#!/usr/bin/env python3 """ Train Metamodels Script ======================= CLI script to train Asset Metamodels (Nivel 2 of hierarchical architecture). This script orchestrates: 1. Loading pre-trained Attention Models (Nivel 0) 2. Loading pre-trained Base Models (Nivel 1) 3. Generating OOS predictions 4. Training Metamodels per asset Usage: # Train metamodels for all assets python train_metamodels.py # Train for specific symbols python train_metamodels.py --symbols XAUUSD EURUSD # Specify model paths python train_metamodels.py \ --attention-path models/attention \ --base-path models/base \ --output-path models/metamodels # Custom OOS period python train_metamodels.py \ --oos-start 2024-01-01 \ --oos-end 2024-08-31 Author: ML Pipeline Version: 1.0.0 Created: 2026-01-07 """ import sys import argparse from pathlib import Path from datetime import datetime import pandas as pd import numpy as np from loguru import logger # Add parent directories to path script_dir = Path(__file__).parent project_dir = script_dir.parent sys.path.insert(0, str(project_dir / 'src')) def setup_logging(log_dir: Path, symbol: str = 'all'): """Configure logging to file and console.""" log_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') log_file = log_dir / f'metamodel_training_{symbol}_{timestamp}.log' logger.remove() logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") logger.add(log_file, level="DEBUG", format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}") return log_file def load_ohlcv_from_mysql(symbol: str, timeframe: str, config: dict) -> pd.DataFrame: """ Load OHLCV data from MySQL database. Loads raw OHLCV data from tickers_agg_data table so we can generate features fresh using the same logic as base model training. """ try: # Try to use the project's database module from data.database import MySQLConnection db = MySQLConnection() # Normalize symbol name for database (same as train_attention_model.py) db_symbol = symbol if not symbol.startswith('C:') and not symbol.startswith('X:'): if symbol == 'BTCUSD': db_symbol = f'X:{symbol}' else: db_symbol = f'C:{symbol}' logger.info(f"Loading OHLCV data for {db_symbol} ({timeframe})...") # Load raw OHLCV data from tickers_agg_data query = f""" SELECT date_agg as timestamp, open, high, low, close, volume FROM tickers_agg_data WHERE ticker = '{db_symbol}' ORDER BY date_agg ASC """ df = pd.read_sql(query, db.engine) if len(df) == 0: logger.warning(f"No data found for {symbol} {timeframe}") return pd.DataFrame() logger.info(f"Loaded {len(df)} rows for {symbol}") # Set timestamp as index df['timestamp'] = pd.to_datetime(df['timestamp']) df.set_index('timestamp', inplace=True) # Resample to requested timeframe agg_dict = { 'open': 'first', 'high': 'max', 'low': 'min', 'close': 'last', 'volume': 'sum' } if timeframe == '5m': df = df.resample('5min').agg(agg_dict).dropna() elif timeframe == '15m': df = df.resample('15min').agg(agg_dict).dropna() logger.info(f"Resampled to {timeframe}: {len(df)} bars") return df except Exception as e: logger.error(f"Failed to load data from MySQL: {e}") import traceback traceback.print_exc() return pd.DataFrame() def load_ohlcv_from_parquet(data_dir: Path, symbol: str, timeframe: str) -> pd.DataFrame: """Load OHLCV data from Parquet files.""" # Try different file patterns patterns = [ f"{symbol}_{timeframe}.parquet", f"{symbol.lower()}_{timeframe}.parquet", f"{symbol}/{timeframe}.parquet", f"{symbol.lower()}/{timeframe}.parquet" ] for pattern in patterns: file_path = data_dir / pattern if file_path.exists(): df = pd.read_parquet(file_path) # Ensure datetime index if 'timestamp' in df.columns: df['timestamp'] = pd.to_datetime(df['timestamp']) df.set_index('timestamp', inplace=True) elif not isinstance(df.index, pd.DatetimeIndex): logger.warning(f"No datetime index in {file_path}") # Standardize column names col_map = { 'open': 'Open', 'high': 'High', 'low': 'Low', 'close': 'Close', 'volume': 'Volume' } df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}, inplace=True) logger.info(f"Loaded {len(df)} rows from {file_path}") return df logger.warning(f"No parquet file found for {symbol} {timeframe}") return pd.DataFrame() def generate_features(df: pd.DataFrame, symbol: str = '') -> pd.DataFrame: """ Generate comprehensive feature set for training. This function generates the EXACT same 50 features used by symbol_timeframe_trainer, ensuring compatibility with the base models. Args: df: OHLCV DataFrame with columns: Open, High, Low, Close, Volume (or lowercase) symbol: Symbol for context-specific features (unused but kept for compatibility) Returns: DataFrame with all features (OHLCV + 50 generated features) """ if len(df) == 0: return df df = df.copy() # Normalize column names to lowercase col_map = {'Open': 'open', 'High': 'high', 'Low': 'low', 'Close': 'close', 'Volume': 'volume'} df.rename(columns={k: v for k, v in col_map.items() if k in df.columns}, inplace=True) features = pd.DataFrame(index=df.index) close = df['close'] high = df['high'] low = df['low'] open_price = df['open'] volume = df['volume'] if 'volume' in df.columns else pd.Series(1, index=df.index) # ===== Price Returns (5 features) ===== features['returns_1'] = close.pct_change(1) features['returns_3'] = close.pct_change(3) features['returns_5'] = close.pct_change(5) features['returns_10'] = close.pct_change(10) features['returns_20'] = close.pct_change(20) # ===== Volatility Features (3 features) ===== features['volatility_5'] = close.pct_change().rolling(5).std() features['volatility_10'] = close.pct_change().rolling(10).std() features['volatility_20'] = close.pct_change().rolling(20).std() # ===== Range Features (7 features) ===== candle_range = high - low features['range'] = candle_range features['range_pct'] = candle_range / close features['range_ma_5'] = candle_range.rolling(5).mean() features['range_ma_10'] = candle_range.rolling(10).mean() features['range_ma_20'] = candle_range.rolling(20).mean() features['range_ratio_5'] = candle_range / features['range_ma_5'] features['range_ratio_20'] = candle_range / features['range_ma_20'] # ===== ATR Features (4 features) ===== tr1 = high - low tr2 = abs(high - close.shift(1)) tr3 = abs(low - close.shift(1)) true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1) features['atr_5'] = true_range.rolling(5).mean() features['atr_14'] = true_range.rolling(14).mean() features['atr_20'] = true_range.rolling(20).mean() features['atr_ratio'] = true_range / features['atr_14'] # ===== Moving Averages (6 features) ===== sma_5 = close.rolling(5).mean() sma_10 = close.rolling(10).mean() sma_20 = close.rolling(20).mean() sma_50 = close.rolling(50).mean() ema_5 = close.ewm(span=5, adjust=False).mean() ema_20 = close.ewm(span=20, adjust=False).mean() features['price_vs_sma5'] = (close - sma_5) / features['atr_14'] features['price_vs_sma10'] = (close - sma_10) / features['atr_14'] features['price_vs_sma20'] = (close - sma_20) / features['atr_14'] features['price_vs_sma50'] = (close - sma_50) / features['atr_14'] features['sma5_vs_sma20'] = (sma_5 - sma_20) / features['atr_14'] features['ema5_vs_ema20'] = (ema_5 - ema_20) / features['atr_14'] # ===== RSI (3 features) ===== delta = close.diff() gain = delta.where(delta > 0, 0).rolling(14).mean() loss = (-delta.where(delta < 0, 0)).rolling(14).mean() rs = gain / (loss + 1e-10) features['rsi_14'] = 100 - (100 / (1 + rs)) features['rsi_oversold'] = (features['rsi_14'] < 30).astype(float) features['rsi_overbought'] = (features['rsi_14'] > 70).astype(float) # ===== Bollinger Bands (2 features) ===== bb_middle = close.rolling(20).mean() bb_std = close.rolling(20).std() bb_upper = bb_middle + 2 * bb_std bb_lower = bb_middle - 2 * bb_std features['bb_width'] = (bb_upper - bb_lower) / bb_middle features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 1e-10) # ===== MACD (3 features) ===== ema_12 = close.ewm(span=12, adjust=False).mean() ema_26 = close.ewm(span=26, adjust=False).mean() macd = ema_12 - ema_26 macd_signal = macd.ewm(span=9, adjust=False).mean() features['macd'] = macd / features['atr_14'] features['macd_signal'] = macd_signal / features['atr_14'] features['macd_hist'] = (macd - macd_signal) / features['atr_14'] # ===== Momentum (3 features) ===== features['momentum_5'] = (close - close.shift(5)) / features['atr_14'] features['momentum_10'] = (close - close.shift(10)) / features['atr_14'] features['momentum_20'] = (close - close.shift(20)) / features['atr_14'] # ===== Stochastic (2 features) ===== low_14 = low.rolling(14).min() high_14 = high.rolling(14).max() features['stoch_k'] = 100 * (close - low_14) / (high_14 - low_14 + 1e-10) features['stoch_d'] = features['stoch_k'].rolling(3).mean() # ===== Williams %R (1 feature) ===== features['williams_r'] = -100 * (high_14 - close) / (high_14 - low_14 + 1e-10) # ===== Volume Features (2 features) ===== if volume.sum() > 0: vol_ma_5 = volume.rolling(5).mean() vol_ma_20 = volume.rolling(20).mean() features['volume_ratio'] = volume / (vol_ma_20 + 1) features['volume_trend'] = (vol_ma_5 - vol_ma_20) / (vol_ma_20 + 1) else: features['volume_ratio'] = 1.0 features['volume_trend'] = 0.0 # ===== Candle Patterns (3 features) ===== body = close - open_price features['body_pct'] = body / (candle_range + 1e-10) features['upper_shadow'] = (high - np.maximum(close, open_price)) / (candle_range + 1e-10) features['lower_shadow'] = (np.minimum(close, open_price) - low) / (candle_range + 1e-10) # ===== Price Position (3 features) ===== features['close_position'] = (close - low) / (candle_range + 1e-10) high_5 = high.rolling(5).max() low_5 = low.rolling(5).min() features['price_position_5'] = (close - low_5) / (high_5 - low_5 + 1e-10) high_20 = high.rolling(20).max() low_20 = low.rolling(20).min() features['price_position_20'] = (close - low_20) / (high_20 - low_20 + 1e-10) # ===== Time Features (7 features) ===== hour = df.index.hour day_of_week = df.index.dayofweek features['hour_sin'] = np.sin(2 * np.pi * hour / 24) features['hour_cos'] = np.cos(2 * np.pi * hour / 24) features['dow_sin'] = np.sin(2 * np.pi * day_of_week / 7) features['dow_cos'] = np.cos(2 * np.pi * day_of_week / 7) # Trading sessions features['is_london'] = ((hour >= 8) & (hour < 16)).astype(float) features['is_newyork'] = ((hour >= 13) & (hour < 21)).astype(float) features['is_overlap'] = ((hour >= 13) & (hour < 16)).astype(float) # Clean up features = features.replace([np.inf, -np.inf], np.nan) # Combine with OHLCV result = pd.concat([df[['open', 'high', 'low', 'close', 'volume']], features], axis=1) logger.info(f"Generated {len(features.columns)} features (total columns: {len(result.columns)})") return result def main(): parser = argparse.ArgumentParser( description='Train Asset Metamodels (Nivel 2)', formatter_class=argparse.RawDescriptionHelpFormatter ) # Symbol configuration parser.add_argument('--symbols', nargs='+', default=['XAUUSD', 'EURUSD'], help='Symbols to train (default: XAUUSD EURUSD)') # Path configuration parser.add_argument('--attention-path', type=str, default='models/attention', help='Path to attention models (default: models/attention)') parser.add_argument('--base-path', type=str, default='models/base', help='Path to base models (default: models/base)') parser.add_argument('--output-path', type=str, default='models/metamodels', help='Output path for metamodels (default: models/metamodels)') # Data source parser.add_argument('--data-source', type=str, choices=['mysql', 'parquet'], default='mysql', help='Data source type (default: mysql)') parser.add_argument('--data-dir', type=str, default='data/', help='Data directory for parquet files') # MySQL configuration parser.add_argument('--db-host', type=str, default='localhost') parser.add_argument('--db-user', type=str, default='root') parser.add_argument('--db-password', type=str, default='') parser.add_argument('--db-name', type=str, default='trading') # OOS period parser.add_argument('--oos-start', type=str, default='2024-01-01', help='OOS period start date (default: 2024-01-01)') parser.add_argument('--oos-end', type=str, default='2024-08-31', help='OOS period end date (default: 2024-08-31)') # Training parameters parser.add_argument('--min-samples', type=int, default=2000, help='Minimum OOS samples required (default: 2000)') parser.add_argument('--val-split', type=float, default=0.15, help='Validation split ratio (default: 0.15)') # Output options parser.add_argument('--log-dir', type=str, default='models/logs', help='Log directory') parser.add_argument('--generate-report', action='store_true', help='Generate markdown training report') args = parser.parse_args() # Setup logging log_file = setup_logging(Path(args.log_dir), 'metamodels') logger.info(f"Log file: {log_file}") logger.info("="*60) logger.info("METAMODEL TRAINING SCRIPT") logger.info("="*60) logger.info(f"Symbols: {args.symbols}") logger.info(f"OOS Period: {args.oos_start} to {args.oos_end}") logger.info(f"Attention models: {args.attention_path}") logger.info(f"Base models: {args.base_path}") logger.info(f"Output: {args.output_path}") # Import trainer from training.metamodel_trainer import MetamodelTrainer, MetamodelTrainerConfig # Create config config = MetamodelTrainerConfig( symbols=args.symbols, timeframes=['5m', '15m'], attention_model_path=args.attention_path, base_model_path=args.base_path, output_path=args.output_path, oos_start_date=args.oos_start, oos_end_date=args.oos_end, min_oos_samples=args.min_samples, val_split=args.val_split ) # Create trainer trainer = MetamodelTrainer(config) # Load pre-trained models logger.info("\n" + "="*60) logger.info("Loading pre-trained models...") logger.info("="*60) models_loaded = trainer.load_models() if not models_loaded: logger.warning("Some models failed to load, training may be incomplete") # Load data logger.info("\n" + "="*60) logger.info("Loading OHLCV data...") logger.info("="*60) data_dict = {} db_config = { 'host': args.db_host, 'user': args.db_user, 'password': args.db_password, 'database': args.db_name } for symbol in args.symbols: data_dict[symbol] = {} for timeframe in ['5m', '15m']: if args.data_source == 'mysql': df = load_ohlcv_from_mysql(symbol, timeframe, db_config) else: df = load_ohlcv_from_parquet(Path(args.data_dir), symbol, timeframe) if len(df) > 0: # Generate features (same as base model training) df = generate_features(df, symbol) data_dict[symbol][timeframe] = df logger.info(f" {symbol} {timeframe}: {len(df)} rows, {df.shape[1]} columns") else: logger.warning(f" {symbol} {timeframe}: No data loaded") # Train metamodels logger.info("\n" + "="*60) logger.info("Training Metamodels...") logger.info("="*60) results = trainer.train_all(data_dict) # Print summary logger.info("\n" + "="*60) logger.info("TRAINING SUMMARY") logger.info("="*60) summary = trainer.get_training_summary() if len(summary) > 0: print("\n" + summary.to_string(index=False)) # Save models logger.info("\n" + "="*60) logger.info("Saving models...") logger.info("="*60) trainer.save() # Generate report if args.generate_report: report_path = Path(args.output_path) / f'training_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.md' report_content = f"""# Metamodel Training Report **Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ## Configuration - **Symbols:** {', '.join(args.symbols)} - **OOS Period:** {args.oos_start} to {args.oos_end} - **Attention Models:** {args.attention_path} - **Base Models:** {args.base_path} - **Output:** {args.output_path} ## Training Results | Symbol | Status | Samples | MAE High | MAE Low | R2 High | R2 Low | Confidence Acc | Improvement | |--------|--------|---------|----------|---------|---------|--------|----------------|-------------| """ for _, row in summary.iterrows(): if row['status'] == 'success': report_content += f"| {row['symbol']} | {row['status']} | {row.get('n_samples', 'N/A')} | " report_content += f"{row.get('mae_high', 'N/A'):.4f} | {row.get('mae_low', 'N/A'):.4f} | " report_content += f"{row.get('r2_high', 'N/A'):.4f} | {row.get('r2_low', 'N/A'):.4f} | " report_content += f"{row.get('confidence_accuracy', 'N/A'):.2%} | " report_content += f"{row.get('improvement_over_avg', 'N/A'):.1f}% |\n" else: report_content += f"| {row['symbol']} | {row['status']} | - | - | - | - | - | - | - |\n" report_content += f""" ## Architecture ``` Nivel 2: Metamodel (per asset) ├── Input Features (10): │ ├── pred_high_5m, pred_low_5m │ ├── pred_high_15m, pred_low_15m │ ├── attention_5m, attention_15m │ ├── attention_class_5m, attention_class_15m │ └── ATR_ratio, volume_z ├── Models: │ ├── XGBoost Regressor (HIGH) │ ├── XGBoost Regressor (LOW) │ └── XGBoost Classifier (Confidence) └── Outputs: ├── delta_high_final ├── delta_low_final └── confidence (binary + probability) ``` ## Log File `{log_file}` """ with open(report_path, 'w') as f: f.write(report_content) logger.info(f"Report saved to: {report_path}") logger.info("\n" + "="*60) logger.info("TRAINING COMPLETE") logger.info("="*60) # Return exit code based on results success_count = sum(1 for r in results.values() if r.get('status') == 'success') total_count = len(results) if success_count == total_count: logger.info(f"All {total_count} metamodels trained successfully") return 0 elif success_count > 0: logger.warning(f"{success_count}/{total_count} metamodels trained successfully") return 0 else: logger.error("No metamodels were trained successfully") return 1 if __name__ == "__main__": sys.exit(main())