trading-platform-ml-engine-v2/scripts/prepare_datasets.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

530 lines
16 KiB
Python

#!/usr/bin/env python3
"""
Dataset Preparation Script for ML-First Strategy
=================================================
Prepares training datasets by timeframe with proper temporal splits.
Usage:
python scripts/prepare_datasets.py --symbol XAUUSD --timeframes 5m,15m,1H,4H,D
python scripts/prepare_datasets.py --all-symbols
Author: ML-Specialist (NEXUS v4.0)
Created: 2026-01-04
"""
import os
import sys
import argparse
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional
import pandas as pd
import numpy as np
import yaml
from loguru import logger
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from data.database import DatabaseManager
from data.pipeline import DataPipeline
from data.indicators import TechnicalIndicators
from training.data_splitter import TemporalDataSplitter, create_ml_first_splits
# Configure logging
logger.remove()
logger.add(
sys.stdout,
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
level="INFO"
)
class DatasetPreparer:
"""
Prepares multi-timeframe datasets for ML training.
"""
# Timeframe configuration
TIMEFRAME_CONFIG = {
'5m': {'periods': 1, 'resample': '5min', 'horizons': {'scalping': 6}},
'15m': {'periods': 3, 'resample': '15min', 'horizons': {'scalping': 4, 'intraday': 2}},
'1H': {'periods': 12, 'resample': '1H', 'horizons': {'intraday': 4, 'swing': 2}},
'4H': {'periods': 48, 'resample': '4H', 'horizons': {'swing': 6, 'position': 2}},
'D': {'periods': 288, 'resample': '1D', 'horizons': {'position': 5, 'weekly': 1}},
'W': {'periods': 2016, 'resample': '1W', 'horizons': {'weekly': 4}}
}
def __init__(
self,
output_dir: str = "datasets",
config_path: str = "config/validation_oos.yaml"
):
"""
Initialize the dataset preparer.
Args:
output_dir: Directory to save datasets
config_path: Path to validation configuration
"""
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.config_path = config_path
self.db_manager = DatabaseManager()
self.splitter = TemporalDataSplitter(config_path)
self.indicators = TechnicalIndicators()
# Load validation config
with open(config_path, 'r') as f:
self.config = yaml.safe_load(f)
def fetch_raw_data(
self,
symbol: str,
limit: int = 500000
) -> pd.DataFrame:
"""
Fetch raw data from MySQL database.
Args:
symbol: Trading symbol (e.g., 'XAUUSD')
limit: Maximum number of records
Returns:
DataFrame with OHLCV data
"""
logger.info(f"Fetching data for {symbol}...")
# Get data from database
df = self.db_manager.db.get_ticker_data(symbol, limit=limit)
if df.empty:
logger.warning(f"No data found for {symbol}")
return df
# Ensure proper datetime index
if not isinstance(df.index, pd.DatetimeIndex):
df.index = pd.to_datetime(df.index)
# Sort by time
df = df.sort_index()
logger.info(
f"Loaded {len(df):,} records for {symbol} "
f"({df.index.min()} to {df.index.max()})"
)
return df
def resample_data(
self,
df: pd.DataFrame,
timeframe: str
) -> pd.DataFrame:
"""
Resample data to specified timeframe.
Args:
df: DataFrame with 5-minute data
timeframe: Target timeframe (e.g., '15m', '1H', '4H', 'D', 'W')
Returns:
Resampled DataFrame
"""
if timeframe not in self.TIMEFRAME_CONFIG:
raise ValueError(f"Unknown timeframe: {timeframe}")
if timeframe == '5m':
# Already in 5-minute resolution
return df.copy()
resample_rule = self.TIMEFRAME_CONFIG[timeframe]['resample']
# OHLCV resampling rules
ohlcv_cols = ['open', 'high', 'low', 'close', 'volume']
available_cols = [col for col in ohlcv_cols if col in df.columns]
resample_dict = {}
if 'open' in available_cols:
resample_dict['open'] = 'first'
if 'high' in available_cols:
resample_dict['high'] = 'max'
if 'low' in available_cols:
resample_dict['low'] = 'min'
if 'close' in available_cols:
resample_dict['close'] = 'last'
if 'volume' in available_cols:
resample_dict['volume'] = 'sum'
df_resampled = df[available_cols].resample(resample_rule).agg(resample_dict)
df_resampled = df_resampled.dropna()
logger.info(
f"Resampled to {timeframe}: {len(df_resampled):,} bars "
f"({df_resampled.index.min()} to {df_resampled.index.max()})"
)
return df_resampled
def calculate_features(
self,
df: pd.DataFrame,
timeframe: str
) -> pd.DataFrame:
"""
Calculate technical indicators and features for the given timeframe.
Args:
df: OHLCV DataFrame
timeframe: Timeframe identifier
Returns:
DataFrame with features added
"""
logger.info(f"Calculating features for {timeframe}...")
# Calculate all indicators
df = self.indicators.calculate_all_indicators(df, minimal=True)
# Calculate rolling features with timeframe-appropriate windows
windows = self._get_rolling_windows(timeframe)
df = self.indicators.calculate_rolling_features(df, windows)
# Transform to ratios
df = self.indicators.transform_to_ratios(df)
# Drop NaN values
df = df.dropna()
logger.info(f"Features calculated: {len(df.columns)} columns, {len(df):,} rows")
return df
def _get_rolling_windows(self, timeframe: str) -> List[int]:
"""Get appropriate rolling windows for timeframe"""
window_config = {
'5m': [12, 48, 96], # 1h, 4h, 8h in 5m bars
'15m': [4, 16, 32], # 1h, 4h, 8h in 15m bars
'1H': [4, 12, 24], # 4h, 12h, 24h in 1H bars
'4H': [6, 12, 24], # 1d, 2d, 4d in 4H bars
'D': [5, 10, 20], # 1w, 2w, 1m in D bars
'W': [4, 8, 12] # 1m, 2m, 3m in W bars
}
return window_config.get(timeframe, [15, 60, 120])
def create_targets(
self,
df: pd.DataFrame,
timeframe: str
) -> pd.DataFrame:
"""
Create target variables for the given timeframe.
Args:
df: DataFrame with features
timeframe: Timeframe identifier
Returns:
DataFrame with targets added
"""
horizons = self.TIMEFRAME_CONFIG[timeframe]['horizons']
for horizon_name, periods in horizons.items():
# Future high
future_highs = [df['high'].shift(-i) for i in range(1, periods + 1)]
df[f'target_max_high_{horizon_name}'] = pd.concat(future_highs, axis=1).max(axis=1)
# Future low
future_lows = [df['low'].shift(-i) for i in range(1, periods + 1)]
df[f'target_min_low_{horizon_name}'] = pd.concat(future_lows, axis=1).min(axis=1)
# Future close
df[f'target_close_{horizon_name}'] = df['close'].shift(-periods)
# Delta ratios (targets for regression)
df[f'target_delta_high_{horizon_name}'] = (
df[f'target_max_high_{horizon_name}'] / df['close'] - 1
)
df[f'target_delta_low_{horizon_name}'] = (
df[f'target_min_low_{horizon_name}'] / df['close'] - 1
)
df[f'target_delta_close_{horizon_name}'] = (
df[f'target_close_{horizon_name}'] / df['close'] - 1
)
# Direction (target for classification)
df[f'target_direction_{horizon_name}'] = (
df[f'target_close_{horizon_name}'] > df['close']
).astype(int)
# Remove rows with NaN targets
target_cols = [col for col in df.columns if col.startswith('target_')]
df = df.dropna(subset=target_cols)
logger.info(f"Targets created: {len(target_cols)} target columns")
return df
def prepare_symbol_timeframe(
self,
symbol: str,
timeframe: str,
save: bool = True
) -> Dict[str, pd.DataFrame]:
"""
Prepare complete dataset for a symbol and timeframe.
Args:
symbol: Trading symbol
timeframe: Target timeframe
save: Whether to save to disk
Returns:
Dictionary with train/val/test_oos DataFrames
"""
logger.info(f"=" * 60)
logger.info(f"Preparing {symbol} @ {timeframe}")
logger.info(f"=" * 60)
# Step 1: Fetch raw data
df_raw = self.fetch_raw_data(symbol)
if df_raw.empty:
return {}
# Step 2: Resample if needed
df = self.resample_data(df_raw, timeframe)
# Step 3: Calculate features
df = self.calculate_features(df, timeframe)
# Step 4: Create targets
df = self.create_targets(df, timeframe)
# Step 5: Show data summary
self.splitter.print_data_summary(df)
# Step 6: Create temporal splits
splits = create_ml_first_splits(df, self.config_path)
# Step 7: Save datasets
if save:
self._save_datasets(splits, symbol, timeframe)
return splits
def _save_datasets(
self,
splits: Dict[str, pd.DataFrame],
symbol: str,
timeframe: str
):
"""Save dataset splits to parquet files"""
for split_name, df in splits.items():
# Create directory structure
save_dir = self.output_dir / symbol / timeframe
save_dir.mkdir(parents=True, exist_ok=True)
# Save as parquet
save_path = save_dir / f"{split_name}.parquet"
df.to_parquet(save_path, engine='pyarrow', compression='snappy')
logger.info(f"Saved {split_name}: {save_path} ({len(df):,} rows)")
# Save metadata
metadata = {
'symbol': symbol,
'timeframe': timeframe,
'created_at': datetime.now().isoformat(),
'config': self.config,
'splits': {
name: {
'rows': len(df),
'columns': list(df.columns),
'date_range': {
'start': str(df.index.min()),
'end': str(df.index.max())
}
}
for name, df in splits.items()
}
}
metadata_path = self.output_dir / symbol / timeframe / 'metadata.yaml'
with open(metadata_path, 'w') as f:
yaml.dump(metadata, f, default_flow_style=False)
logger.info(f"Saved metadata: {metadata_path}")
def prepare_all_timeframes(
self,
symbol: str,
timeframes: Optional[List[str]] = None
) -> Dict[str, Dict[str, pd.DataFrame]]:
"""
Prepare datasets for all timeframes for a symbol.
Args:
symbol: Trading symbol
timeframes: List of timeframes (defaults to all)
Returns:
Nested dictionary of splits by timeframe
"""
if timeframes is None:
timeframes = list(self.TIMEFRAME_CONFIG.keys())
results = {}
for tf in timeframes:
try:
results[tf] = self.prepare_symbol_timeframe(symbol, tf)
except Exception as e:
logger.error(f"Failed to prepare {symbol}@{tf}: {e}")
results[tf] = {}
return results
def prepare_all_symbols(
self,
symbols: Optional[List[str]] = None,
timeframes: Optional[List[str]] = None
) -> Dict[str, Dict[str, Dict[str, pd.DataFrame]]]:
"""
Prepare datasets for all symbols and timeframes.
Args:
symbols: List of symbols (defaults to available in DB)
timeframes: List of timeframes (defaults to all)
Returns:
Nested dictionary of splits by symbol and timeframe
"""
if symbols is None:
symbols = self.db_manager.db.get_available_symbols()
logger.info(f"Found {len(symbols)} symbols in database")
results = {}
for symbol in symbols:
logger.info(f"\n{'='*60}")
logger.info(f"Processing {symbol}")
logger.info(f"{'='*60}\n")
results[symbol] = self.prepare_all_timeframes(symbol, timeframes)
return results
def generate_report(self) -> str:
"""Generate summary report of prepared datasets"""
report_lines = [
"=" * 70,
"DATASET PREPARATION REPORT",
f"Generated: {datetime.now().isoformat()}",
"=" * 70,
""
]
# Walk through output directory
for symbol_dir in self.output_dir.iterdir():
if not symbol_dir.is_dir():
continue
report_lines.append(f"Symbol: {symbol_dir.name}")
report_lines.append("-" * 50)
for tf_dir in symbol_dir.iterdir():
if not tf_dir.is_dir():
continue
metadata_path = tf_dir / 'metadata.yaml'
if metadata_path.exists():
with open(metadata_path, 'r') as f:
metadata = yaml.safe_load(f)
report_lines.append(f" Timeframe: {tf_dir.name}")
for split_name, info in metadata['splits'].items():
report_lines.append(
f" {split_name}: {info['rows']:,} rows "
f"({info['date_range']['start']} to {info['date_range']['end']})"
)
report_lines.append("")
report = "\n".join(report_lines)
logger.info(report)
# Save report
report_path = self.output_dir / 'preparation_report.txt'
with open(report_path, 'w') as f:
f.write(report)
return report
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(
description="Prepare multi-timeframe datasets for ML training"
)
parser.add_argument(
'--symbol',
type=str,
help='Symbol to process (e.g., XAUUSD)'
)
parser.add_argument(
'--timeframes',
type=str,
default='5m,15m,1H,4H,D',
help='Comma-separated list of timeframes (default: 5m,15m,1H,4H,D)'
)
parser.add_argument(
'--all-symbols',
action='store_true',
help='Process all available symbols'
)
parser.add_argument(
'--output-dir',
type=str,
default='datasets',
help='Output directory for datasets (default: datasets)'
)
parser.add_argument(
'--config',
type=str,
default='config/validation_oos.yaml',
help='Path to validation config (default: config/validation_oos.yaml)'
)
parser.add_argument(
'--report-only',
action='store_true',
help='Only generate report of existing datasets'
)
args = parser.parse_args()
# Initialize preparer
preparer = DatasetPreparer(
output_dir=args.output_dir,
config_path=args.config
)
if args.report_only:
preparer.generate_report()
return
timeframes = args.timeframes.split(',')
if args.all_symbols:
preparer.prepare_all_symbols(timeframes=timeframes)
elif args.symbol:
preparer.prepare_all_timeframes(args.symbol, timeframes=timeframes)
else:
# Default: prepare XAUUSD
logger.info("No symbol specified, using XAUUSD")
preparer.prepare_all_timeframes('XAUUSD', timeframes=timeframes)
# Generate report
preparer.generate_report()
if __name__ == "__main__":
main()