trading-platform/apps/data-service/tests/test_polygon_client.py

196 lines
6.1 KiB
Python

"""
Tests for Polygon/Massive Client
OrbiQuant IA Trading Platform
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
from providers.polygon_client import (
PolygonClient, AssetType, Timeframe, OHLCVBar, TickerSnapshot
)
class TestPolygonClient:
"""Test PolygonClient class."""
def test_init_with_api_key(self):
"""Test initialization with API key."""
client = PolygonClient(api_key="test_key")
assert client.api_key == "test_key"
assert client.base_url == PolygonClient.BASE_URL
def test_init_with_massive_url(self):
"""Test initialization with Massive URL."""
client = PolygonClient(api_key="test_key", use_massive_url=True)
assert client.base_url == PolygonClient.MASSIVE_URL
def test_init_without_api_key(self):
"""Test initialization without API key raises error."""
with pytest.raises(ValueError, match="API_KEY is required"):
PolygonClient()
def test_format_symbol_forex(self):
"""Test formatting forex symbols."""
client = PolygonClient(api_key="test")
formatted = client._format_symbol("EURUSD", AssetType.FOREX)
assert formatted == "C:EURUSD"
def test_format_symbol_crypto(self):
"""Test formatting crypto symbols."""
client = PolygonClient(api_key="test")
formatted = client._format_symbol("BTCUSD", AssetType.CRYPTO)
assert formatted == "X:BTCUSD"
def test_format_symbol_index(self):
"""Test formatting index symbols."""
client = PolygonClient(api_key="test")
formatted = client._format_symbol("SPX", AssetType.INDEX)
assert formatted == "I:SPX"
def test_format_symbol_already_formatted(self):
"""Test formatting already formatted symbols."""
client = PolygonClient(api_key="test")
formatted = client._format_symbol("C:EURUSD", AssetType.FOREX)
assert formatted == "C:EURUSD"
@pytest.mark.asyncio
async def test_rate_limit_wait(self):
"""Test rate limiting."""
client = PolygonClient(api_key="test", rate_limit_per_min=2)
# First request should not wait
await client._rate_limit_wait()
assert client._request_count == 1
# Second request should not wait
await client._rate_limit_wait()
assert client._request_count == 2
@pytest.mark.asyncio
async def test_context_manager(self):
"""Test using client as context manager."""
async with PolygonClient(api_key="test") as client:
assert client._session is not None
@pytest.mark.asyncio
async def test_request_with_mock_response(self):
"""Test making API request with mock response."""
client = PolygonClient(api_key="test")
# Mock aiohttp session
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={"results": []})
mock_response.raise_for_status = MagicMock()
mock_session = AsyncMock()
mock_session.get.return_value.__aenter__.return_value = mock_response
client._session = mock_session
result = await client._request("/test")
assert "results" in result
mock_session.get.assert_called_once()
@pytest.mark.asyncio
async def test_request_rate_limited(self):
"""Test handling rate limit response."""
client = PolygonClient(api_key="test")
# Mock rate limit then success
mock_response_429 = AsyncMock()
mock_response_429.status = 429
mock_response_429.headers = {"Retry-After": "1"}
mock_response_200 = AsyncMock()
mock_response_200.status = 200
mock_response_200.json = AsyncMock(return_value={"status": "OK"})
mock_response_200.raise_for_status = MagicMock()
mock_session = AsyncMock()
mock_session.get.return_value.__aenter__.side_effect = [
mock_response_429,
mock_response_200
]
client._session = mock_session
with patch('asyncio.sleep', new=AsyncMock()):
result = await client._request("/test")
assert result["status"] == "OK"
class TestTimeframe:
"""Test Timeframe enum."""
def test_timeframe_values(self):
"""Test timeframe enum values."""
assert Timeframe.MINUTE_1.value == ("1", "minute")
assert Timeframe.MINUTE_5.value == ("5", "minute")
assert Timeframe.MINUTE_15.value == ("15", "minute")
assert Timeframe.HOUR_1.value == ("1", "hour")
assert Timeframe.HOUR_4.value == ("4", "hour")
assert Timeframe.DAY_1.value == ("1", "day")
class TestAssetType:
"""Test AssetType enum."""
def test_asset_type_values(self):
"""Test asset type enum values."""
assert AssetType.FOREX.value == "forex"
assert AssetType.CRYPTO.value == "crypto"
assert AssetType.INDEX.value == "index"
assert AssetType.FUTURES.value == "futures"
assert AssetType.STOCK.value == "stock"
class TestOHLCVBar:
"""Test OHLCVBar dataclass."""
def test_ohlcv_bar_creation(self):
"""Test creating OHLCV bar."""
bar = OHLCVBar(
timestamp=datetime.now(),
open=1.10,
high=1.15,
low=1.09,
close=1.12,
volume=1000000,
vwap=1.11,
transactions=1500
)
assert bar.open == 1.10
assert bar.close == 1.12
assert bar.volume == 1000000
class TestTickerSnapshot:
"""Test TickerSnapshot dataclass."""
def test_ticker_snapshot_creation(self):
"""Test creating ticker snapshot."""
snapshot = TickerSnapshot(
symbol="EURUSD",
bid=1.1000,
ask=1.1002,
last_price=1.1001,
timestamp=datetime.now(),
daily_high=1.1050,
daily_low=1.0950
)
assert snapshot.symbol == "EURUSD"
assert snapshot.bid == 1.1000
assert snapshot.ask == 1.1002
if __name__ == "__main__":
pytest.main([__file__, "-v"])