trading-platform-ml-engine-v2/tests/test_api.py
rckrdmrd 75c4d07690 feat: Initial commit - ML Engine codebase
Hierarchical ML Pipeline for trading predictions:
- Level 0: Attention Models (volatility/flow classification)
- Level 1: Base Models (XGBoost per symbol/timeframe)
- Level 2: Metamodels (XGBoost Stacking + Neural Gating)

Key components:
- src/pipelines/hierarchical_pipeline.py - Main prediction pipeline
- src/models/ - All ML model classes
- src/training/ - Training utilities
- src/api/ - FastAPI endpoints
- scripts/ - Training and evaluation scripts
- config/ - YAML configurations

Note: Trained models (*.joblib, *.pt) are gitignored.
      Regenerate with training scripts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 04:27:40 -06:00

192 lines
5.0 KiB
Python

"""
Test ML Engine API endpoints
"""
import pytest
from fastapi.testclient import TestClient
from datetime import datetime
from src.api.main import app
@pytest.fixture
def client():
"""Create test client"""
return TestClient(app)
def test_health_check(client):
"""Test health check endpoint"""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "version" in data
assert "timestamp" in data
assert isinstance(data["models_loaded"], bool)
def test_list_models(client):
"""Test list models endpoint"""
response = client.get("/models")
assert response.status_code == 200
assert isinstance(response.json(), list)
def test_list_symbols(client):
"""Test list symbols endpoint"""
response = client.get("/symbols")
assert response.status_code == 200
symbols = response.json()
assert isinstance(symbols, list)
assert "XAUUSD" in symbols
assert "EURUSD" in symbols
def test_predict_range(client):
"""Test range prediction endpoint"""
request_data = {
"symbol": "XAUUSD",
"timeframe": "15m",
"horizon": "15m"
}
response = client.post("/predict/range", json=request_data)
# May return 503 if models not loaded, which is acceptable
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert isinstance(data, list)
assert len(data) > 0
def test_predict_tpsl(client):
"""Test TP/SL prediction endpoint"""
request_data = {
"symbol": "XAUUSD",
"timeframe": "15m",
"horizon": "15m"
}
response = client.post("/predict/tpsl?rr_config=rr_2_1", json=request_data)
# May return 503 if models not loaded
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "prob_tp_first" in data
assert "rr_config" in data
assert "confidence" in data
def test_generate_signal(client):
"""Test signal generation endpoint"""
request_data = {
"symbol": "XAUUSD",
"timeframe": "15m",
"horizon": "15m"
}
response = client.post("/generate/signal?rr_config=rr_2_1", json=request_data)
# May return 503 if models not loaded
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "signal_id" in data
assert "symbol" in data
assert "direction" in data
assert "entry_price" in data
assert "stop_loss" in data
assert "take_profit" in data
def test_amd_detection(client):
"""Test AMD phase detection endpoint"""
response = client.post("/api/amd/XAUUSD?timeframe=15m&lookback_periods=100")
# May return 503 if AMD detector not loaded
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "phase" in data
assert "confidence" in data
assert "strength" in data
assert "characteristics" in data
assert "signals" in data
assert "trading_bias" in data
def test_backtest(client):
"""Test backtesting endpoint"""
request_data = {
"symbol": "XAUUSD",
"start_date": "2024-01-01T00:00:00",
"end_date": "2024-02-01T00:00:00",
"initial_capital": 10000.0,
"risk_per_trade": 0.02,
"rr_config": "rr_2_1",
"filter_by_amd": True,
"min_confidence": 0.55
}
response = client.post("/api/backtest", json=request_data)
# May return 503 if backtester not loaded
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "total_trades" in data
assert "winrate" in data
assert "net_profit" in data
assert "profit_factor" in data
assert "max_drawdown" in data
def test_train_models(client):
"""Test model training endpoint"""
request_data = {
"symbol": "XAUUSD",
"start_date": "2023-01-01T00:00:00",
"end_date": "2024-01-01T00:00:00",
"models_to_train": ["range_predictor", "tpsl_classifier"],
"use_walk_forward": True,
"n_splits": 5
}
response = client.post("/api/train/full", json=request_data)
# May return 503 if pipeline not loaded
assert response.status_code in [200, 503]
if response.status_code == 200:
data = response.json()
assert "status" in data
assert "models_trained" in data
assert "metrics" in data
assert "model_paths" in data
def test_websocket_connection(client):
"""Test WebSocket connection"""
with client.websocket_connect("/ws/signals") as websocket:
# Send a test message
websocket.send_text("test")
# Receive response
data = websocket.receive_json()
assert "type" in data
assert "data" in data
if __name__ == "__main__":
pytest.main([__file__, "-v"])