Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
530 lines
16 KiB
Python
530 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Dataset Preparation Script for ML-First Strategy
|
|
=================================================
|
|
Prepares training datasets by timeframe with proper temporal splits.
|
|
|
|
Usage:
|
|
python scripts/prepare_datasets.py --symbol XAUUSD --timeframes 5m,15m,1H,4H,D
|
|
python scripts/prepare_datasets.py --all-symbols
|
|
|
|
Author: ML-Specialist (NEXUS v4.0)
|
|
Created: 2026-01-04
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional
|
|
import pandas as pd
|
|
import numpy as np
|
|
import yaml
|
|
from loguru import logger
|
|
|
|
# Add src to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from data.database import DatabaseManager
|
|
from data.pipeline import DataPipeline
|
|
from data.indicators import TechnicalIndicators
|
|
from training.data_splitter import TemporalDataSplitter, create_ml_first_splits
|
|
|
|
|
|
# Configure logging
|
|
logger.remove()
|
|
logger.add(
|
|
sys.stdout,
|
|
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
|
|
level="INFO"
|
|
)
|
|
|
|
|
|
class DatasetPreparer:
|
|
"""
|
|
Prepares multi-timeframe datasets for ML training.
|
|
"""
|
|
|
|
# Timeframe configuration
|
|
TIMEFRAME_CONFIG = {
|
|
'5m': {'periods': 1, 'resample': '5min', 'horizons': {'scalping': 6}},
|
|
'15m': {'periods': 3, 'resample': '15min', 'horizons': {'scalping': 4, 'intraday': 2}},
|
|
'1H': {'periods': 12, 'resample': '1H', 'horizons': {'intraday': 4, 'swing': 2}},
|
|
'4H': {'periods': 48, 'resample': '4H', 'horizons': {'swing': 6, 'position': 2}},
|
|
'D': {'periods': 288, 'resample': '1D', 'horizons': {'position': 5, 'weekly': 1}},
|
|
'W': {'periods': 2016, 'resample': '1W', 'horizons': {'weekly': 4}}
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
output_dir: str = "datasets",
|
|
config_path: str = "config/validation_oos.yaml"
|
|
):
|
|
"""
|
|
Initialize the dataset preparer.
|
|
|
|
Args:
|
|
output_dir: Directory to save datasets
|
|
config_path: Path to validation configuration
|
|
"""
|
|
self.output_dir = Path(output_dir)
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
self.config_path = config_path
|
|
|
|
self.db_manager = DatabaseManager()
|
|
self.splitter = TemporalDataSplitter(config_path)
|
|
self.indicators = TechnicalIndicators()
|
|
|
|
# Load validation config
|
|
with open(config_path, 'r') as f:
|
|
self.config = yaml.safe_load(f)
|
|
|
|
def fetch_raw_data(
|
|
self,
|
|
symbol: str,
|
|
limit: int = 500000
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Fetch raw data from MySQL database.
|
|
|
|
Args:
|
|
symbol: Trading symbol (e.g., 'XAUUSD')
|
|
limit: Maximum number of records
|
|
|
|
Returns:
|
|
DataFrame with OHLCV data
|
|
"""
|
|
logger.info(f"Fetching data for {symbol}...")
|
|
|
|
# Get data from database
|
|
df = self.db_manager.db.get_ticker_data(symbol, limit=limit)
|
|
|
|
if df.empty:
|
|
logger.warning(f"No data found for {symbol}")
|
|
return df
|
|
|
|
# Ensure proper datetime index
|
|
if not isinstance(df.index, pd.DatetimeIndex):
|
|
df.index = pd.to_datetime(df.index)
|
|
|
|
# Sort by time
|
|
df = df.sort_index()
|
|
|
|
logger.info(
|
|
f"Loaded {len(df):,} records for {symbol} "
|
|
f"({df.index.min()} to {df.index.max()})"
|
|
)
|
|
|
|
return df
|
|
|
|
def resample_data(
|
|
self,
|
|
df: pd.DataFrame,
|
|
timeframe: str
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Resample data to specified timeframe.
|
|
|
|
Args:
|
|
df: DataFrame with 5-minute data
|
|
timeframe: Target timeframe (e.g., '15m', '1H', '4H', 'D', 'W')
|
|
|
|
Returns:
|
|
Resampled DataFrame
|
|
"""
|
|
if timeframe not in self.TIMEFRAME_CONFIG:
|
|
raise ValueError(f"Unknown timeframe: {timeframe}")
|
|
|
|
if timeframe == '5m':
|
|
# Already in 5-minute resolution
|
|
return df.copy()
|
|
|
|
resample_rule = self.TIMEFRAME_CONFIG[timeframe]['resample']
|
|
|
|
# OHLCV resampling rules
|
|
ohlcv_cols = ['open', 'high', 'low', 'close', 'volume']
|
|
available_cols = [col for col in ohlcv_cols if col in df.columns]
|
|
|
|
resample_dict = {}
|
|
if 'open' in available_cols:
|
|
resample_dict['open'] = 'first'
|
|
if 'high' in available_cols:
|
|
resample_dict['high'] = 'max'
|
|
if 'low' in available_cols:
|
|
resample_dict['low'] = 'min'
|
|
if 'close' in available_cols:
|
|
resample_dict['close'] = 'last'
|
|
if 'volume' in available_cols:
|
|
resample_dict['volume'] = 'sum'
|
|
|
|
df_resampled = df[available_cols].resample(resample_rule).agg(resample_dict)
|
|
df_resampled = df_resampled.dropna()
|
|
|
|
logger.info(
|
|
f"Resampled to {timeframe}: {len(df_resampled):,} bars "
|
|
f"({df_resampled.index.min()} to {df_resampled.index.max()})"
|
|
)
|
|
|
|
return df_resampled
|
|
|
|
def calculate_features(
|
|
self,
|
|
df: pd.DataFrame,
|
|
timeframe: str
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Calculate technical indicators and features for the given timeframe.
|
|
|
|
Args:
|
|
df: OHLCV DataFrame
|
|
timeframe: Timeframe identifier
|
|
|
|
Returns:
|
|
DataFrame with features added
|
|
"""
|
|
logger.info(f"Calculating features for {timeframe}...")
|
|
|
|
# Calculate all indicators
|
|
df = self.indicators.calculate_all_indicators(df, minimal=True)
|
|
|
|
# Calculate rolling features with timeframe-appropriate windows
|
|
windows = self._get_rolling_windows(timeframe)
|
|
df = self.indicators.calculate_rolling_features(df, windows)
|
|
|
|
# Transform to ratios
|
|
df = self.indicators.transform_to_ratios(df)
|
|
|
|
# Drop NaN values
|
|
df = df.dropna()
|
|
|
|
logger.info(f"Features calculated: {len(df.columns)} columns, {len(df):,} rows")
|
|
|
|
return df
|
|
|
|
def _get_rolling_windows(self, timeframe: str) -> List[int]:
|
|
"""Get appropriate rolling windows for timeframe"""
|
|
window_config = {
|
|
'5m': [12, 48, 96], # 1h, 4h, 8h in 5m bars
|
|
'15m': [4, 16, 32], # 1h, 4h, 8h in 15m bars
|
|
'1H': [4, 12, 24], # 4h, 12h, 24h in 1H bars
|
|
'4H': [6, 12, 24], # 1d, 2d, 4d in 4H bars
|
|
'D': [5, 10, 20], # 1w, 2w, 1m in D bars
|
|
'W': [4, 8, 12] # 1m, 2m, 3m in W bars
|
|
}
|
|
return window_config.get(timeframe, [15, 60, 120])
|
|
|
|
def create_targets(
|
|
self,
|
|
df: pd.DataFrame,
|
|
timeframe: str
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Create target variables for the given timeframe.
|
|
|
|
Args:
|
|
df: DataFrame with features
|
|
timeframe: Timeframe identifier
|
|
|
|
Returns:
|
|
DataFrame with targets added
|
|
"""
|
|
horizons = self.TIMEFRAME_CONFIG[timeframe]['horizons']
|
|
|
|
for horizon_name, periods in horizons.items():
|
|
# Future high
|
|
future_highs = [df['high'].shift(-i) for i in range(1, periods + 1)]
|
|
df[f'target_max_high_{horizon_name}'] = pd.concat(future_highs, axis=1).max(axis=1)
|
|
|
|
# Future low
|
|
future_lows = [df['low'].shift(-i) for i in range(1, periods + 1)]
|
|
df[f'target_min_low_{horizon_name}'] = pd.concat(future_lows, axis=1).min(axis=1)
|
|
|
|
# Future close
|
|
df[f'target_close_{horizon_name}'] = df['close'].shift(-periods)
|
|
|
|
# Delta ratios (targets for regression)
|
|
df[f'target_delta_high_{horizon_name}'] = (
|
|
df[f'target_max_high_{horizon_name}'] / df['close'] - 1
|
|
)
|
|
df[f'target_delta_low_{horizon_name}'] = (
|
|
df[f'target_min_low_{horizon_name}'] / df['close'] - 1
|
|
)
|
|
df[f'target_delta_close_{horizon_name}'] = (
|
|
df[f'target_close_{horizon_name}'] / df['close'] - 1
|
|
)
|
|
|
|
# Direction (target for classification)
|
|
df[f'target_direction_{horizon_name}'] = (
|
|
df[f'target_close_{horizon_name}'] > df['close']
|
|
).astype(int)
|
|
|
|
# Remove rows with NaN targets
|
|
target_cols = [col for col in df.columns if col.startswith('target_')]
|
|
df = df.dropna(subset=target_cols)
|
|
|
|
logger.info(f"Targets created: {len(target_cols)} target columns")
|
|
|
|
return df
|
|
|
|
def prepare_symbol_timeframe(
|
|
self,
|
|
symbol: str,
|
|
timeframe: str,
|
|
save: bool = True
|
|
) -> Dict[str, pd.DataFrame]:
|
|
"""
|
|
Prepare complete dataset for a symbol and timeframe.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframe: Target timeframe
|
|
save: Whether to save to disk
|
|
|
|
Returns:
|
|
Dictionary with train/val/test_oos DataFrames
|
|
"""
|
|
logger.info(f"=" * 60)
|
|
logger.info(f"Preparing {symbol} @ {timeframe}")
|
|
logger.info(f"=" * 60)
|
|
|
|
# Step 1: Fetch raw data
|
|
df_raw = self.fetch_raw_data(symbol)
|
|
if df_raw.empty:
|
|
return {}
|
|
|
|
# Step 2: Resample if needed
|
|
df = self.resample_data(df_raw, timeframe)
|
|
|
|
# Step 3: Calculate features
|
|
df = self.calculate_features(df, timeframe)
|
|
|
|
# Step 4: Create targets
|
|
df = self.create_targets(df, timeframe)
|
|
|
|
# Step 5: Show data summary
|
|
self.splitter.print_data_summary(df)
|
|
|
|
# Step 6: Create temporal splits
|
|
splits = create_ml_first_splits(df, self.config_path)
|
|
|
|
# Step 7: Save datasets
|
|
if save:
|
|
self._save_datasets(splits, symbol, timeframe)
|
|
|
|
return splits
|
|
|
|
def _save_datasets(
|
|
self,
|
|
splits: Dict[str, pd.DataFrame],
|
|
symbol: str,
|
|
timeframe: str
|
|
):
|
|
"""Save dataset splits to parquet files"""
|
|
for split_name, df in splits.items():
|
|
# Create directory structure
|
|
save_dir = self.output_dir / symbol / timeframe
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save as parquet
|
|
save_path = save_dir / f"{split_name}.parquet"
|
|
df.to_parquet(save_path, engine='pyarrow', compression='snappy')
|
|
|
|
logger.info(f"Saved {split_name}: {save_path} ({len(df):,} rows)")
|
|
|
|
# Save metadata
|
|
metadata = {
|
|
'symbol': symbol,
|
|
'timeframe': timeframe,
|
|
'created_at': datetime.now().isoformat(),
|
|
'config': self.config,
|
|
'splits': {
|
|
name: {
|
|
'rows': len(df),
|
|
'columns': list(df.columns),
|
|
'date_range': {
|
|
'start': str(df.index.min()),
|
|
'end': str(df.index.max())
|
|
}
|
|
}
|
|
for name, df in splits.items()
|
|
}
|
|
}
|
|
|
|
metadata_path = self.output_dir / symbol / timeframe / 'metadata.yaml'
|
|
with open(metadata_path, 'w') as f:
|
|
yaml.dump(metadata, f, default_flow_style=False)
|
|
|
|
logger.info(f"Saved metadata: {metadata_path}")
|
|
|
|
def prepare_all_timeframes(
|
|
self,
|
|
symbol: str,
|
|
timeframes: Optional[List[str]] = None
|
|
) -> Dict[str, Dict[str, pd.DataFrame]]:
|
|
"""
|
|
Prepare datasets for all timeframes for a symbol.
|
|
|
|
Args:
|
|
symbol: Trading symbol
|
|
timeframes: List of timeframes (defaults to all)
|
|
|
|
Returns:
|
|
Nested dictionary of splits by timeframe
|
|
"""
|
|
if timeframes is None:
|
|
timeframes = list(self.TIMEFRAME_CONFIG.keys())
|
|
|
|
results = {}
|
|
for tf in timeframes:
|
|
try:
|
|
results[tf] = self.prepare_symbol_timeframe(symbol, tf)
|
|
except Exception as e:
|
|
logger.error(f"Failed to prepare {symbol}@{tf}: {e}")
|
|
results[tf] = {}
|
|
|
|
return results
|
|
|
|
def prepare_all_symbols(
|
|
self,
|
|
symbols: Optional[List[str]] = None,
|
|
timeframes: Optional[List[str]] = None
|
|
) -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
|
|
"""
|
|
Prepare datasets for all symbols and timeframes.
|
|
|
|
Args:
|
|
symbols: List of symbols (defaults to available in DB)
|
|
timeframes: List of timeframes (defaults to all)
|
|
|
|
Returns:
|
|
Nested dictionary of splits by symbol and timeframe
|
|
"""
|
|
if symbols is None:
|
|
symbols = self.db_manager.db.get_available_symbols()
|
|
logger.info(f"Found {len(symbols)} symbols in database")
|
|
|
|
results = {}
|
|
for symbol in symbols:
|
|
logger.info(f"\n{'='*60}")
|
|
logger.info(f"Processing {symbol}")
|
|
logger.info(f"{'='*60}\n")
|
|
results[symbol] = self.prepare_all_timeframes(symbol, timeframes)
|
|
|
|
return results
|
|
|
|
def generate_report(self) -> str:
|
|
"""Generate summary report of prepared datasets"""
|
|
report_lines = [
|
|
"=" * 70,
|
|
"DATASET PREPARATION REPORT",
|
|
f"Generated: {datetime.now().isoformat()}",
|
|
"=" * 70,
|
|
""
|
|
]
|
|
|
|
# Walk through output directory
|
|
for symbol_dir in self.output_dir.iterdir():
|
|
if not symbol_dir.is_dir():
|
|
continue
|
|
|
|
report_lines.append(f"Symbol: {symbol_dir.name}")
|
|
report_lines.append("-" * 50)
|
|
|
|
for tf_dir in symbol_dir.iterdir():
|
|
if not tf_dir.is_dir():
|
|
continue
|
|
|
|
metadata_path = tf_dir / 'metadata.yaml'
|
|
if metadata_path.exists():
|
|
with open(metadata_path, 'r') as f:
|
|
metadata = yaml.safe_load(f)
|
|
|
|
report_lines.append(f" Timeframe: {tf_dir.name}")
|
|
for split_name, info in metadata['splits'].items():
|
|
report_lines.append(
|
|
f" {split_name}: {info['rows']:,} rows "
|
|
f"({info['date_range']['start']} to {info['date_range']['end']})"
|
|
)
|
|
report_lines.append("")
|
|
|
|
report = "\n".join(report_lines)
|
|
logger.info(report)
|
|
|
|
# Save report
|
|
report_path = self.output_dir / 'preparation_report.txt'
|
|
with open(report_path, 'w') as f:
|
|
f.write(report)
|
|
|
|
return report
|
|
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Prepare multi-timeframe datasets for ML training"
|
|
)
|
|
parser.add_argument(
|
|
'--symbol',
|
|
type=str,
|
|
help='Symbol to process (e.g., XAUUSD)'
|
|
)
|
|
parser.add_argument(
|
|
'--timeframes',
|
|
type=str,
|
|
default='5m,15m,1H,4H,D',
|
|
help='Comma-separated list of timeframes (default: 5m,15m,1H,4H,D)'
|
|
)
|
|
parser.add_argument(
|
|
'--all-symbols',
|
|
action='store_true',
|
|
help='Process all available symbols'
|
|
)
|
|
parser.add_argument(
|
|
'--output-dir',
|
|
type=str,
|
|
default='datasets',
|
|
help='Output directory for datasets (default: datasets)'
|
|
)
|
|
parser.add_argument(
|
|
'--config',
|
|
type=str,
|
|
default='config/validation_oos.yaml',
|
|
help='Path to validation config (default: config/validation_oos.yaml)'
|
|
)
|
|
parser.add_argument(
|
|
'--report-only',
|
|
action='store_true',
|
|
help='Only generate report of existing datasets'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize preparer
|
|
preparer = DatasetPreparer(
|
|
output_dir=args.output_dir,
|
|
config_path=args.config
|
|
)
|
|
|
|
if args.report_only:
|
|
preparer.generate_report()
|
|
return
|
|
|
|
timeframes = args.timeframes.split(',')
|
|
|
|
if args.all_symbols:
|
|
preparer.prepare_all_symbols(timeframes=timeframes)
|
|
elif args.symbol:
|
|
preparer.prepare_all_timeframes(args.symbol, timeframes=timeframes)
|
|
else:
|
|
# Default: prepare XAUUSD
|
|
logger.info("No symbol specified, using XAUUSD")
|
|
preparer.prepare_all_timeframes('XAUUSD', timeframes=timeframes)
|
|
|
|
# Generate report
|
|
preparer.generate_report()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|