Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)
Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations
Note: Trained models (*.joblib, *.pt) are gitignored.
Regenerate with training scripts.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
529 lines
17 KiB
Python
529 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Data Validation Script for ML-First Strategy
|
|
=============================================
|
|
Validates data quality, temporal splits, and readiness for training.
|
|
|
|
Usage:
|
|
python scripts/validate_data.py --check-db
|
|
python scripts/validate_data.py --check-splits
|
|
python scripts/validate_data.py --full-validation
|
|
|
|
Author: ML-Specialist (NEXUS v4.0)
|
|
Created: 2026-01-04
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Tuple
|
|
import pandas as pd
|
|
import numpy as np
|
|
import yaml
|
|
from loguru import logger
|
|
from dataclasses import dataclass
|
|
|
|
# Add src to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
|
|
|
from data.database import DatabaseManager
|
|
|
|
|
|
# Configure logging
|
|
logger.remove()
|
|
logger.add(
|
|
sys.stdout,
|
|
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
|
|
level="INFO"
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ValidationResult:
|
|
"""Result of a validation check"""
|
|
name: str
|
|
passed: bool
|
|
message: str
|
|
details: Optional[Dict] = None
|
|
|
|
|
|
class DataValidator:
|
|
"""
|
|
Validates data quality and readiness for ML training.
|
|
"""
|
|
|
|
def __init__(self, config_path: str = "config/validation_oos.yaml"):
|
|
"""Initialize the data validator"""
|
|
self.db_manager = DatabaseManager()
|
|
self.results: List[ValidationResult] = []
|
|
|
|
# Load config
|
|
config_file = Path(config_path)
|
|
if config_file.exists():
|
|
with open(config_file, 'r') as f:
|
|
self.config = yaml.safe_load(f)
|
|
else:
|
|
logger.warning(f"Config not found: {config_path}")
|
|
self.config = {}
|
|
|
|
def check_database_connection(self) -> ValidationResult:
|
|
"""Check database connectivity"""
|
|
try:
|
|
# Test query
|
|
symbols = self.db_manager.db.get_available_symbols()
|
|
result = ValidationResult(
|
|
name="Database Connection",
|
|
passed=True,
|
|
message=f"Connected successfully. Found {len(symbols)} symbols.",
|
|
details={'symbols': symbols[:10]} # First 10
|
|
)
|
|
except Exception as e:
|
|
result = ValidationResult(
|
|
name="Database Connection",
|
|
passed=False,
|
|
message=f"Connection failed: {str(e)}"
|
|
)
|
|
|
|
self.results.append(result)
|
|
return result
|
|
|
|
def check_symbol_data(
|
|
self,
|
|
symbol: str,
|
|
min_records: int = 10000
|
|
) -> ValidationResult:
|
|
"""Check if symbol has sufficient data"""
|
|
try:
|
|
df = self.db_manager.db.get_ticker_data(symbol, limit=1)
|
|
|
|
if df.empty:
|
|
result = ValidationResult(
|
|
name=f"Symbol Data: {symbol}",
|
|
passed=False,
|
|
message=f"No data found for {symbol}"
|
|
)
|
|
else:
|
|
# Get full count and date range
|
|
df_full = self.db_manager.db.get_ticker_data(symbol, limit=500000)
|
|
count = len(df_full)
|
|
date_range = f"{df_full.index.min()} to {df_full.index.max()}"
|
|
|
|
result = ValidationResult(
|
|
name=f"Symbol Data: {symbol}",
|
|
passed=count >= min_records,
|
|
message=f"{count:,} records ({date_range})",
|
|
details={
|
|
'count': count,
|
|
'min_date': str(df_full.index.min()),
|
|
'max_date': str(df_full.index.max()),
|
|
'columns': list(df_full.columns)
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
result = ValidationResult(
|
|
name=f"Symbol Data: {symbol}",
|
|
passed=False,
|
|
message=f"Error checking data: {str(e)}"
|
|
)
|
|
|
|
self.results.append(result)
|
|
return result
|
|
|
|
def check_data_quality(
|
|
self,
|
|
symbol: str,
|
|
sample_size: int = 10000
|
|
) -> ValidationResult:
|
|
"""Check data quality (gaps, nulls, outliers)"""
|
|
try:
|
|
df = self.db_manager.db.get_ticker_data(symbol, limit=sample_size)
|
|
|
|
if df.empty:
|
|
return ValidationResult(
|
|
name=f"Data Quality: {symbol}",
|
|
passed=False,
|
|
message="No data to validate"
|
|
)
|
|
|
|
# Check for nulls
|
|
null_counts = df.isnull().sum()
|
|
null_pct = (null_counts / len(df) * 100).round(2)
|
|
|
|
# Check for gaps
|
|
time_diffs = df.index.to_series().diff().dropna()
|
|
expected_interval = pd.Timedelta(minutes=5) # Assuming 5-min data
|
|
gaps = time_diffs[time_diffs > expected_interval * 2]
|
|
|
|
# Check for outliers in price
|
|
if 'close' in df.columns:
|
|
price_returns = df['close'].pct_change().dropna()
|
|
outlier_threshold = price_returns.std() * 5
|
|
outliers = (price_returns.abs() > outlier_threshold).sum()
|
|
else:
|
|
outliers = 0
|
|
|
|
issues = []
|
|
if null_counts.sum() > 0:
|
|
issues.append(f"Nulls: {null_counts.sum()}")
|
|
if len(gaps) > 0:
|
|
issues.append(f"Time gaps: {len(gaps)}")
|
|
if outliers > 0:
|
|
issues.append(f"Price outliers: {outliers}")
|
|
|
|
passed = len(issues) == 0
|
|
|
|
result = ValidationResult(
|
|
name=f"Data Quality: {symbol}",
|
|
passed=passed,
|
|
message="OK" if passed else ", ".join(issues),
|
|
details={
|
|
'null_counts': null_pct.to_dict(),
|
|
'gap_count': len(gaps),
|
|
'outlier_count': outliers,
|
|
'sample_size': len(df)
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
result = ValidationResult(
|
|
name=f"Data Quality: {symbol}",
|
|
passed=False,
|
|
message=f"Error: {str(e)}"
|
|
)
|
|
|
|
self.results.append(result)
|
|
return result
|
|
|
|
def check_temporal_coverage(
|
|
self,
|
|
symbol: str
|
|
) -> ValidationResult:
|
|
"""Check if data covers required time periods for training/validation"""
|
|
try:
|
|
df = self.db_manager.db.get_ticker_data(symbol, limit=500000)
|
|
|
|
if df.empty:
|
|
return ValidationResult(
|
|
name=f"Temporal Coverage: {symbol}",
|
|
passed=False,
|
|
message="No data"
|
|
)
|
|
|
|
# Required periods from config
|
|
train_start = pd.to_datetime(self.config.get('validation', {}).get('train', {}).get('start_date', '2023-01-01'))
|
|
train_end = pd.to_datetime(self.config.get('validation', {}).get('train', {}).get('end_date', '2024-12-31'))
|
|
test_start = pd.to_datetime(self.config.get('validation', {}).get('test_oos', {}).get('start_date', '2025-01-01'))
|
|
test_end = pd.to_datetime(self.config.get('validation', {}).get('test_oos', {}).get('end_date', '2025-12-31'))
|
|
|
|
data_start = df.index.min()
|
|
data_end = df.index.max()
|
|
|
|
# Check coverage
|
|
train_covered = data_start <= train_start and data_end >= train_end
|
|
test_covered = data_start <= test_start and data_end >= test_end
|
|
|
|
# Count samples per period
|
|
train_mask = (df.index >= train_start) & (df.index <= train_end)
|
|
test_mask = (df.index >= test_start) & (df.index <= test_end)
|
|
|
|
train_count = train_mask.sum()
|
|
test_count = test_mask.sum()
|
|
|
|
# Year breakdown
|
|
year_counts = df.groupby(df.index.year).size().to_dict()
|
|
|
|
passed = train_covered and test_covered and train_count > 10000 and test_count > 1000
|
|
|
|
result = ValidationResult(
|
|
name=f"Temporal Coverage: {symbol}",
|
|
passed=passed,
|
|
message=f"Train: {train_count:,}, Test OOS: {test_count:,}",
|
|
details={
|
|
'data_range': f"{data_start} to {data_end}",
|
|
'train_samples': train_count,
|
|
'test_samples': test_count,
|
|
'year_counts': year_counts,
|
|
'train_covered': train_covered,
|
|
'test_covered': test_covered
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
result = ValidationResult(
|
|
name=f"Temporal Coverage: {symbol}",
|
|
passed=False,
|
|
message=f"Error: {str(e)}"
|
|
)
|
|
|
|
self.results.append(result)
|
|
return result
|
|
|
|
def check_required_columns(
|
|
self,
|
|
symbol: str
|
|
) -> ValidationResult:
|
|
"""Check if all required columns exist"""
|
|
required_columns = [
|
|
'open', 'high', 'low', 'close', 'volume',
|
|
'rsi', 'macd_histogram', 'macd_signal',
|
|
'sma_10', 'sma_20', 'atr'
|
|
]
|
|
|
|
try:
|
|
df = self.db_manager.db.get_ticker_data(symbol, limit=100)
|
|
|
|
if df.empty:
|
|
return ValidationResult(
|
|
name=f"Required Columns: {symbol}",
|
|
passed=False,
|
|
message="No data"
|
|
)
|
|
|
|
available = set(df.columns)
|
|
required = set(required_columns)
|
|
missing = required - available
|
|
|
|
result = ValidationResult(
|
|
name=f"Required Columns: {symbol}",
|
|
passed=len(missing) == 0,
|
|
message="All required columns present" if len(missing) == 0 else f"Missing: {missing}",
|
|
details={
|
|
'available': list(available),
|
|
'missing': list(missing),
|
|
'total_columns': len(df.columns)
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
result = ValidationResult(
|
|
name=f"Required Columns: {symbol}",
|
|
passed=False,
|
|
message=f"Error: {str(e)}"
|
|
)
|
|
|
|
self.results.append(result)
|
|
return result
|
|
|
|
def check_prepared_datasets(
|
|
self,
|
|
datasets_dir: str = "datasets"
|
|
) -> List[ValidationResult]:
|
|
"""Check prepared dataset files"""
|
|
results = []
|
|
datasets_path = Path(datasets_dir)
|
|
|
|
if not datasets_path.exists():
|
|
result = ValidationResult(
|
|
name="Prepared Datasets",
|
|
passed=False,
|
|
message=f"Directory not found: {datasets_dir}"
|
|
)
|
|
results.append(result)
|
|
self.results.append(result)
|
|
return results
|
|
|
|
for symbol_dir in datasets_path.iterdir():
|
|
if not symbol_dir.is_dir():
|
|
continue
|
|
|
|
for tf_dir in symbol_dir.iterdir():
|
|
if not tf_dir.is_dir():
|
|
continue
|
|
|
|
# Check for required files
|
|
train_file = tf_dir / 'train.parquet'
|
|
val_file = tf_dir / 'val.parquet'
|
|
test_file = tf_dir / 'test_oos.parquet'
|
|
metadata_file = tf_dir / 'metadata.yaml'
|
|
|
|
files_exist = {
|
|
'train': train_file.exists(),
|
|
'val': val_file.exists(),
|
|
'test_oos': test_file.exists(),
|
|
'metadata': metadata_file.exists()
|
|
}
|
|
|
|
all_exist = all(files_exist.values())
|
|
|
|
# Get sizes if files exist
|
|
sizes = {}
|
|
if train_file.exists():
|
|
sizes['train'] = len(pd.read_parquet(train_file))
|
|
if val_file.exists():
|
|
sizes['val'] = len(pd.read_parquet(val_file))
|
|
if test_file.exists():
|
|
sizes['test_oos'] = len(pd.read_parquet(test_file))
|
|
|
|
result = ValidationResult(
|
|
name=f"Dataset: {symbol_dir.name}/{tf_dir.name}",
|
|
passed=all_exist,
|
|
message=f"OK - Train: {sizes.get('train', 0):,}, Val: {sizes.get('val', 0):,}, Test: {sizes.get('test_oos', 0):,}" if all_exist else f"Missing files: {[k for k, v in files_exist.items() if not v]}",
|
|
details={
|
|
'files': files_exist,
|
|
'sizes': sizes
|
|
}
|
|
)
|
|
results.append(result)
|
|
self.results.append(result)
|
|
|
|
return results
|
|
|
|
def run_full_validation(
|
|
self,
|
|
symbols: Optional[List[str]] = None
|
|
) -> Dict:
|
|
"""Run complete validation suite"""
|
|
logger.info("=" * 70)
|
|
logger.info("STARTING FULL DATA VALIDATION")
|
|
logger.info("=" * 70)
|
|
|
|
# 1. Check database connection
|
|
logger.info("\n[1/5] Checking database connection...")
|
|
self.check_database_connection()
|
|
|
|
# 2. Get symbols if not provided
|
|
if symbols is None:
|
|
try:
|
|
symbols = self.db_manager.db.get_available_symbols()[:5] # First 5
|
|
except:
|
|
symbols = ['XAUUSD']
|
|
|
|
# 3. Check each symbol
|
|
logger.info(f"\n[2/5] Checking symbol data ({len(symbols)} symbols)...")
|
|
for symbol in symbols:
|
|
self.check_symbol_data(symbol)
|
|
|
|
# 4. Check data quality
|
|
logger.info(f"\n[3/5] Checking data quality...")
|
|
for symbol in symbols:
|
|
self.check_data_quality(symbol)
|
|
|
|
# 5. Check temporal coverage
|
|
logger.info(f"\n[4/5] Checking temporal coverage...")
|
|
for symbol in symbols:
|
|
self.check_temporal_coverage(symbol)
|
|
|
|
# 6. Check required columns
|
|
logger.info(f"\n[5/5] Checking required columns...")
|
|
for symbol in symbols:
|
|
self.check_required_columns(symbol)
|
|
|
|
# Generate report
|
|
return self.generate_report()
|
|
|
|
def generate_report(self) -> Dict:
|
|
"""Generate validation report"""
|
|
passed = sum(1 for r in self.results if r.passed)
|
|
failed = sum(1 for r in self.results if not r.passed)
|
|
total = len(self.results)
|
|
|
|
logger.info("\n" + "=" * 70)
|
|
logger.info("VALIDATION REPORT")
|
|
logger.info("=" * 70)
|
|
logger.info(f"Total checks: {total}")
|
|
logger.info(f"Passed: {passed} ({passed/total*100:.1f}%)")
|
|
logger.info(f"Failed: {failed} ({failed/total*100:.1f}%)")
|
|
logger.info("-" * 70)
|
|
|
|
for result in self.results:
|
|
status = "[PASS]" if result.passed else "[FAIL]"
|
|
logger.info(f"{status} {result.name}: {result.message}")
|
|
|
|
logger.info("=" * 70)
|
|
|
|
report = {
|
|
'timestamp': datetime.now().isoformat(),
|
|
'summary': {
|
|
'total': total,
|
|
'passed': passed,
|
|
'failed': failed,
|
|
'pass_rate': passed / total if total > 0 else 0
|
|
},
|
|
'results': [
|
|
{
|
|
'name': r.name,
|
|
'passed': r.passed,
|
|
'message': r.message,
|
|
'details': r.details
|
|
}
|
|
for r in self.results
|
|
]
|
|
}
|
|
|
|
return report
|
|
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Validate data quality and readiness for ML training"
|
|
)
|
|
parser.add_argument(
|
|
'--check-db',
|
|
action='store_true',
|
|
help='Check database connection only'
|
|
)
|
|
parser.add_argument(
|
|
'--check-splits',
|
|
action='store_true',
|
|
help='Check prepared dataset splits'
|
|
)
|
|
parser.add_argument(
|
|
'--full-validation',
|
|
action='store_true',
|
|
help='Run complete validation suite'
|
|
)
|
|
parser.add_argument(
|
|
'--symbol',
|
|
type=str,
|
|
help='Specific symbol to validate'
|
|
)
|
|
parser.add_argument(
|
|
'--config',
|
|
type=str,
|
|
default='config/validation_oos.yaml',
|
|
help='Path to validation config'
|
|
)
|
|
parser.add_argument(
|
|
'--datasets-dir',
|
|
type=str,
|
|
default='datasets',
|
|
help='Directory with prepared datasets'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize validator
|
|
validator = DataValidator(config_path=args.config)
|
|
|
|
if args.check_db:
|
|
result = validator.check_database_connection()
|
|
print(f"{'PASSED' if result.passed else 'FAILED'}: {result.message}")
|
|
|
|
elif args.check_splits:
|
|
results = validator.check_prepared_datasets(args.datasets_dir)
|
|
for r in results:
|
|
print(f"{'PASSED' if r.passed else 'FAILED'}: {r.name} - {r.message}")
|
|
|
|
elif args.full_validation or args.symbol:
|
|
symbols = [args.symbol] if args.symbol else None
|
|
report = validator.run_full_validation(symbols)
|
|
|
|
# Save report
|
|
report_path = Path('reports') / 'validation_report.yaml'
|
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(report_path, 'w') as f:
|
|
yaml.dump(report, f, default_flow_style=False)
|
|
logger.info(f"\nReport saved to: {report_path}")
|
|
|
|
else:
|
|
# Default: run full validation
|
|
report = validator.run_full_validation()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|