228 lines
7.1 KiB
Python
228 lines
7.1 KiB
Python
"""
|
|
Tests for Data Synchronization Service
|
|
OrbiQuant IA Trading Platform
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from services.sync_service import DataSyncService, SyncStatus
|
|
from providers.polygon_client import AssetType, Timeframe, OHLCVBar
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_polygon_client():
|
|
"""Mock Polygon client."""
|
|
client = MagicMock()
|
|
client.get_ticker_details = AsyncMock(return_value={
|
|
"name": "EUR/USD",
|
|
"primary_exchange": "FOREX"
|
|
})
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_db_pool():
|
|
"""Mock database pool."""
|
|
pool = MagicMock()
|
|
|
|
# Mock connection
|
|
conn = MagicMock()
|
|
conn.fetchrow = AsyncMock(return_value={"id": 1, "last_ts": None})
|
|
conn.fetchval = AsyncMock(return_value=1)
|
|
conn.fetch = AsyncMock(return_value=[])
|
|
conn.execute = AsyncMock()
|
|
conn.executemany = AsyncMock()
|
|
|
|
# Mock pool.acquire context manager
|
|
pool.acquire = MagicMock()
|
|
pool.acquire.return_value.__aenter__ = AsyncMock(return_value=conn)
|
|
pool.acquire.return_value.__aexit__ = AsyncMock()
|
|
|
|
return pool
|
|
|
|
|
|
@pytest.fixture
|
|
def sync_service(mock_polygon_client, mock_db_pool):
|
|
"""Create DataSyncService instance."""
|
|
return DataSyncService(
|
|
polygon_client=mock_polygon_client,
|
|
db_pool=mock_db_pool,
|
|
batch_size=100
|
|
)
|
|
|
|
|
|
class TestDataSyncService:
|
|
"""Test DataSyncService class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_ticker_existing(self, sync_service, mock_db_pool):
|
|
"""Test getting existing ticker."""
|
|
# Mock existing ticker
|
|
conn = await mock_db_pool.acquire().__aenter__()
|
|
conn.fetchrow.return_value = {"id": 123}
|
|
|
|
ticker_id = await sync_service.get_or_create_ticker("EURUSD", AssetType.FOREX)
|
|
|
|
assert ticker_id == 123
|
|
conn.fetchrow.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_ticker_new(self, sync_service, mock_db_pool):
|
|
"""Test creating new ticker."""
|
|
# Mock no existing ticker, then return new ID
|
|
conn = await mock_db_pool.acquire().__aenter__()
|
|
conn.fetchrow.return_value = None
|
|
conn.fetchval.return_value = 456
|
|
|
|
ticker_id = await sync_service.get_or_create_ticker("GBPUSD", AssetType.FOREX)
|
|
|
|
assert ticker_id == 456
|
|
conn.fetchval.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_ticker_data_success(self, sync_service, mock_polygon_client):
|
|
"""Test successful ticker sync."""
|
|
# Mock data from Polygon
|
|
async def mock_aggregates(*args, **kwargs):
|
|
bars = [
|
|
OHLCVBar(
|
|
timestamp=datetime.now(),
|
|
open=1.1000,
|
|
high=1.1050,
|
|
low=1.0950,
|
|
close=1.1025,
|
|
volume=1000000,
|
|
vwap=1.1012,
|
|
transactions=1500
|
|
)
|
|
]
|
|
for bar in bars:
|
|
yield bar
|
|
|
|
mock_polygon_client.get_aggregates = mock_aggregates
|
|
|
|
result = await sync_service.sync_ticker_data(
|
|
symbol="EURUSD",
|
|
asset_type=AssetType.FOREX,
|
|
timeframe=Timeframe.MINUTE_5,
|
|
backfill_days=1
|
|
)
|
|
|
|
assert result["status"] == SyncStatus.SUCCESS
|
|
assert result["symbol"] == "EURUSD"
|
|
assert result["rows_inserted"] >= 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_ticker_data_no_ticker(self, sync_service, mock_db_pool):
|
|
"""Test sync when ticker creation fails."""
|
|
# Mock ticker creation failure
|
|
conn = await mock_db_pool.acquire().__aenter__()
|
|
conn.fetchrow.return_value = None
|
|
conn.fetchval.return_value = None
|
|
|
|
result = await sync_service.sync_ticker_data(
|
|
symbol="INVALID",
|
|
asset_type=AssetType.FOREX,
|
|
backfill_days=1
|
|
)
|
|
|
|
assert result["status"] == SyncStatus.FAILED
|
|
assert "Failed to get/create ticker" in result["error"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_insert_bars(self, sync_service):
|
|
"""Test inserting bars."""
|
|
bars = [
|
|
(1, datetime.now(), 1.1, 1.15, 1.09, 1.12, 1000, 1.11, 100, 1234567890)
|
|
]
|
|
|
|
inserted = await sync_service._insert_bars("ohlcv_5min", bars)
|
|
|
|
assert inserted == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_supported_symbols(self, sync_service):
|
|
"""Test getting supported symbols."""
|
|
symbols = await sync_service.get_supported_symbols()
|
|
|
|
assert len(symbols) > 0
|
|
assert all("symbol" in s for s in symbols)
|
|
assert all("asset_type" in s for s in symbols)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_supported_symbols_filtered(self, sync_service):
|
|
"""Test getting supported symbols with filter."""
|
|
forex_symbols = await sync_service.get_supported_symbols(
|
|
asset_type=AssetType.FOREX
|
|
)
|
|
|
|
assert len(forex_symbols) > 0
|
|
assert all(s["asset_type"] == "forex" for s in forex_symbols)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_sync_status(self, sync_service, mock_db_pool):
|
|
"""Test getting sync status."""
|
|
# Mock status data
|
|
conn = await mock_db_pool.acquire().__aenter__()
|
|
conn.fetch.return_value = [
|
|
{
|
|
"symbol": "EURUSD",
|
|
"asset_type": "forex",
|
|
"timeframe": "5min",
|
|
"last_sync_timestamp": datetime.now(),
|
|
"last_sync_rows": 100,
|
|
"sync_status": "success",
|
|
"error_message": None,
|
|
"updated_at": datetime.now()
|
|
}
|
|
]
|
|
|
|
status = await sync_service.get_sync_status()
|
|
|
|
assert len(status) == 1
|
|
assert status[0]["symbol"] == "EURUSD"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_all_active_tickers(self, sync_service, mock_db_pool, mock_polygon_client):
|
|
"""Test syncing all active tickers."""
|
|
# Mock active tickers
|
|
conn = await mock_db_pool.acquire().__aenter__()
|
|
conn.fetch.return_value = [
|
|
{"id": 1, "symbol": "EURUSD", "asset_type": "forex"},
|
|
{"id": 2, "symbol": "GBPUSD", "asset_type": "forex"}
|
|
]
|
|
|
|
# Mock empty aggregates
|
|
async def mock_aggregates(*args, **kwargs):
|
|
return
|
|
yield # Make it a generator
|
|
|
|
mock_polygon_client.get_aggregates = mock_aggregates
|
|
|
|
result = await sync_service.sync_all_active_tickers(
|
|
timeframe=Timeframe.MINUTE_5,
|
|
backfill_days=1
|
|
)
|
|
|
|
assert "total_tickers" in result
|
|
assert "successful" in result
|
|
assert "total_rows_inserted" in result
|
|
|
|
|
|
class TestSyncStatus:
|
|
"""Test SyncStatus enum."""
|
|
|
|
def test_sync_status_values(self):
|
|
"""Test SyncStatus enum values."""
|
|
assert SyncStatus.PENDING == "pending"
|
|
assert SyncStatus.IN_PROGRESS == "in_progress"
|
|
assert SyncStatus.SUCCESS == "success"
|
|
assert SyncStatus.FAILED == "failed"
|
|
assert SyncStatus.PARTIAL == "partial"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|