#!/usr/bin/env python3 """ Attention Model Training Script ================================ Trains attention score models for identifying high-flow market moments. This script: 1. Loads OHLCV data from MySQL database 2. Trains attention models for all symbols and timeframes 3. Generates both regression (0-3+) and classification (low/medium/high) outputs 4. Saves models to models/attention/ 5. Generates comprehensive training report Features learned: - volume_ratio, volume_z (volume activity) - ATR, ATR_ratio (volatility) - CMF, MFI, OBV_delta (money flow) - BB_width, displacement (price structure) Usage: python scripts/train_attention_model.py python scripts/train_attention_model.py --symbols XAUUSD EURUSD python scripts/train_attention_model.py --cutoff-date 2024-03-01 Author: ML Pipeline Version: 1.0.0 Created: 2026-01-06 """ import argparse import sys from pathlib import Path from datetime import datetime, timedelta import json import os # Setup path BEFORE any other imports _SCRIPT_DIR = Path(__file__).parent.parent.absolute() os.chdir(_SCRIPT_DIR) sys.path.insert(0, str(_SCRIPT_DIR / 'src')) import numpy as np import pandas as pd from loguru import logger import importlib.util # Load modules directly to avoid circular imports in models/__init__.py def _load_module_direct(module_name: str, file_path: Path): """Load a module directly from file without going through __init__.py""" spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module # Load attention modules with CONSISTENT names (important for joblib pickle) _src_dir = _SCRIPT_DIR / 'src' # First load the attention_score_model with a stable name _attention_model_module = _load_module_direct( "models.attention_score_model", _src_dir / 'models' / 'attention_score_model.py' ) # Now load the trainer _attention_trainer_module = _load_module_direct( "training.attention_trainer", _src_dir / 'training' / 'attention_trainer.py' ) AttentionModelTrainer = _attention_trainer_module.AttentionModelTrainer AttentionTrainerConfig = _attention_trainer_module.AttentionTrainerConfig generate_attention_training_report = _attention_trainer_module.generate_attention_training_report AttentionModelConfig = _attention_model_module.AttentionModelConfig # Load database module normally (it doesn't have circular imports) sys.path.insert(0, str(_src_dir)) from data.database import MySQLConnection def setup_logging(log_dir: Path, experiment_name: str) -> Path: """Configure logging to file and console.""" log_dir.mkdir(parents=True, exist_ok=True) log_file = log_dir / f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" logger.remove() logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") logger.add(log_file, level="DEBUG", rotation="10 MB") logger.info(f"Logging to {log_file}") return log_file def load_data_from_db( db: MySQLConnection, symbol: str, start_date: str = None, end_date: str = None, limit: int = None ) -> pd.DataFrame: """ Load OHLCV data from MySQL database. Args: db: MySQL connection symbol: Trading symbol (e.g., 'XAUUSD') start_date: Start date filter (YYYY-MM-DD) end_date: End date filter (YYYY-MM-DD) limit: Maximum records to fetch Returns: DataFrame with OHLCV data """ # Normalize symbol name for database db_symbol = symbol if not symbol.startswith('C:') and not symbol.startswith('X:'): if symbol == 'BTCUSD': db_symbol = f'X:{symbol}' else: db_symbol = f'C:{symbol}' logger.info(f"Loading data for {db_symbol}...") query = """ SELECT date_agg as time, open, high, low, close, volume, vwap FROM tickers_agg_data WHERE ticker = :symbol """ params = {'symbol': db_symbol} if start_date: query += " AND date_agg >= :start_date" params['start_date'] = start_date if end_date: query += " AND date_agg <= :end_date" params['end_date'] = end_date query += " ORDER BY date_agg ASC" if limit: query += f" LIMIT {limit}" df = db.execute_query(query, params) if df.empty: logger.warning(f"No data found for {symbol}") return df # Set datetime index df['time'] = pd.to_datetime(df['time']) df.set_index('time', inplace=True) df = df.sort_index() # Rename columns to standard format df.columns = ['open', 'high', 'low', 'close', 'volume', 'vwap'] logger.info(f"Loaded {len(df)} records for {symbol}") logger.info(f" Date range: {df.index.min()} to {df.index.max()}") return df def resample_to_timeframe(df: pd.DataFrame, timeframe: str) -> pd.DataFrame: """ Resample 5-minute data to different timeframe. Args: df: DataFrame with 5m data timeframe: Target timeframe ('5m', '15m', '1H', etc.) Returns: Resampled DataFrame """ if timeframe == '5m': return df # Already in 5m # Map timeframe to pandas offset tf_map = { '15m': '15min', '30m': '30min', '1H': '1H', '4H': '4H', '1D': '1D' } offset = tf_map.get(timeframe, timeframe) resampled = df.resample(offset).agg({ 'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum', 'vwap': 'mean' }).dropna() logger.info(f"Resampled to {timeframe}: {len(resampled)} bars") return resampled def train_attention_models( symbols: list, timeframes: list, output_dir: Path, cutoff_date: str = '2024-12-31', train_years: float = 5.0, holdout_years: float = 1.0, db_config_path: str = 'config/database.yaml' ) -> dict: """ Train attention models for all symbol/timeframe combinations. Args: symbols: List of symbols to train timeframes: List of timeframes output_dir: Directory to save models cutoff_date: Training data cutoff date train_years: Years of training data holdout_years: Years reserved for holdout validation db_config_path: Path to database config Returns: Dictionary with training results """ logger.info("="*60) logger.info("Attention Model Training") logger.info("="*60) logger.info(f"Symbols: {symbols}") logger.info(f"Timeframes: {timeframes}") logger.info(f"Cutoff date: {cutoff_date}") logger.info(f"Training years: {train_years}") logger.info(f"Holdout years: {holdout_years}") # Connect to database db = MySQLConnection(db_config_path) # Configure attention model model_config = AttentionModelConfig( factor_window=200, horizon_bars=3, low_flow_threshold=1.0, high_flow_threshold=2.0, reg_params={ 'n_estimators': 200, 'max_depth': 5, 'learning_rate': 0.05, 'subsample': 0.8, 'colsample_bytree': 0.8, 'min_child_weight': 10, 'gamma': 0.1, 'reg_alpha': 0.1, 'reg_lambda': 1.0, 'tree_method': 'hist', 'random_state': 42 }, clf_params={ 'n_estimators': 150, 'max_depth': 4, 'learning_rate': 0.05, 'subsample': 0.8, 'colsample_bytree': 0.8, 'min_child_weight': 15, 'gamma': 0.2, 'reg_alpha': 0.1, 'reg_lambda': 1.0, 'tree_method': 'hist', 'random_state': 42, 'objective': 'multi:softmax', 'num_class': 3 }, min_train_samples=5000 ) # Configure trainer trainer_config = AttentionTrainerConfig( symbols=symbols, timeframes=timeframes, train_years=train_years, holdout_years=holdout_years, model_config=model_config, output_dir=str(output_dir / 'attention') ) trainer = AttentionModelTrainer(trainer_config) # Prepare data dictionary data_dict = {} for symbol in symbols: logger.info(f"\n{'='*60}") logger.info(f"Loading data for {symbol}") logger.info(f"{'='*60}") # Load raw data (5m) df_5m = load_data_from_db(db, symbol, end_date=cutoff_date) if df_5m.empty: logger.warning(f"No data for {symbol}, skipping...") continue # Verify we have enough data if len(df_5m) < 50000: logger.warning(f"Insufficient data for {symbol}: {len(df_5m)} rows (need 50000+)") continue data_dict[symbol] = {} for timeframe in timeframes: logger.info(f"\n--- Preparing {symbol} {timeframe} ---") # Resample if needed if timeframe == '5m': df_tf = df_5m.copy() else: df_tf = resample_to_timeframe(df_5m.copy(), timeframe) if len(df_tf) < 10000: logger.warning(f"Insufficient {timeframe} data: {len(df_tf)} rows (need 10000+)") continue logger.info(f"Data shape: {df_tf.shape}") logger.info(f"Date range: {df_tf.index.min()} to {df_tf.index.max()}") data_dict[symbol][timeframe] = df_tf # Train all models logger.info("\n" + "="*60) logger.info("Starting model training") logger.info("="*60) all_results = trainer.train_all(data_dict) # Save models model_dir = output_dir / 'attention' trainer.save(str(model_dir)) logger.info(f"\nModels saved to {model_dir}") # Generate training summary summary_df = trainer.get_training_summary() if not summary_df.empty: summary_path = output_dir / 'attention_training_summary.csv' summary_df.to_csv(summary_path, index=False) logger.info(f"Summary saved to {summary_path}") logger.info("\n" + "="*60) logger.info("TRAINING SUMMARY") logger.info("="*60) print(summary_df.to_string(index=False)) return { 'results': all_results, 'summary': summary_df.to_dict() if not summary_df.empty else {}, 'model_dir': str(model_dir), 'trainer': trainer } def generate_markdown_report( trainer: AttentionModelTrainer, output_dir: Path, symbols: list, timeframes: list, cutoff_date: str ) -> Path: """Generate detailed Markdown training report.""" report_path = output_dir / f"ATTENTION_TRAINING_REPORT_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" report = f"""# Attention Score Model Training Report **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ## Overview The attention model learns to identify high-flow market moments using volume, volatility, and money flow indicators - WITHOUT hardcoding specific trading hours or sessions. ## Configuration - **Symbols:** {', '.join(symbols)} - **Timeframes:** {', '.join(timeframes)} - **Training Data Cutoff:** {cutoff_date} - **Training Years:** {trainer.config.train_years} - **Holdout Years:** {trainer.config.holdout_years} ### Model Parameters | Parameter | Value | |-----------|-------| | Factor Window | {trainer.config.model_config.factor_window} | | Horizon Bars | {trainer.config.model_config.horizon_bars} | | Low Flow Threshold | {trainer.config.model_config.low_flow_threshold} | | High Flow Threshold | {trainer.config.model_config.high_flow_threshold} | ### Features Used (9 total) | Feature | Description | |---------|-------------| | volume_ratio | volume / rolling_median(volume, 20) | | volume_z | z-score of volume over 20 periods | | ATR | Average True Range (14 periods) | | ATR_ratio | ATR / rolling_median(ATR, 50) | | CMF | Chaikin Money Flow (20 periods) | | MFI | Money Flow Index (14 periods) | | OBV_delta | diff(OBV) / rolling_std(OBV, 20) | | BB_width | (BB_upper - BB_lower) / close | | displacement | (close - open) / ATR | ## Training Results | Model | Symbol | TF | Reg MAE | Reg R2 | Clf Acc | Clf F1 | N Train | High Flow % | |-------|--------|-----|---------|--------|---------|--------|---------|-------------| """ for key, result in trainer.results.items(): total_samples = sum(result.class_distribution.values()) high_pct = result.class_distribution.get('high_flow', 0) / max(total_samples, 1) * 100 report += f"| {key} | {result.symbol} | {result.timeframe} | " report += f"{result.reg_mae:.4f} | {result.reg_r2:.4f} | " report += f"{result.clf_accuracy:.2%} | {result.clf_f1:.2%} | " report += f"{result.n_train} | {high_pct:.1f}% |\n" report += """ ## Class Distribution (Holdout Set) | Model | Low Flow | Medium Flow | High Flow | |-------|----------|-------------|-----------| """ for key, result in trainer.results.items(): low = result.class_distribution.get('low_flow', 0) med = result.class_distribution.get('medium_flow', 0) high = result.class_distribution.get('high_flow', 0) total = max(low + med + high, 1) report += f"| {key} | {low} ({low/total*100:.1f}%) | {med} ({med/total*100:.1f}%) | {high} ({high/total*100:.1f}%) |\n" report += """ ## Feature Importance """ for key, result in trainer.results.items(): report += f"### {key}\n\n" report += "| Rank | Feature | Combined Importance |\n|------|---------|--------------------|\n" sorted_features = sorted(result.feature_importance.items(), key=lambda x: -x[1]) for rank, (feat, imp) in enumerate(sorted_features, 1): report += f"| {rank} | {feat} | {imp:.4f} |\n" report += "\n" report += f""" ## Interpretation ### Attention Score (Regression) - **< 1.0**: Low flow period - below average market movement expected - **1.0 - 2.0**: Medium flow period - average market conditions - **> 2.0**: High flow period - above average movement expected (best trading opportunities) ### Flow Class (Classification) - **0 (low_flow)**: move_multiplier < 1.0 - **1 (medium_flow)**: 1.0 <= move_multiplier < 2.0 - **2 (high_flow)**: move_multiplier >= 2.0 ## Trading Recommendations 1. **Filter by attention_score**: Only trade when attention_score > 1.0 2. **Adjust position sizing**: Increase size when attention_score > 2.0 3. **Combine with base models**: Use attention_score as feature #51 in prediction models 4. **Time-agnostic**: The model identifies flow without hardcoded sessions ## Usage Example ```python from training.attention_trainer import AttentionModelTrainer # Load trained models trainer = AttentionModelTrainer.load('models/attention/') # Get attention score for new OHLCV data attention = trainer.get_attention_score(df_ohlcv, 'XAUUSD', '5m') # Filter trades mask_trade = attention > 1.0 # Only trade in medium/high flow # Or use as feature in base models df['attention_score'] = attention ``` ## Files Generated - `models/attention/{{symbol}}_{{timeframe}}_attention/` - Model directories - `models/attention/trainer_metadata.joblib` - Trainer configuration - `models/attention/training_summary.csv` - Summary metrics --- *Report generated by Attention Model Training Pipeline* """ with open(report_path, 'w') as f: f.write(report) logger.info(f"Report saved to {report_path}") return report_path def main(): parser = argparse.ArgumentParser(description='Train Attention Score Models') parser.add_argument( '--symbols', nargs='+', default=['XAUUSD', 'EURUSD', 'BTCUSD', 'GBPUSD', 'USDJPY'], help='Symbols to train (default: XAUUSD EURUSD BTCUSD GBPUSD USDJPY)' ) parser.add_argument( '--timeframes', nargs='+', default=['5m', '15m'], help='Timeframes to train (default: 5m 15m)' ) parser.add_argument( '--output-dir', type=str, default='models/', help='Output directory for models (default: models/)' ) parser.add_argument( '--cutoff-date', type=str, default='2024-12-31', help='Training data cutoff date (default: 2024-12-31)' ) parser.add_argument( '--train-years', type=float, default=5.0, help='Years of training data (default: 5.0)' ) parser.add_argument( '--holdout-years', type=float, default=1.0, help='Years for holdout validation (default: 1.0)' ) parser.add_argument( '--db-config', type=str, default='config/database.yaml', help='Database configuration file' ) args = parser.parse_args() # Setup paths script_dir = Path(__file__).parent.parent output_dir = script_dir / args.output_dir output_dir.mkdir(parents=True, exist_ok=True) logs_dir = output_dir / 'logs' setup_logging(logs_dir, 'attention_model_training') # Run training try: results = train_attention_models( symbols=args.symbols, timeframes=args.timeframes, output_dir=output_dir, cutoff_date=args.cutoff_date, train_years=args.train_years, holdout_years=args.holdout_years, db_config_path=str(script_dir / args.db_config) ) # Generate detailed report if results.get('trainer'): generate_markdown_report( results['trainer'], output_dir, args.symbols, args.timeframes, args.cutoff_date ) logger.info("\n" + "="*60) logger.info("ATTENTION MODEL TRAINING COMPLETE!") logger.info("="*60) logger.info(f"Models saved to: {results.get('model_dir', 'N/A')}") logger.info(f"Total models trained: {len(results.get('results', {}))}") # Print quick summary if results.get('results'): logger.info("\nQuick Summary:") for key, result in results['results'].items(): high_pct = result.class_distribution.get('high_flow', 0) / max(sum(result.class_distribution.values()), 1) * 100 logger.info(f" {key}: R2={result.reg_r2:.3f}, Clf Acc={result.clf_accuracy:.1%}, High Flow={high_pct:.1f}%") except Exception as e: logger.exception(f"Training failed: {e}") sys.exit(1) if __name__ == "__main__": main()