192 lines
5.0 KiB
Python
192 lines
5.0 KiB
Python
"""
|
|
Test ML Engine API endpoints
|
|
"""
|
|
|
|
import pytest
|
|
from fastapi.testclient import TestClient
|
|
from datetime import datetime
|
|
|
|
from src.api.main import app
|
|
|
|
|
|
@pytest.fixture
|
|
def client():
|
|
"""Create test client"""
|
|
return TestClient(app)
|
|
|
|
|
|
def test_health_check(client):
|
|
"""Test health check endpoint"""
|
|
response = client.get("/health")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert "version" in data
|
|
assert "timestamp" in data
|
|
assert isinstance(data["models_loaded"], bool)
|
|
|
|
|
|
def test_list_models(client):
|
|
"""Test list models endpoint"""
|
|
response = client.get("/models")
|
|
assert response.status_code == 200
|
|
assert isinstance(response.json(), list)
|
|
|
|
|
|
def test_list_symbols(client):
|
|
"""Test list symbols endpoint"""
|
|
response = client.get("/symbols")
|
|
assert response.status_code == 200
|
|
|
|
symbols = response.json()
|
|
assert isinstance(symbols, list)
|
|
assert "XAUUSD" in symbols
|
|
assert "EURUSD" in symbols
|
|
|
|
|
|
def test_predict_range(client):
|
|
"""Test range prediction endpoint"""
|
|
request_data = {
|
|
"symbol": "XAUUSD",
|
|
"timeframe": "15m",
|
|
"horizon": "15m"
|
|
}
|
|
|
|
response = client.post("/predict/range", json=request_data)
|
|
|
|
# May return 503 if models not loaded, which is acceptable
|
|
assert response.status_code in [200, 503]
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
assert isinstance(data, list)
|
|
assert len(data) > 0
|
|
|
|
|
|
def test_predict_tpsl(client):
|
|
"""Test TP/SL prediction endpoint"""
|
|
request_data = {
|
|
"symbol": "XAUUSD",
|
|
"timeframe": "15m",
|
|
"horizon": "15m"
|
|
}
|
|
|
|
response = client.post("/predict/tpsl?rr_config=rr_2_1", json=request_data)
|
|
|
|
# May return 503 if models not loaded
|
|
assert response.status_code in [200, 503]
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
assert "prob_tp_first" in data
|
|
assert "rr_config" in data
|
|
assert "confidence" in data
|
|
|
|
|
|
def test_generate_signal(client):
|
|
"""Test signal generation endpoint"""
|
|
request_data = {
|
|
"symbol": "XAUUSD",
|
|
"timeframe": "15m",
|
|
"horizon": "15m"
|
|
}
|
|
|
|
response = client.post("/generate/signal?rr_config=rr_2_1", json=request_data)
|
|
|
|
# May return 503 if models not loaded
|
|
assert response.status_code in [200, 503]
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
assert "signal_id" in data
|
|
assert "symbol" in data
|
|
assert "direction" in data
|
|
assert "entry_price" in data
|
|
assert "stop_loss" in data
|
|
assert "take_profit" in data
|
|
|
|
|
|
def test_amd_detection(client):
|
|
"""Test AMD phase detection endpoint"""
|
|
response = client.post("/api/amd/XAUUSD?timeframe=15m&lookback_periods=100")
|
|
|
|
# May return 503 if AMD detector not loaded
|
|
assert response.status_code in [200, 503]
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
assert "phase" in data
|
|
assert "confidence" in data
|
|
assert "strength" in data
|
|
assert "characteristics" in data
|
|
assert "signals" in data
|
|
assert "trading_bias" in data
|
|
|
|
|
|
def test_backtest(client):
|
|
"""Test backtesting endpoint"""
|
|
request_data = {
|
|
"symbol": "XAUUSD",
|
|
"start_date": "2024-01-01T00:00:00",
|
|
"end_date": "2024-02-01T00:00:00",
|
|
"initial_capital": 10000.0,
|
|
"risk_per_trade": 0.02,
|
|
"rr_config": "rr_2_1",
|
|
"filter_by_amd": True,
|
|
"min_confidence": 0.55
|
|
}
|
|
|
|
response = client.post("/api/backtest", json=request_data)
|
|
|
|
# May return 503 if backtester not loaded
|
|
assert response.status_code in [200, 503]
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
assert "total_trades" in data
|
|
assert "winrate" in data
|
|
assert "net_profit" in data
|
|
assert "profit_factor" in data
|
|
assert "max_drawdown" in data
|
|
|
|
|
|
def test_train_models(client):
|
|
"""Test model training endpoint"""
|
|
request_data = {
|
|
"symbol": "XAUUSD",
|
|
"start_date": "2023-01-01T00:00:00",
|
|
"end_date": "2024-01-01T00:00:00",
|
|
"models_to_train": ["range_predictor", "tpsl_classifier"],
|
|
"use_walk_forward": True,
|
|
"n_splits": 5
|
|
}
|
|
|
|
response = client.post("/api/train/full", json=request_data)
|
|
|
|
# May return 503 if pipeline not loaded
|
|
assert response.status_code in [200, 503]
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
assert "status" in data
|
|
assert "models_trained" in data
|
|
assert "metrics" in data
|
|
assert "model_paths" in data
|
|
|
|
|
|
def test_websocket_connection(client):
|
|
"""Test WebSocket connection"""
|
|
with client.websocket_connect("/ws/signals") as websocket:
|
|
# Send a test message
|
|
websocket.send_text("test")
|
|
|
|
# Receive response
|
|
data = websocket.receive_json()
|
|
assert "type" in data
|
|
assert "data" in data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|