Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
617 lines
18 KiB
Python
617 lines
18 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Attention Model Training Script
|
|
================================
|
|
Trains attention score models for identifying high-flow market moments.
|
|
|
|
This script:
|
|
1. Loads OHLCV data from MySQL database
|
|
2. Trains attention models for all symbols and timeframes
|
|
3. Generates both regression (0-3+) and classification (low/medium/high) outputs
|
|
4. Saves models to models/attention/
|
|
5. Generates comprehensive training report
|
|
|
|
Features learned:
|
|
- volume_ratio, volume_z (volume activity)
|
|
- ATR, ATR_ratio (volatility)
|
|
- CMF, MFI, OBV_delta (money flow)
|
|
- BB_width, displacement (price structure)
|
|
|
|
Usage:
|
|
python scripts/train_attention_model.py
|
|
python scripts/train_attention_model.py --symbols XAUUSD EURUSD
|
|
python scripts/train_attention_model.py --cutoff-date 2024-03-01
|
|
|
|
Author: ML Pipeline
|
|
Version: 1.0.0
|
|
Created: 2026-01-06
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
import os
|
|
|
|
# Setup path BEFORE any other imports
|
|
_SCRIPT_DIR = Path(__file__).parent.parent.absolute()
|
|
os.chdir(_SCRIPT_DIR)
|
|
sys.path.insert(0, str(_SCRIPT_DIR / 'src'))
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from loguru import logger
|
|
import importlib.util
|
|
|
|
# Load modules directly to avoid circular imports in models/__init__.py
|
|
def _load_module_direct(module_name: str, file_path: Path):
|
|
"""Load a module directly from file without going through __init__.py"""
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = module
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
# Load attention modules with CONSISTENT names (important for joblib pickle)
|
|
_src_dir = _SCRIPT_DIR / 'src'
|
|
|
|
# First load the attention_score_model with a stable name
|
|
_attention_model_module = _load_module_direct(
|
|
"models.attention_score_model",
|
|
_src_dir / 'models' / 'attention_score_model.py'
|
|
)
|
|
|
|
# Now load the trainer
|
|
_attention_trainer_module = _load_module_direct(
|
|
"training.attention_trainer",
|
|
_src_dir / 'training' / 'attention_trainer.py'
|
|
)
|
|
|
|
AttentionModelTrainer = _attention_trainer_module.AttentionModelTrainer
|
|
AttentionTrainerConfig = _attention_trainer_module.AttentionTrainerConfig
|
|
generate_attention_training_report = _attention_trainer_module.generate_attention_training_report
|
|
AttentionModelConfig = _attention_model_module.AttentionModelConfig
|
|
|
|
# Load database module normally (it doesn't have circular imports)
|
|
sys.path.insert(0, str(_src_dir))
|
|
from data.database import MySQLConnection
|
|
|
|
|
|
def setup_logging(log_dir: Path, experiment_name: str) -> Path:
|
|
"""Configure logging to file and console."""
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
log_file = log_dir / f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
|
|
logger.remove()
|
|
logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
|
logger.add(log_file, level="DEBUG", rotation="10 MB")
|
|
|
|
logger.info(f"Logging to {log_file}")
|
|
return log_file
|
|
|
|
|
|
def load_data_from_db(
|
|
db: MySQLConnection,
|
|
symbol: str,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
limit: int = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Load OHLCV data from MySQL database.
|
|
|
|
Args:
|
|
db: MySQL connection
|
|
symbol: Trading symbol (e.g., 'XAUUSD')
|
|
start_date: Start date filter (YYYY-MM-DD)
|
|
end_date: End date filter (YYYY-MM-DD)
|
|
limit: Maximum records to fetch
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data
|
|
"""
|
|
# Normalize symbol name for database
|
|
db_symbol = symbol
|
|
if not symbol.startswith('C:') and not symbol.startswith('X:'):
|
|
if symbol == 'BTCUSD':
|
|
db_symbol = f'X:{symbol}'
|
|
else:
|
|
db_symbol = f'C:{symbol}'
|
|
|
|
logger.info(f"Loading data for {db_symbol}...")
|
|
|
|
query = """
|
|
SELECT
|
|
date_agg as time,
|
|
open,
|
|
high,
|
|
low,
|
|
close,
|
|
volume,
|
|
vwap
|
|
FROM tickers_agg_data
|
|
WHERE ticker = :symbol
|
|
"""
|
|
|
|
params = {'symbol': db_symbol}
|
|
|
|
if start_date:
|
|
query += " AND date_agg >= :start_date"
|
|
params['start_date'] = start_date
|
|
if end_date:
|
|
query += " AND date_agg <= :end_date"
|
|
params['end_date'] = end_date
|
|
|
|
query += " ORDER BY date_agg ASC"
|
|
|
|
if limit:
|
|
query += f" LIMIT {limit}"
|
|
|
|
df = db.execute_query(query, params)
|
|
|
|
if df.empty:
|
|
logger.warning(f"No data found for {symbol}")
|
|
return df
|
|
|
|
# Set datetime index
|
|
df['time'] = pd.to_datetime(df['time'])
|
|
df.set_index('time', inplace=True)
|
|
df = df.sort_index()
|
|
|
|
# Rename columns to standard format
|
|
df.columns = ['open', 'high', 'low', 'close', 'volume', 'vwap']
|
|
|
|
logger.info(f"Loaded {len(df)} records for {symbol}")
|
|
logger.info(f" Date range: {df.index.min()} to {df.index.max()}")
|
|
|
|
return df
|
|
|
|
|
|
def resample_to_timeframe(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
|
"""
|
|
Resample 5-minute data to different timeframe.
|
|
|
|
Args:
|
|
df: DataFrame with 5m data
|
|
timeframe: Target timeframe ('5m', '15m', '1H', etc.)
|
|
|
|
Returns:
|
|
Resampled DataFrame
|
|
"""
|
|
if timeframe == '5m':
|
|
return df # Already in 5m
|
|
|
|
# Map timeframe to pandas offset
|
|
tf_map = {
|
|
'15m': '15min',
|
|
'30m': '30min',
|
|
'1H': '1H',
|
|
'4H': '4H',
|
|
'1D': '1D'
|
|
}
|
|
|
|
offset = tf_map.get(timeframe, timeframe)
|
|
|
|
resampled = df.resample(offset).agg({
|
|
'open': 'first',
|
|
'high': 'max',
|
|
'low': 'min',
|
|
'close': 'last',
|
|
'volume': 'sum',
|
|
'vwap': 'mean'
|
|
}).dropna()
|
|
|
|
logger.info(f"Resampled to {timeframe}: {len(resampled)} bars")
|
|
return resampled
|
|
|
|
|
|
def train_attention_models(
|
|
symbols: list,
|
|
timeframes: list,
|
|
output_dir: Path,
|
|
cutoff_date: str = '2024-12-31',
|
|
train_years: float = 5.0,
|
|
holdout_years: float = 1.0,
|
|
db_config_path: str = 'config/database.yaml'
|
|
) -> dict:
|
|
"""
|
|
Train attention models for all symbol/timeframe combinations.
|
|
|
|
Args:
|
|
symbols: List of symbols to train
|
|
timeframes: List of timeframes
|
|
output_dir: Directory to save models
|
|
cutoff_date: Training data cutoff date
|
|
train_years: Years of training data
|
|
holdout_years: Years reserved for holdout validation
|
|
db_config_path: Path to database config
|
|
|
|
Returns:
|
|
Dictionary with training results
|
|
"""
|
|
logger.info("="*60)
|
|
logger.info("Attention Model Training")
|
|
logger.info("="*60)
|
|
logger.info(f"Symbols: {symbols}")
|
|
logger.info(f"Timeframes: {timeframes}")
|
|
logger.info(f"Cutoff date: {cutoff_date}")
|
|
logger.info(f"Training years: {train_years}")
|
|
logger.info(f"Holdout years: {holdout_years}")
|
|
|
|
# Connect to database
|
|
db = MySQLConnection(db_config_path)
|
|
|
|
# Configure attention model
|
|
model_config = AttentionModelConfig(
|
|
factor_window=200,
|
|
horizon_bars=3,
|
|
low_flow_threshold=1.0,
|
|
high_flow_threshold=2.0,
|
|
reg_params={
|
|
'n_estimators': 200,
|
|
'max_depth': 5,
|
|
'learning_rate': 0.05,
|
|
'subsample': 0.8,
|
|
'colsample_bytree': 0.8,
|
|
'min_child_weight': 10,
|
|
'gamma': 0.1,
|
|
'reg_alpha': 0.1,
|
|
'reg_lambda': 1.0,
|
|
'tree_method': 'hist',
|
|
'random_state': 42
|
|
},
|
|
clf_params={
|
|
'n_estimators': 150,
|
|
'max_depth': 4,
|
|
'learning_rate': 0.05,
|
|
'subsample': 0.8,
|
|
'colsample_bytree': 0.8,
|
|
'min_child_weight': 15,
|
|
'gamma': 0.2,
|
|
'reg_alpha': 0.1,
|
|
'reg_lambda': 1.0,
|
|
'tree_method': 'hist',
|
|
'random_state': 42,
|
|
'objective': 'multi:softmax',
|
|
'num_class': 3
|
|
},
|
|
min_train_samples=5000
|
|
)
|
|
|
|
# Configure trainer
|
|
trainer_config = AttentionTrainerConfig(
|
|
symbols=symbols,
|
|
timeframes=timeframes,
|
|
train_years=train_years,
|
|
holdout_years=holdout_years,
|
|
model_config=model_config,
|
|
output_dir=str(output_dir / 'attention')
|
|
)
|
|
|
|
trainer = AttentionModelTrainer(trainer_config)
|
|
|
|
# Prepare data dictionary
|
|
data_dict = {}
|
|
|
|
for symbol in symbols:
|
|
logger.info(f"\n{'='*60}")
|
|
logger.info(f"Loading data for {symbol}")
|
|
logger.info(f"{'='*60}")
|
|
|
|
# Load raw data (5m)
|
|
df_5m = load_data_from_db(db, symbol, end_date=cutoff_date)
|
|
|
|
if df_5m.empty:
|
|
logger.warning(f"No data for {symbol}, skipping...")
|
|
continue
|
|
|
|
# Verify we have enough data
|
|
if len(df_5m) < 50000:
|
|
logger.warning(f"Insufficient data for {symbol}: {len(df_5m)} rows (need 50000+)")
|
|
continue
|
|
|
|
data_dict[symbol] = {}
|
|
|
|
for timeframe in timeframes:
|
|
logger.info(f"\n--- Preparing {symbol} {timeframe} ---")
|
|
|
|
# Resample if needed
|
|
if timeframe == '5m':
|
|
df_tf = df_5m.copy()
|
|
else:
|
|
df_tf = resample_to_timeframe(df_5m.copy(), timeframe)
|
|
|
|
if len(df_tf) < 10000:
|
|
logger.warning(f"Insufficient {timeframe} data: {len(df_tf)} rows (need 10000+)")
|
|
continue
|
|
|
|
logger.info(f"Data shape: {df_tf.shape}")
|
|
logger.info(f"Date range: {df_tf.index.min()} to {df_tf.index.max()}")
|
|
|
|
data_dict[symbol][timeframe] = df_tf
|
|
|
|
# Train all models
|
|
logger.info("\n" + "="*60)
|
|
logger.info("Starting model training")
|
|
logger.info("="*60)
|
|
|
|
all_results = trainer.train_all(data_dict)
|
|
|
|
# Save models
|
|
model_dir = output_dir / 'attention'
|
|
trainer.save(str(model_dir))
|
|
logger.info(f"\nModels saved to {model_dir}")
|
|
|
|
# Generate training summary
|
|
summary_df = trainer.get_training_summary()
|
|
|
|
if not summary_df.empty:
|
|
summary_path = output_dir / 'attention_training_summary.csv'
|
|
summary_df.to_csv(summary_path, index=False)
|
|
logger.info(f"Summary saved to {summary_path}")
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("TRAINING SUMMARY")
|
|
logger.info("="*60)
|
|
print(summary_df.to_string(index=False))
|
|
|
|
return {
|
|
'results': all_results,
|
|
'summary': summary_df.to_dict() if not summary_df.empty else {},
|
|
'model_dir': str(model_dir),
|
|
'trainer': trainer
|
|
}
|
|
|
|
|
|
def generate_markdown_report(
|
|
trainer: AttentionModelTrainer,
|
|
output_dir: Path,
|
|
symbols: list,
|
|
timeframes: list,
|
|
cutoff_date: str
|
|
) -> Path:
|
|
"""Generate detailed Markdown training report."""
|
|
report_path = output_dir / f"ATTENTION_TRAINING_REPORT_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
|
|
|
|
report = f"""# Attention Score Model Training Report
|
|
|
|
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
|
|
## Overview
|
|
|
|
The attention model learns to identify high-flow market moments using volume, volatility, and money flow indicators - WITHOUT hardcoding specific trading hours or sessions.
|
|
|
|
## Configuration
|
|
|
|
- **Symbols:** {', '.join(symbols)}
|
|
- **Timeframes:** {', '.join(timeframes)}
|
|
- **Training Data Cutoff:** {cutoff_date}
|
|
- **Training Years:** {trainer.config.train_years}
|
|
- **Holdout Years:** {trainer.config.holdout_years}
|
|
|
|
### Model Parameters
|
|
|
|
| Parameter | Value |
|
|
|-----------|-------|
|
|
| Factor Window | {trainer.config.model_config.factor_window} |
|
|
| Horizon Bars | {trainer.config.model_config.horizon_bars} |
|
|
| Low Flow Threshold | {trainer.config.model_config.low_flow_threshold} |
|
|
| High Flow Threshold | {trainer.config.model_config.high_flow_threshold} |
|
|
|
|
### Features Used (9 total)
|
|
|
|
| Feature | Description |
|
|
|---------|-------------|
|
|
| volume_ratio | volume / rolling_median(volume, 20) |
|
|
| volume_z | z-score of volume over 20 periods |
|
|
| ATR | Average True Range (14 periods) |
|
|
| ATR_ratio | ATR / rolling_median(ATR, 50) |
|
|
| CMF | Chaikin Money Flow (20 periods) |
|
|
| MFI | Money Flow Index (14 periods) |
|
|
| OBV_delta | diff(OBV) / rolling_std(OBV, 20) |
|
|
| BB_width | (BB_upper - BB_lower) / close |
|
|
| displacement | (close - open) / ATR |
|
|
|
|
## Training Results
|
|
|
|
| Model | Symbol | TF | Reg MAE | Reg R2 | Clf Acc | Clf F1 | N Train | High Flow % |
|
|
|-------|--------|-----|---------|--------|---------|--------|---------|-------------|
|
|
"""
|
|
|
|
for key, result in trainer.results.items():
|
|
total_samples = sum(result.class_distribution.values())
|
|
high_pct = result.class_distribution.get('high_flow', 0) / max(total_samples, 1) * 100
|
|
|
|
report += f"| {key} | {result.symbol} | {result.timeframe} | "
|
|
report += f"{result.reg_mae:.4f} | {result.reg_r2:.4f} | "
|
|
report += f"{result.clf_accuracy:.2%} | {result.clf_f1:.2%} | "
|
|
report += f"{result.n_train} | {high_pct:.1f}% |\n"
|
|
|
|
report += """
|
|
|
|
## Class Distribution (Holdout Set)
|
|
|
|
| Model | Low Flow | Medium Flow | High Flow |
|
|
|-------|----------|-------------|-----------|
|
|
"""
|
|
|
|
for key, result in trainer.results.items():
|
|
low = result.class_distribution.get('low_flow', 0)
|
|
med = result.class_distribution.get('medium_flow', 0)
|
|
high = result.class_distribution.get('high_flow', 0)
|
|
total = max(low + med + high, 1)
|
|
|
|
report += f"| {key} | {low} ({low/total*100:.1f}%) | {med} ({med/total*100:.1f}%) | {high} ({high/total*100:.1f}%) |\n"
|
|
|
|
report += """
|
|
|
|
## Feature Importance
|
|
|
|
"""
|
|
|
|
for key, result in trainer.results.items():
|
|
report += f"### {key}\n\n"
|
|
report += "| Rank | Feature | Combined Importance |\n|------|---------|--------------------|\n"
|
|
|
|
sorted_features = sorted(result.feature_importance.items(), key=lambda x: -x[1])
|
|
for rank, (feat, imp) in enumerate(sorted_features, 1):
|
|
report += f"| {rank} | {feat} | {imp:.4f} |\n"
|
|
report += "\n"
|
|
|
|
report += f"""
|
|
|
|
## Interpretation
|
|
|
|
### Attention Score (Regression)
|
|
|
|
- **< 1.0**: Low flow period - below average market movement expected
|
|
- **1.0 - 2.0**: Medium flow period - average market conditions
|
|
- **> 2.0**: High flow period - above average movement expected (best trading opportunities)
|
|
|
|
### Flow Class (Classification)
|
|
|
|
- **0 (low_flow)**: move_multiplier < 1.0
|
|
- **1 (medium_flow)**: 1.0 <= move_multiplier < 2.0
|
|
- **2 (high_flow)**: move_multiplier >= 2.0
|
|
|
|
## Trading Recommendations
|
|
|
|
1. **Filter by attention_score**: Only trade when attention_score > 1.0
|
|
2. **Adjust position sizing**: Increase size when attention_score > 2.0
|
|
3. **Combine with base models**: Use attention_score as feature #51 in prediction models
|
|
4. **Time-agnostic**: The model identifies flow without hardcoded sessions
|
|
|
|
## Usage Example
|
|
|
|
```python
|
|
from training.attention_trainer import AttentionModelTrainer
|
|
|
|
# Load trained models
|
|
trainer = AttentionModelTrainer.load('models/attention/')
|
|
|
|
# Get attention score for new OHLCV data
|
|
attention = trainer.get_attention_score(df_ohlcv, 'XAUUSD', '5m')
|
|
|
|
# Filter trades
|
|
mask_trade = attention > 1.0 # Only trade in medium/high flow
|
|
|
|
# Or use as feature in base models
|
|
df['attention_score'] = attention
|
|
```
|
|
|
|
## Files Generated
|
|
|
|
- `models/attention/{{symbol}}_{{timeframe}}_attention/` - Model directories
|
|
- `models/attention/trainer_metadata.joblib` - Trainer configuration
|
|
- `models/attention/training_summary.csv` - Summary metrics
|
|
|
|
---
|
|
*Report generated by Attention Model Training Pipeline*
|
|
"""
|
|
|
|
with open(report_path, 'w') as f:
|
|
f.write(report)
|
|
|
|
logger.info(f"Report saved to {report_path}")
|
|
return report_path
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train Attention Score Models')
|
|
parser.add_argument(
|
|
'--symbols',
|
|
nargs='+',
|
|
default=['XAUUSD', 'EURUSD', 'BTCUSD', 'GBPUSD', 'USDJPY'],
|
|
help='Symbols to train (default: XAUUSD EURUSD BTCUSD GBPUSD USDJPY)'
|
|
)
|
|
parser.add_argument(
|
|
'--timeframes',
|
|
nargs='+',
|
|
default=['5m', '15m'],
|
|
help='Timeframes to train (default: 5m 15m)'
|
|
)
|
|
parser.add_argument(
|
|
'--output-dir',
|
|
type=str,
|
|
default='models/',
|
|
help='Output directory for models (default: models/)'
|
|
)
|
|
parser.add_argument(
|
|
'--cutoff-date',
|
|
type=str,
|
|
default='2024-12-31',
|
|
help='Training data cutoff date (default: 2024-12-31)'
|
|
)
|
|
parser.add_argument(
|
|
'--train-years',
|
|
type=float,
|
|
default=5.0,
|
|
help='Years of training data (default: 5.0)'
|
|
)
|
|
parser.add_argument(
|
|
'--holdout-years',
|
|
type=float,
|
|
default=1.0,
|
|
help='Years for holdout validation (default: 1.0)'
|
|
)
|
|
parser.add_argument(
|
|
'--db-config',
|
|
type=str,
|
|
default='config/database.yaml',
|
|
help='Database configuration file'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup paths
|
|
script_dir = Path(__file__).parent.parent
|
|
output_dir = script_dir / args.output_dir
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logs_dir = output_dir / 'logs'
|
|
setup_logging(logs_dir, 'attention_model_training')
|
|
|
|
# Run training
|
|
try:
|
|
results = train_attention_models(
|
|
symbols=args.symbols,
|
|
timeframes=args.timeframes,
|
|
output_dir=output_dir,
|
|
cutoff_date=args.cutoff_date,
|
|
train_years=args.train_years,
|
|
holdout_years=args.holdout_years,
|
|
db_config_path=str(script_dir / args.db_config)
|
|
)
|
|
|
|
# Generate detailed report
|
|
if results.get('trainer'):
|
|
generate_markdown_report(
|
|
results['trainer'],
|
|
output_dir,
|
|
args.symbols,
|
|
args.timeframes,
|
|
args.cutoff_date
|
|
)
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("ATTENTION MODEL TRAINING COMPLETE!")
|
|
logger.info("="*60)
|
|
logger.info(f"Models saved to: {results.get('model_dir', 'N/A')}")
|
|
logger.info(f"Total models trained: {len(results.get('results', {}))}")
|
|
|
|
# Print quick summary
|
|
if results.get('results'):
|
|
logger.info("\nQuick Summary:")
|
|
for key, result in results['results'].items():
|
|
high_pct = result.class_distribution.get('high_flow', 0) / max(sum(result.class_distribution.values()), 1) * 100
|
|
logger.info(f" {key}: R2={result.reg_r2:.3f}, Clf Acc={result.clf_accuracy:.1%}, High Flow={high_pct:.1f}%")
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Training failed: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|