268 lines
9.3 KiB
Python
268 lines
9.3 KiB
Python
"""
|
|
Tests for ICT/SMC Detector
|
|
"""
|
|
import pytest
|
|
import pandas as pd
|
|
import numpy as np
|
|
from datetime import datetime, timedelta
|
|
|
|
# Add parent directory to path
|
|
import sys
|
|
sys.path.insert(0, str(__file__).rsplit('/', 2)[0])
|
|
|
|
from src.models.ict_smc_detector import (
|
|
ICTSMCDetector,
|
|
ICTAnalysis,
|
|
OrderBlock,
|
|
FairValueGap,
|
|
MarketBias
|
|
)
|
|
|
|
|
|
class TestICTSMCDetector:
|
|
"""Test suite for ICT/SMC Detector"""
|
|
|
|
@pytest.fixture
|
|
def sample_ohlcv_data(self):
|
|
"""Generate sample OHLCV data for testing"""
|
|
np.random.seed(42)
|
|
n_periods = 200
|
|
|
|
# Generate trending price data
|
|
base_price = 1.1000
|
|
trend = np.cumsum(np.random.randn(n_periods) * 0.0005)
|
|
|
|
dates = pd.date_range(end=datetime.now(), periods=n_periods, freq='1H')
|
|
|
|
# Generate OHLCV
|
|
data = []
|
|
for i, date in enumerate(dates):
|
|
price = base_price + trend[i]
|
|
high = price + abs(np.random.randn() * 0.0010)
|
|
low = price - abs(np.random.randn() * 0.0010)
|
|
open_price = price + np.random.randn() * 0.0005
|
|
close = price + np.random.randn() * 0.0005
|
|
volume = np.random.randint(1000, 10000)
|
|
|
|
data.append({
|
|
'open': max(low, min(high, open_price)),
|
|
'high': high,
|
|
'low': low,
|
|
'close': max(low, min(high, close)),
|
|
'volume': volume
|
|
})
|
|
|
|
df = pd.DataFrame(data, index=dates)
|
|
return df
|
|
|
|
@pytest.fixture
|
|
def detector(self):
|
|
"""Create detector instance"""
|
|
return ICTSMCDetector(
|
|
swing_lookback=10,
|
|
ob_min_size=0.001,
|
|
fvg_min_size=0.0005
|
|
)
|
|
|
|
def test_detector_initialization(self, detector):
|
|
"""Test detector initializes correctly"""
|
|
assert detector.swing_lookback == 10
|
|
assert detector.ob_min_size == 0.001
|
|
assert detector.fvg_min_size == 0.0005
|
|
|
|
def test_analyze_returns_ict_analysis(self, detector, sample_ohlcv_data):
|
|
"""Test analyze returns ICTAnalysis object"""
|
|
result = detector.analyze(sample_ohlcv_data, "EURUSD", "1H")
|
|
|
|
assert isinstance(result, ICTAnalysis)
|
|
assert result.symbol == "EURUSD"
|
|
assert result.timeframe == "1H"
|
|
assert result.market_bias in [MarketBias.BULLISH, MarketBias.BEARISH, MarketBias.NEUTRAL]
|
|
|
|
def test_analyze_with_insufficient_data(self, detector):
|
|
"""Test analyze handles insufficient data gracefully"""
|
|
# Create minimal data
|
|
df = pd.DataFrame({
|
|
'open': [1.1, 1.2],
|
|
'high': [1.15, 1.25],
|
|
'low': [1.05, 1.15],
|
|
'close': [1.12, 1.22],
|
|
'volume': [1000, 1000]
|
|
}, index=pd.date_range(end=datetime.now(), periods=2, freq='1H'))
|
|
|
|
result = detector.analyze(df, "TEST", "1H")
|
|
|
|
# Should return empty analysis
|
|
assert result.market_bias == MarketBias.NEUTRAL
|
|
assert result.score == 0
|
|
|
|
def test_swing_points_detection(self, detector, sample_ohlcv_data):
|
|
"""Test swing high/low detection"""
|
|
swing_highs, swing_lows = detector._find_swing_points(sample_ohlcv_data)
|
|
|
|
# Should find some swing points
|
|
assert len(swing_highs) > 0
|
|
assert len(swing_lows) > 0
|
|
|
|
# Each swing point should be a tuple of (index, price)
|
|
for idx, price in swing_highs:
|
|
assert isinstance(idx, int)
|
|
assert isinstance(price, float)
|
|
|
|
def test_order_blocks_detection(self, detector, sample_ohlcv_data):
|
|
"""Test order block detection"""
|
|
swing_highs, swing_lows = detector._find_swing_points(sample_ohlcv_data)
|
|
order_blocks = detector._find_order_blocks(sample_ohlcv_data, swing_highs, swing_lows)
|
|
|
|
# May or may not find order blocks depending on data
|
|
for ob in order_blocks:
|
|
assert isinstance(ob, OrderBlock)
|
|
assert ob.type in ['bullish', 'bearish']
|
|
assert ob.high > ob.low
|
|
assert 0 <= ob.strength <= 1
|
|
|
|
def test_fair_value_gaps_detection(self, detector, sample_ohlcv_data):
|
|
"""Test FVG detection"""
|
|
fvgs = detector._find_fair_value_gaps(sample_ohlcv_data)
|
|
|
|
for fvg in fvgs:
|
|
assert isinstance(fvg, FairValueGap)
|
|
assert fvg.type in ['bullish', 'bearish']
|
|
assert fvg.high > fvg.low
|
|
assert fvg.size > 0
|
|
|
|
def test_premium_discount_zones(self, detector, sample_ohlcv_data):
|
|
"""Test premium/discount zone calculation"""
|
|
swing_highs, swing_lows = detector._find_swing_points(sample_ohlcv_data)
|
|
premium, discount, equilibrium = detector._calculate_zones(
|
|
sample_ohlcv_data, swing_highs, swing_lows
|
|
)
|
|
|
|
# Premium zone should be above equilibrium
|
|
assert premium[0] >= equilibrium or premium[1] >= equilibrium
|
|
|
|
# Discount zone should be below equilibrium
|
|
assert discount[0] <= equilibrium or discount[1] <= equilibrium
|
|
|
|
def test_trade_recommendation(self, detector, sample_ohlcv_data):
|
|
"""Test trade recommendation generation"""
|
|
analysis = detector.analyze(sample_ohlcv_data, "EURUSD", "1H")
|
|
recommendation = detector.get_trade_recommendation(analysis)
|
|
|
|
assert 'action' in recommendation
|
|
assert recommendation['action'] in ['BUY', 'SELL', 'HOLD']
|
|
assert 'score' in recommendation
|
|
|
|
def test_analysis_to_dict(self, detector, sample_ohlcv_data):
|
|
"""Test analysis serialization"""
|
|
analysis = detector.analyze(sample_ohlcv_data, "EURUSD", "1H")
|
|
result = analysis.to_dict()
|
|
|
|
assert isinstance(result, dict)
|
|
assert 'symbol' in result
|
|
assert 'market_bias' in result
|
|
assert 'order_blocks' in result
|
|
assert 'fair_value_gaps' in result
|
|
assert 'signals' in result
|
|
assert 'score' in result
|
|
|
|
def test_setup_score_range(self, detector, sample_ohlcv_data):
|
|
"""Test that setup score is in valid range"""
|
|
analysis = detector.analyze(sample_ohlcv_data, "EURUSD", "1H")
|
|
|
|
assert 0 <= analysis.score <= 100
|
|
|
|
def test_bias_confidence_range(self, detector, sample_ohlcv_data):
|
|
"""Test that bias confidence is in valid range"""
|
|
analysis = detector.analyze(sample_ohlcv_data, "EURUSD", "1H")
|
|
|
|
assert 0 <= analysis.bias_confidence <= 1
|
|
|
|
|
|
class TestStrategyEnsemble:
|
|
"""Test suite for Strategy Ensemble"""
|
|
|
|
@pytest.fixture
|
|
def sample_ohlcv_data(self):
|
|
"""Generate sample OHLCV data"""
|
|
np.random.seed(42)
|
|
n_periods = 300
|
|
|
|
base_price = 1.1000
|
|
trend = np.cumsum(np.random.randn(n_periods) * 0.0005)
|
|
dates = pd.date_range(end=datetime.now(), periods=n_periods, freq='1H')
|
|
|
|
data = []
|
|
for i, date in enumerate(dates):
|
|
price = base_price + trend[i]
|
|
high = price + abs(np.random.randn() * 0.0010)
|
|
low = price - abs(np.random.randn() * 0.0010)
|
|
open_price = price + np.random.randn() * 0.0005
|
|
close = price + np.random.randn() * 0.0005
|
|
volume = np.random.randint(1000, 10000)
|
|
|
|
data.append({
|
|
'open': max(low, min(high, open_price)),
|
|
'high': high,
|
|
'low': low,
|
|
'close': max(low, min(high, close)),
|
|
'volume': volume
|
|
})
|
|
|
|
return pd.DataFrame(data, index=dates)
|
|
|
|
def test_ensemble_import(self):
|
|
"""Test ensemble can be imported"""
|
|
from src.models.strategy_ensemble import (
|
|
StrategyEnsemble,
|
|
EnsembleSignal,
|
|
TradeAction,
|
|
SignalStrength
|
|
)
|
|
|
|
assert StrategyEnsemble is not None
|
|
assert EnsembleSignal is not None
|
|
|
|
def test_ensemble_initialization(self):
|
|
"""Test ensemble initializes correctly"""
|
|
from src.models.strategy_ensemble import StrategyEnsemble
|
|
|
|
ensemble = StrategyEnsemble(
|
|
amd_weight=0.25,
|
|
ict_weight=0.35,
|
|
min_confidence=0.6
|
|
)
|
|
|
|
assert ensemble.min_confidence == 0.6
|
|
# Weights should be normalized
|
|
total = sum(ensemble.weights.values())
|
|
assert abs(total - 1.0) < 0.01
|
|
|
|
def test_ensemble_analyze(self, sample_ohlcv_data):
|
|
"""Test ensemble analysis"""
|
|
from src.models.strategy_ensemble import StrategyEnsemble, EnsembleSignal
|
|
|
|
ensemble = StrategyEnsemble()
|
|
signal = ensemble.analyze(sample_ohlcv_data, "EURUSD", "1H")
|
|
|
|
assert isinstance(signal, EnsembleSignal)
|
|
assert signal.symbol == "EURUSD"
|
|
assert -1 <= signal.net_score <= 1
|
|
assert 0 <= signal.confidence <= 1
|
|
|
|
def test_quick_signal(self, sample_ohlcv_data):
|
|
"""Test quick signal generation"""
|
|
from src.models.strategy_ensemble import StrategyEnsemble
|
|
|
|
ensemble = StrategyEnsemble()
|
|
signal = ensemble.get_quick_signal(sample_ohlcv_data, "EURUSD")
|
|
|
|
assert isinstance(signal, dict)
|
|
assert 'action' in signal
|
|
assert 'confidence' in signal
|
|
assert 'score' in signal
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|