trading-platform-ml-engine-v2/scripts/validate_data.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

529 lines
17 KiB
Python

#!/usr/bin/env python3
"""
Data Validation Script for ML-First Strategy
=============================================
Validates data quality, temporal splits, and readiness for training.
Usage:
python scripts/validate_data.py --check-db
python scripts/validate_data.py --check-splits
python scripts/validate_data.py --full-validation
Author: ML-Specialist (NEXUS v4.0)
Created: 2026-01-04
"""
import os
import sys
import argparse
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import pandas as pd
import numpy as np
import yaml
from loguru import logger
from dataclasses import dataclass
# Add src to path
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from data.database import DatabaseManager
# Configure logging
logger.remove()
logger.add(
sys.stdout,
format="<green>{time:HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{message}</cyan>",
level="INFO"
)
@dataclass
class ValidationResult:
"""Result of a validation check"""
name: str
passed: bool
message: str
details: Optional[Dict] = None
class DataValidator:
"""
Validates data quality and readiness for ML training.
"""
def __init__(self, config_path: str = "config/validation_oos.yaml"):
"""Initialize the data validator"""
self.db_manager = DatabaseManager()
self.results: List[ValidationResult] = []
# Load config
config_file = Path(config_path)
if config_file.exists():
with open(config_file, 'r') as f:
self.config = yaml.safe_load(f)
else:
logger.warning(f"Config not found: {config_path}")
self.config = {}
def check_database_connection(self) -> ValidationResult:
"""Check database connectivity"""
try:
# Test query
symbols = self.db_manager.db.get_available_symbols()
result = ValidationResult(
name="Database Connection",
passed=True,
message=f"Connected successfully. Found {len(symbols)} symbols.",
details={'symbols': symbols[:10]} # First 10
)
except Exception as e:
result = ValidationResult(
name="Database Connection",
passed=False,
message=f"Connection failed: {str(e)}"
)
self.results.append(result)
return result
def check_symbol_data(
self,
symbol: str,
min_records: int = 10000
) -> ValidationResult:
"""Check if symbol has sufficient data"""
try:
df = self.db_manager.db.get_ticker_data(symbol, limit=1)
if df.empty:
result = ValidationResult(
name=f"Symbol Data: {symbol}",
passed=False,
message=f"No data found for {symbol}"
)
else:
# Get full count and date range
df_full = self.db_manager.db.get_ticker_data(symbol, limit=500000)
count = len(df_full)
date_range = f"{df_full.index.min()} to {df_full.index.max()}"
result = ValidationResult(
name=f"Symbol Data: {symbol}",
passed=count >= min_records,
message=f"{count:,} records ({date_range})",
details={
'count': count,
'min_date': str(df_full.index.min()),
'max_date': str(df_full.index.max()),
'columns': list(df_full.columns)
}
)
except Exception as e:
result = ValidationResult(
name=f"Symbol Data: {symbol}",
passed=False,
message=f"Error checking data: {str(e)}"
)
self.results.append(result)
return result
def check_data_quality(
self,
symbol: str,
sample_size: int = 10000
) -> ValidationResult:
"""Check data quality (gaps, nulls, outliers)"""
try:
df = self.db_manager.db.get_ticker_data(symbol, limit=sample_size)
if df.empty:
return ValidationResult(
name=f"Data Quality: {symbol}",
passed=False,
message="No data to validate"
)
# Check for nulls
null_counts = df.isnull().sum()
null_pct = (null_counts / len(df) * 100).round(2)
# Check for gaps
time_diffs = df.index.to_series().diff().dropna()
expected_interval = pd.Timedelta(minutes=5) # Assuming 5-min data
gaps = time_diffs[time_diffs > expected_interval * 2]
# Check for outliers in price
if 'close' in df.columns:
price_returns = df['close'].pct_change().dropna()
outlier_threshold = price_returns.std() * 5
outliers = (price_returns.abs() > outlier_threshold).sum()
else:
outliers = 0
issues = []
if null_counts.sum() > 0:
issues.append(f"Nulls: {null_counts.sum()}")
if len(gaps) > 0:
issues.append(f"Time gaps: {len(gaps)}")
if outliers > 0:
issues.append(f"Price outliers: {outliers}")
passed = len(issues) == 0
result = ValidationResult(
name=f"Data Quality: {symbol}",
passed=passed,
message="OK" if passed else ", ".join(issues),
details={
'null_counts': null_pct.to_dict(),
'gap_count': len(gaps),
'outlier_count': outliers,
'sample_size': len(df)
}
)
except Exception as e:
result = ValidationResult(
name=f"Data Quality: {symbol}",
passed=False,
message=f"Error: {str(e)}"
)
self.results.append(result)
return result
def check_temporal_coverage(
self,
symbol: str
) -> ValidationResult:
"""Check if data covers required time periods for training/validation"""
try:
df = self.db_manager.db.get_ticker_data(symbol, limit=500000)
if df.empty:
return ValidationResult(
name=f"Temporal Coverage: {symbol}",
passed=False,
message="No data"
)
# Required periods from config
train_start = pd.to_datetime(self.config.get('validation', {}).get('train', {}).get('start_date', '2023-01-01'))
train_end = pd.to_datetime(self.config.get('validation', {}).get('train', {}).get('end_date', '2024-12-31'))
test_start = pd.to_datetime(self.config.get('validation', {}).get('test_oos', {}).get('start_date', '2025-01-01'))
test_end = pd.to_datetime(self.config.get('validation', {}).get('test_oos', {}).get('end_date', '2025-12-31'))
data_start = df.index.min()
data_end = df.index.max()
# Check coverage
train_covered = data_start <= train_start and data_end >= train_end
test_covered = data_start <= test_start and data_end >= test_end
# Count samples per period
train_mask = (df.index >= train_start) & (df.index <= train_end)
test_mask = (df.index >= test_start) & (df.index <= test_end)
train_count = train_mask.sum()
test_count = test_mask.sum()
# Year breakdown
year_counts = df.groupby(df.index.year).size().to_dict()
passed = train_covered and test_covered and train_count > 10000 and test_count > 1000
result = ValidationResult(
name=f"Temporal Coverage: {symbol}",
passed=passed,
message=f"Train: {train_count:,}, Test OOS: {test_count:,}",
details={
'data_range': f"{data_start} to {data_end}",
'train_samples': train_count,
'test_samples': test_count,
'year_counts': year_counts,
'train_covered': train_covered,
'test_covered': test_covered
}
)
except Exception as e:
result = ValidationResult(
name=f"Temporal Coverage: {symbol}",
passed=False,
message=f"Error: {str(e)}"
)
self.results.append(result)
return result
def check_required_columns(
self,
symbol: str
) -> ValidationResult:
"""Check if all required columns exist"""
required_columns = [
'open', 'high', 'low', 'close', 'volume',
'rsi', 'macd_histogram', 'macd_signal',
'sma_10', 'sma_20', 'atr'
]
try:
df = self.db_manager.db.get_ticker_data(symbol, limit=100)
if df.empty:
return ValidationResult(
name=f"Required Columns: {symbol}",
passed=False,
message="No data"
)
available = set(df.columns)
required = set(required_columns)
missing = required - available
result = ValidationResult(
name=f"Required Columns: {symbol}",
passed=len(missing) == 0,
message="All required columns present" if len(missing) == 0 else f"Missing: {missing}",
details={
'available': list(available),
'missing': list(missing),
'total_columns': len(df.columns)
}
)
except Exception as e:
result = ValidationResult(
name=f"Required Columns: {symbol}",
passed=False,
message=f"Error: {str(e)}"
)
self.results.append(result)
return result
def check_prepared_datasets(
self,
datasets_dir: str = "datasets"
) -> List[ValidationResult]:
"""Check prepared dataset files"""
results = []
datasets_path = Path(datasets_dir)
if not datasets_path.exists():
result = ValidationResult(
name="Prepared Datasets",
passed=False,
message=f"Directory not found: {datasets_dir}"
)
results.append(result)
self.results.append(result)
return results
for symbol_dir in datasets_path.iterdir():
if not symbol_dir.is_dir():
continue
for tf_dir in symbol_dir.iterdir():
if not tf_dir.is_dir():
continue
# Check for required files
train_file = tf_dir / 'train.parquet'
val_file = tf_dir / 'val.parquet'
test_file = tf_dir / 'test_oos.parquet'
metadata_file = tf_dir / 'metadata.yaml'
files_exist = {
'train': train_file.exists(),
'val': val_file.exists(),
'test_oos': test_file.exists(),
'metadata': metadata_file.exists()
}
all_exist = all(files_exist.values())
# Get sizes if files exist
sizes = {}
if train_file.exists():
sizes['train'] = len(pd.read_parquet(train_file))
if val_file.exists():
sizes['val'] = len(pd.read_parquet(val_file))
if test_file.exists():
sizes['test_oos'] = len(pd.read_parquet(test_file))
result = ValidationResult(
name=f"Dataset: {symbol_dir.name}/{tf_dir.name}",
passed=all_exist,
message=f"OK - Train: {sizes.get('train', 0):,}, Val: {sizes.get('val', 0):,}, Test: {sizes.get('test_oos', 0):,}" if all_exist else f"Missing files: {[k for k, v in files_exist.items() if not v]}",
details={
'files': files_exist,
'sizes': sizes
}
)
results.append(result)
self.results.append(result)
return results
def run_full_validation(
self,
symbols: Optional[List[str]] = None
) -> Dict:
"""Run complete validation suite"""
logger.info("=" * 70)
logger.info("STARTING FULL DATA VALIDATION")
logger.info("=" * 70)
# 1. Check database connection
logger.info("\n[1/5] Checking database connection...")
self.check_database_connection()
# 2. Get symbols if not provided
if symbols is None:
try:
symbols = self.db_manager.db.get_available_symbols()[:5] # First 5
except:
symbols = ['XAUUSD']
# 3. Check each symbol
logger.info(f"\n[2/5] Checking symbol data ({len(symbols)} symbols)...")
for symbol in symbols:
self.check_symbol_data(symbol)
# 4. Check data quality
logger.info(f"\n[3/5] Checking data quality...")
for symbol in symbols:
self.check_data_quality(symbol)
# 5. Check temporal coverage
logger.info(f"\n[4/5] Checking temporal coverage...")
for symbol in symbols:
self.check_temporal_coverage(symbol)
# 6. Check required columns
logger.info(f"\n[5/5] Checking required columns...")
for symbol in symbols:
self.check_required_columns(symbol)
# Generate report
return self.generate_report()
def generate_report(self) -> Dict:
"""Generate validation report"""
passed = sum(1 for r in self.results if r.passed)
failed = sum(1 for r in self.results if not r.passed)
total = len(self.results)
logger.info("\n" + "=" * 70)
logger.info("VALIDATION REPORT")
logger.info("=" * 70)
logger.info(f"Total checks: {total}")
logger.info(f"Passed: {passed} ({passed/total*100:.1f}%)")
logger.info(f"Failed: {failed} ({failed/total*100:.1f}%)")
logger.info("-" * 70)
for result in self.results:
status = "[PASS]" if result.passed else "[FAIL]"
logger.info(f"{status} {result.name}: {result.message}")
logger.info("=" * 70)
report = {
'timestamp': datetime.now().isoformat(),
'summary': {
'total': total,
'passed': passed,
'failed': failed,
'pass_rate': passed / total if total > 0 else 0
},
'results': [
{
'name': r.name,
'passed': r.passed,
'message': r.message,
'details': r.details
}
for r in self.results
]
}
return report
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(
description="Validate data quality and readiness for ML training"
)
parser.add_argument(
'--check-db',
action='store_true',
help='Check database connection only'
)
parser.add_argument(
'--check-splits',
action='store_true',
help='Check prepared dataset splits'
)
parser.add_argument(
'--full-validation',
action='store_true',
help='Run complete validation suite'
)
parser.add_argument(
'--symbol',
type=str,
help='Specific symbol to validate'
)
parser.add_argument(
'--config',
type=str,
default='config/validation_oos.yaml',
help='Path to validation config'
)
parser.add_argument(
'--datasets-dir',
type=str,
default='datasets',
help='Directory with prepared datasets'
)
args = parser.parse_args()
# Initialize validator
validator = DataValidator(config_path=args.config)
if args.check_db:
result = validator.check_database_connection()
print(f"{'PASSED' if result.passed else 'FAILED'}: {result.message}")
elif args.check_splits:
results = validator.check_prepared_datasets(args.datasets_dir)
for r in results:
print(f"{'PASSED' if r.passed else 'FAILED'}: {r.name} - {r.message}")
elif args.full_validation or args.symbol:
symbols = [args.symbol] if args.symbol else None
report = validator.run_full_validation(symbols)
# Save report
report_path = Path('reports') / 'validation_report.yaml'
report_path.parent.mkdir(parents=True, exist_ok=True)
with open(report_path, 'w') as f:
yaml.dump(report, f, default_flow_style=False)
logger.info(f"\nReport saved to: {report_path}")
else:
# Default: run full validation
report = validator.run_full_validation()
if __name__ == "__main__":
main()