Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
626 lines
21 KiB
Python
626 lines
21 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Symbol-Timeframe Training Script
|
|
================================
|
|
Trains separate ML models for each symbol and timeframe combination.
|
|
|
|
This script uses the SymbolTimeframeTrainer to train models for:
|
|
- XAUUSD (Gold)
|
|
- EURUSD (Euro/USD)
|
|
- BTCUSD (Bitcoin) - if data available
|
|
|
|
Each symbol is trained for both 5m and 15m timeframes.
|
|
|
|
Features:
|
|
- Loads data from MySQL database
|
|
- Excludes last year (2025) for backtesting
|
|
- Uses dynamic factor-based sample weighting
|
|
- Generates comprehensive feature set
|
|
- Saves models and training reports
|
|
|
|
Usage:
|
|
python scripts/train_symbol_timeframe_models.py
|
|
python scripts/train_symbol_timeframe_models.py --symbols XAUUSD EURUSD --timeframes 5m 15m
|
|
|
|
Author: ML Training Pipeline
|
|
Version: 1.0.0
|
|
Created: 2026-01-05
|
|
"""
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
from loguru import logger
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
|
|
|
|
from training.symbol_timeframe_trainer import (
|
|
SymbolTimeframeTrainer,
|
|
TrainerConfig,
|
|
SYMBOL_CONFIGS
|
|
)
|
|
from data.database import MySQLConnection
|
|
|
|
|
|
def setup_logging(log_dir: Path, experiment_name: str):
|
|
"""Configure logging to file and console."""
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
log_file = log_dir / f"{experiment_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
|
|
|
|
logger.remove()
|
|
logger.add(sys.stderr, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
|
logger.add(log_file, level="DEBUG", rotation="10 MB")
|
|
|
|
logger.info(f"Logging to {log_file}")
|
|
return log_file
|
|
|
|
|
|
def load_data_from_db(
|
|
db: MySQLConnection,
|
|
symbol: str,
|
|
start_date: str = None,
|
|
end_date: str = None,
|
|
limit: int = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Load OHLCV data from MySQL database.
|
|
|
|
Args:
|
|
db: MySQL connection
|
|
symbol: Trading symbol (e.g., 'XAUUSD')
|
|
start_date: Start date filter (YYYY-MM-DD)
|
|
end_date: End date filter (YYYY-MM-DD)
|
|
limit: Maximum records to fetch
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data
|
|
"""
|
|
# Normalize symbol name
|
|
db_symbol = symbol
|
|
if not symbol.startswith('C:') and not symbol.startswith('X:'):
|
|
if symbol == 'BTCUSD':
|
|
db_symbol = f'X:{symbol}'
|
|
else:
|
|
db_symbol = f'C:{symbol}'
|
|
|
|
logger.info(f"Loading data for {db_symbol}...")
|
|
|
|
query = """
|
|
SELECT
|
|
date_agg as time,
|
|
open,
|
|
high,
|
|
low,
|
|
close,
|
|
volume,
|
|
vwap
|
|
FROM tickers_agg_data
|
|
WHERE ticker = :symbol
|
|
"""
|
|
|
|
params = {'symbol': db_symbol}
|
|
|
|
if start_date:
|
|
query += " AND date_agg >= :start_date"
|
|
params['start_date'] = start_date
|
|
if end_date:
|
|
query += " AND date_agg <= :end_date"
|
|
params['end_date'] = end_date
|
|
|
|
query += " ORDER BY date_agg ASC"
|
|
|
|
if limit:
|
|
query += f" LIMIT {limit}"
|
|
|
|
df = db.execute_query(query, params)
|
|
|
|
if df.empty:
|
|
logger.warning(f"No data found for {symbol}")
|
|
return df
|
|
|
|
# Set datetime index
|
|
df['time'] = pd.to_datetime(df['time'])
|
|
df.set_index('time', inplace=True)
|
|
df = df.sort_index()
|
|
|
|
# Rename columns to match expected format
|
|
df.columns = ['open', 'high', 'low', 'close', 'volume', 'vwap']
|
|
|
|
logger.info(f"Loaded {len(df)} records for {symbol}")
|
|
logger.info(f" Date range: {df.index.min()} to {df.index.max()}")
|
|
|
|
return df
|
|
|
|
|
|
def resample_to_timeframe(df: pd.DataFrame, timeframe: str) -> pd.DataFrame:
|
|
"""
|
|
Resample 5-minute data to different timeframe.
|
|
|
|
Args:
|
|
df: DataFrame with 5m data
|
|
timeframe: Target timeframe ('5m', '15m', '1H', etc.)
|
|
|
|
Returns:
|
|
Resampled DataFrame
|
|
"""
|
|
if timeframe == '5m':
|
|
return df # Already in 5m
|
|
|
|
# Map timeframe to pandas offset
|
|
tf_map = {
|
|
'15m': '15min',
|
|
'30m': '30min',
|
|
'1H': '1H',
|
|
'4H': '4H',
|
|
'1D': '1D'
|
|
}
|
|
|
|
offset = tf_map.get(timeframe, timeframe)
|
|
|
|
resampled = df.resample(offset).agg({
|
|
'open': 'first',
|
|
'high': 'max',
|
|
'low': 'min',
|
|
'close': 'last',
|
|
'volume': 'sum',
|
|
'vwap': 'mean'
|
|
}).dropna()
|
|
|
|
logger.info(f"Resampled to {timeframe}: {len(resampled)} bars")
|
|
return resampled
|
|
|
|
|
|
def generate_features(df: pd.DataFrame, symbol: str) -> pd.DataFrame:
|
|
"""
|
|
Generate comprehensive feature set for training.
|
|
|
|
Args:
|
|
df: OHLCV DataFrame
|
|
symbol: Symbol for context-specific features
|
|
|
|
Returns:
|
|
DataFrame with features
|
|
"""
|
|
logger.info(f"Generating features for {symbol}...")
|
|
|
|
features = pd.DataFrame(index=df.index)
|
|
|
|
close = df['close']
|
|
high = df['high']
|
|
low = df['low']
|
|
open_price = df['open']
|
|
volume = df['volume'] if 'volume' in df.columns else pd.Series(1, index=df.index)
|
|
|
|
# ===== Price Returns =====
|
|
features['returns_1'] = close.pct_change(1)
|
|
features['returns_3'] = close.pct_change(3)
|
|
features['returns_5'] = close.pct_change(5)
|
|
features['returns_10'] = close.pct_change(10)
|
|
features['returns_20'] = close.pct_change(20)
|
|
|
|
# ===== Volatility Features =====
|
|
features['volatility_5'] = close.pct_change().rolling(5).std()
|
|
features['volatility_10'] = close.pct_change().rolling(10).std()
|
|
features['volatility_20'] = close.pct_change().rolling(20).std()
|
|
|
|
# ===== Range Features =====
|
|
candle_range = high - low
|
|
features['range'] = candle_range
|
|
features['range_pct'] = candle_range / close
|
|
features['range_ma_5'] = candle_range.rolling(5).mean()
|
|
features['range_ma_10'] = candle_range.rolling(10).mean()
|
|
features['range_ma_20'] = candle_range.rolling(20).mean()
|
|
features['range_ratio_5'] = candle_range / features['range_ma_5']
|
|
features['range_ratio_20'] = candle_range / features['range_ma_20']
|
|
|
|
# ===== ATR Features =====
|
|
tr1 = high - low
|
|
tr2 = abs(high - close.shift(1))
|
|
tr3 = abs(low - close.shift(1))
|
|
true_range = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
|
features['atr_5'] = true_range.rolling(5).mean()
|
|
features['atr_14'] = true_range.rolling(14).mean()
|
|
features['atr_20'] = true_range.rolling(20).mean()
|
|
features['atr_ratio'] = true_range / features['atr_14']
|
|
|
|
# ===== Moving Averages =====
|
|
sma_5 = close.rolling(5).mean()
|
|
sma_10 = close.rolling(10).mean()
|
|
sma_20 = close.rolling(20).mean()
|
|
sma_50 = close.rolling(50).mean()
|
|
|
|
ema_5 = close.ewm(span=5, adjust=False).mean()
|
|
ema_10 = close.ewm(span=10, adjust=False).mean()
|
|
ema_20 = close.ewm(span=20, adjust=False).mean()
|
|
|
|
features['price_vs_sma5'] = (close - sma_5) / features['atr_14']
|
|
features['price_vs_sma10'] = (close - sma_10) / features['atr_14']
|
|
features['price_vs_sma20'] = (close - sma_20) / features['atr_14']
|
|
features['price_vs_sma50'] = (close - sma_50) / features['atr_14']
|
|
features['sma5_vs_sma20'] = (sma_5 - sma_20) / features['atr_14']
|
|
features['ema5_vs_ema20'] = (ema_5 - ema_20) / features['atr_14']
|
|
|
|
# ===== RSI =====
|
|
delta = close.diff()
|
|
gain = delta.where(delta > 0, 0).rolling(14).mean()
|
|
loss = (-delta.where(delta < 0, 0)).rolling(14).mean()
|
|
rs = gain / (loss + 1e-10)
|
|
features['rsi_14'] = 100 - (100 / (1 + rs))
|
|
|
|
# RSI extremes
|
|
features['rsi_oversold'] = (features['rsi_14'] < 30).astype(float)
|
|
features['rsi_overbought'] = (features['rsi_14'] > 70).astype(float)
|
|
|
|
# ===== Bollinger Bands =====
|
|
bb_middle = close.rolling(20).mean()
|
|
bb_std = close.rolling(20).std()
|
|
bb_upper = bb_middle + 2 * bb_std
|
|
bb_lower = bb_middle - 2 * bb_std
|
|
features['bb_width'] = (bb_upper - bb_lower) / bb_middle
|
|
features['bb_position'] = (close - bb_lower) / (bb_upper - bb_lower + 1e-10)
|
|
|
|
# ===== MACD =====
|
|
ema_12 = close.ewm(span=12, adjust=False).mean()
|
|
ema_26 = close.ewm(span=26, adjust=False).mean()
|
|
macd = ema_12 - ema_26
|
|
macd_signal = macd.ewm(span=9, adjust=False).mean()
|
|
features['macd'] = macd / features['atr_14']
|
|
features['macd_signal'] = macd_signal / features['atr_14']
|
|
features['macd_hist'] = (macd - macd_signal) / features['atr_14']
|
|
|
|
# ===== Momentum =====
|
|
features['momentum_5'] = (close - close.shift(5)) / features['atr_14']
|
|
features['momentum_10'] = (close - close.shift(10)) / features['atr_14']
|
|
features['momentum_20'] = (close - close.shift(20)) / features['atr_14']
|
|
|
|
# ===== Stochastic =====
|
|
low_14 = low.rolling(14).min()
|
|
high_14 = high.rolling(14).max()
|
|
features['stoch_k'] = 100 * (close - low_14) / (high_14 - low_14 + 1e-10)
|
|
features['stoch_d'] = features['stoch_k'].rolling(3).mean()
|
|
|
|
# ===== Williams %R =====
|
|
features['williams_r'] = -100 * (high_14 - close) / (high_14 - low_14 + 1e-10)
|
|
|
|
# ===== Volume Features =====
|
|
if volume.sum() > 0:
|
|
vol_ma_5 = volume.rolling(5).mean()
|
|
vol_ma_20 = volume.rolling(20).mean()
|
|
features['volume_ratio'] = volume / (vol_ma_20 + 1)
|
|
features['volume_trend'] = (vol_ma_5 - vol_ma_20) / (vol_ma_20 + 1)
|
|
|
|
# ===== Candle Patterns =====
|
|
body = close - open_price
|
|
features['body_pct'] = body / (candle_range + 1e-10)
|
|
features['upper_shadow'] = (high - np.maximum(close, open_price)) / (candle_range + 1e-10)
|
|
features['lower_shadow'] = (np.minimum(close, open_price) - low) / (candle_range + 1e-10)
|
|
|
|
# ===== Price Position =====
|
|
features['close_position'] = (close - low) / (candle_range + 1e-10)
|
|
high_5 = high.rolling(5).max()
|
|
low_5 = low.rolling(5).min()
|
|
features['price_position_5'] = (close - low_5) / (high_5 - low_5 + 1e-10)
|
|
|
|
high_20 = high.rolling(20).max()
|
|
low_20 = low.rolling(20).min()
|
|
features['price_position_20'] = (close - low_20) / (high_20 - low_20 + 1e-10)
|
|
|
|
# ===== Time Features =====
|
|
features['hour'] = df.index.hour
|
|
features['hour_sin'] = np.sin(2 * np.pi * features['hour'] / 24)
|
|
features['hour_cos'] = np.cos(2 * np.pi * features['hour'] / 24)
|
|
features['day_of_week'] = df.index.dayofweek
|
|
features['dow_sin'] = np.sin(2 * np.pi * features['day_of_week'] / 7)
|
|
features['dow_cos'] = np.cos(2 * np.pi * features['day_of_week'] / 7)
|
|
|
|
# Trading sessions
|
|
features['is_london'] = ((features['hour'] >= 8) & (features['hour'] < 16)).astype(float)
|
|
features['is_newyork'] = ((features['hour'] >= 13) & (features['hour'] < 21)).astype(float)
|
|
features['is_overlap'] = ((features['hour'] >= 13) & (features['hour'] < 16)).astype(float)
|
|
|
|
# Clean up
|
|
features = features.replace([np.inf, -np.inf], np.nan)
|
|
|
|
# Drop non-feature columns used for intermediate calculations
|
|
drop_cols = ['hour', 'day_of_week']
|
|
features = features.drop(columns=[c for c in drop_cols if c in features.columns], errors='ignore')
|
|
|
|
logger.info(f"Generated {len(features.columns)} features")
|
|
|
|
return features
|
|
|
|
|
|
def train_models(
|
|
symbols: list,
|
|
timeframes: list,
|
|
output_dir: Path,
|
|
cutoff_date: str = '2024-12-31',
|
|
db_config_path: str = 'config/database.yaml',
|
|
use_attention: bool = False,
|
|
attention_model_path: str = 'models/attention'
|
|
) -> dict:
|
|
"""
|
|
Train models for all symbol/timeframe combinations.
|
|
|
|
Args:
|
|
symbols: List of symbols to train
|
|
timeframes: List of timeframes
|
|
output_dir: Directory to save models
|
|
cutoff_date: Training data cutoff date
|
|
db_config_path: Path to database config
|
|
use_attention: Whether to include attention features from pre-trained model
|
|
attention_model_path: Path to trained attention models
|
|
|
|
Returns:
|
|
Dictionary with training results
|
|
"""
|
|
logger.info("="*60)
|
|
logger.info("Symbol-Timeframe Model Training")
|
|
logger.info("="*60)
|
|
logger.info(f"Symbols: {symbols}")
|
|
logger.info(f"Timeframes: {timeframes}")
|
|
logger.info(f"Cutoff date: {cutoff_date}")
|
|
logger.info(f"Use attention features: {use_attention}")
|
|
|
|
# Connect to database
|
|
db = MySQLConnection(db_config_path)
|
|
|
|
# Configure trainer with improved parameters for better R^2
|
|
# Key improvements:
|
|
# 1. Targets are now normalized by ATR (handled in SymbolTimeframeTrainer)
|
|
# 2. Reduced sample weighting aggressiveness
|
|
# 3. More regularization in XGBoost
|
|
config = TrainerConfig(
|
|
symbols=symbols,
|
|
timeframes=timeframes,
|
|
horizons={
|
|
'5m': 3, # 15 minutes ahead
|
|
'15m': 3, # 45 minutes ahead
|
|
},
|
|
train_years=5.0,
|
|
holdout_years=1.0, # Exclude 2025 for backtesting
|
|
use_dynamic_factor_weighting=True,
|
|
factor_window=200,
|
|
softplus_beta=2.0, # Reduced from 4.0 - less aggressive weighting
|
|
softplus_w_max=2.0, # Reduced from 3.0 - cap weights lower
|
|
xgb_params={
|
|
'n_estimators': 150, # Reduced from 300
|
|
'max_depth': 4, # Reduced from 6 - shallower trees
|
|
'learning_rate': 0.02, # Reduced from 0.03
|
|
'subsample': 0.7, # Reduced from 0.8
|
|
'colsample_bytree': 0.7, # Reduced from 0.8
|
|
'min_child_weight': 20, # Increased from 10 - more regularization
|
|
'gamma': 0.3, # Increased from 0.1
|
|
'reg_alpha': 0.5, # Increased from 0.1 - L1 regularization
|
|
'reg_lambda': 5.0, # Increased from 1.0 - L2 regularization
|
|
'tree_method': 'hist',
|
|
'random_state': 42
|
|
},
|
|
min_train_samples=5000,
|
|
use_attention_features=use_attention,
|
|
attention_model_path=attention_model_path
|
|
)
|
|
|
|
trainer = SymbolTimeframeTrainer(config)
|
|
|
|
# Prepare data dictionary
|
|
data_dict = {}
|
|
all_results = {}
|
|
|
|
for symbol in symbols:
|
|
logger.info(f"\n{'='*60}")
|
|
logger.info(f"Processing {symbol}")
|
|
logger.info(f"{'='*60}")
|
|
|
|
# Load raw data (5m)
|
|
df_5m = load_data_from_db(db, symbol, end_date=cutoff_date)
|
|
|
|
if df_5m.empty:
|
|
logger.warning(f"No data for {symbol}, skipping...")
|
|
continue
|
|
|
|
# Verify we have enough data
|
|
if len(df_5m) < 50000:
|
|
logger.warning(f"Insufficient data for {symbol}: {len(df_5m)} rows")
|
|
continue
|
|
|
|
data_dict[symbol] = {}
|
|
|
|
for timeframe in timeframes:
|
|
logger.info(f"\n--- {symbol} {timeframe} ---")
|
|
|
|
# Resample if needed
|
|
if timeframe == '5m':
|
|
df_tf = df_5m.copy()
|
|
else:
|
|
df_tf = resample_to_timeframe(df_5m.copy(), timeframe)
|
|
|
|
if len(df_tf) < 10000:
|
|
logger.warning(f"Insufficient {timeframe} data: {len(df_tf)} rows")
|
|
continue
|
|
|
|
# Generate features
|
|
features = generate_features(df_tf, symbol)
|
|
|
|
# Combine OHLCV with features
|
|
df_combined = pd.concat([df_tf[['open', 'high', 'low', 'close', 'volume']], features], axis=1)
|
|
|
|
# Drop NaN rows
|
|
df_combined = df_combined.dropna()
|
|
|
|
logger.info(f"Final data shape: {df_combined.shape}")
|
|
|
|
data_dict[symbol][timeframe] = df_combined
|
|
|
|
# Train for this symbol/timeframe
|
|
try:
|
|
results = trainer.train_single(df_combined, symbol, timeframe)
|
|
all_results.update(results)
|
|
|
|
for key, result in results.items():
|
|
logger.info(f"\n{key}:")
|
|
logger.info(f" MAE: {result.mae:.6f}")
|
|
logger.info(f" RMSE: {result.rmse:.6f}")
|
|
logger.info(f" R2: {result.r2:.4f}")
|
|
logger.info(f" Dir Accuracy: {result.directional_accuracy:.2%}")
|
|
logger.info(f" Train samples: {result.n_train}")
|
|
logger.info(f" Val samples: {result.n_val}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Training failed for {symbol} {timeframe}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
# Save models
|
|
model_dir = output_dir / 'symbol_timeframe_models'
|
|
trainer.save(str(model_dir))
|
|
logger.info(f"\nModels saved to {model_dir}")
|
|
|
|
# Generate summary report
|
|
summary_df = trainer.get_training_summary()
|
|
|
|
if not summary_df.empty:
|
|
report_path = output_dir / 'training_summary.csv'
|
|
summary_df.to_csv(report_path, index=False)
|
|
logger.info(f"Summary saved to {report_path}")
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("TRAINING SUMMARY")
|
|
logger.info("="*60)
|
|
print(summary_df.to_string(index=False))
|
|
|
|
return {
|
|
'results': all_results,
|
|
'summary': summary_df.to_dict() if not summary_df.empty else {},
|
|
'model_dir': str(model_dir)
|
|
}
|
|
|
|
|
|
def generate_markdown_report(results: dict, output_dir: Path) -> Path:
|
|
"""Generate a Markdown training report."""
|
|
report_path = output_dir / f"TRAINING_REPORT_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
|
|
|
|
report = f"""# Symbol-Timeframe Model Training Report
|
|
|
|
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
|
|
|
## Configuration
|
|
|
|
- **Training Data Cutoff:** 2024-12-31 (excluding 2025 for backtesting)
|
|
- **Dynamic Factor Weighting:** Enabled
|
|
- **Sample Weight Method:** Softplus with beta=4.0, w_max=3.0
|
|
|
|
## Training Results Summary
|
|
|
|
| Model | Symbol | Timeframe | Target | MAE | RMSE | R2 | Dir Accuracy | Train | Val |
|
|
|-------|--------|-----------|--------|-----|------|----|--------------| ----- | --- |
|
|
"""
|
|
|
|
for key, result in results.get('results', {}).items():
|
|
report += f"| {key} | {result.model_key.symbol} | {result.model_key.timeframe} | "
|
|
report += f"{result.model_key.target_type} | {result.mae:.6f} | {result.rmse:.6f} | "
|
|
report += f"{result.r2:.4f} | {result.directional_accuracy:.2%} | "
|
|
report += f"{result.n_train} | {result.n_val} |\n"
|
|
|
|
model_dir_str = results.get('model_dir', 'N/A')
|
|
report += f"""
|
|
## Model Files
|
|
|
|
Models saved to: `{model_dir_str}`
|
|
|
|
### Model Naming Convention
|
|
- `{{symbol}}_{{timeframe}}_high_h{{horizon}}.joblib` - High range predictor
|
|
- `{{symbol}}_{{timeframe}}_low_h{{horizon}}.joblib` - Low range predictor
|
|
|
|
## Usage Example
|
|
|
|
```python
|
|
from training.symbol_timeframe_trainer import SymbolTimeframeTrainer
|
|
|
|
# Load trained models
|
|
trainer = SymbolTimeframeTrainer()
|
|
trainer.load('models/symbol_timeframe_models/')
|
|
|
|
# Predict for XAUUSD 15m
|
|
predictions = trainer.predict(features, 'XAUUSD', '15m')
|
|
print(f"Predicted High: {{predictions['high']}}")
|
|
print(f"Predicted Low: {{predictions['low']}}")
|
|
```
|
|
|
|
## Notes
|
|
|
|
1. Models exclude 2025 data for out-of-sample backtesting
|
|
2. Dynamic factor weighting emphasizes high-movement samples
|
|
3. Separate models for HIGH and LOW predictions per symbol/timeframe
|
|
|
|
---
|
|
*Report generated by Symbol-Timeframe Training Pipeline*
|
|
"""
|
|
|
|
with open(report_path, 'w') as f:
|
|
f.write(report)
|
|
|
|
logger.info(f"Report saved to {report_path}")
|
|
return report_path
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Train Symbol-Timeframe ML Models')
|
|
parser.add_argument('--symbols', nargs='+', default=['XAUUSD', 'EURUSD'],
|
|
help='Symbols to train (default: XAUUSD EURUSD)')
|
|
parser.add_argument('--timeframes', nargs='+', default=['5m', '15m'],
|
|
help='Timeframes to train (default: 5m 15m)')
|
|
parser.add_argument('--output-dir', type=str, default='models/',
|
|
help='Output directory for models')
|
|
parser.add_argument('--cutoff-date', type=str, default='2024-12-31',
|
|
help='Training data cutoff date')
|
|
parser.add_argument('--db-config', type=str, default='config/database.yaml',
|
|
help='Database configuration file')
|
|
parser.add_argument('--use-attention', action='store_true',
|
|
help='Use attention features from pre-trained attention model')
|
|
parser.add_argument('--attention-model-path', type=str, default='models/attention',
|
|
help='Path to trained attention models')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup paths
|
|
script_dir = Path(__file__).parent.parent
|
|
output_dir = script_dir / args.output_dir
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
logs_dir = output_dir / 'logs'
|
|
setup_logging(logs_dir, 'symbol_timeframe_training')
|
|
|
|
# Run training
|
|
try:
|
|
results = train_models(
|
|
symbols=args.symbols,
|
|
timeframes=args.timeframes,
|
|
output_dir=output_dir,
|
|
cutoff_date=args.cutoff_date,
|
|
db_config_path=str(script_dir / args.db_config),
|
|
use_attention=args.use_attention,
|
|
attention_model_path=str(script_dir / args.attention_model_path)
|
|
)
|
|
|
|
# Generate report
|
|
generate_markdown_report(results, output_dir)
|
|
|
|
logger.info("\n" + "="*60)
|
|
logger.info("TRAINING COMPLETE!")
|
|
logger.info("="*60)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Training failed: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|