[GAP-4] feat: Implement real-time WebSocket signal streaming
- Created SignalStreamManager class for connection management
- Per-symbol subscriptions support
- Connection state tracking
- Targeted and broadcast messaging
- Added background signal generation task
- Generates signals every 60 seconds for subscribed symbols
- Includes attention scores in signals
- Enhanced /ws/signals endpoint with full protocol:
- subscribe: Subscribe to symbols
- unsubscribe: Unsubscribe from symbols
- get_signal: Request immediate signal
- get_attention: Request attention scores
- get_hierarchical: Request full L0→L1→L2 prediction
- ping/pong: Connection keepalive
- Added GET /ws/status endpoint for monitoring
- Added startup/shutdown handlers for background task
WebSocket Protocol:
1. Connect to /ws/signals
2. Send: {"action": "subscribe", "symbols": ["XAUUSD"]}
3. Receive: {"type": "signal", "symbol": "XAUUSD", "data": {...}}
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a1e606c21a
commit
a80aeea3c7
415
src/api/main.py
415
src/api/main.py
@ -1349,56 +1349,425 @@ async def scan_symbols(request: MultiSymbolRequest):
|
||||
# WebSocket for real-time signals
|
||||
# =============================================================================
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
import json
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manage WebSocket connections"""
|
||||
class SignalStreamManager:
|
||||
"""
|
||||
Manages WebSocket connections for real-time signal streaming.
|
||||
|
||||
Supports:
|
||||
- Per-symbol subscriptions
|
||||
- Broadcast to all connections
|
||||
- Targeted messages to specific subscriptions
|
||||
- Connection state tracking
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map of websocket -> set of subscribed symbols
|
||||
self.subscriptions: Dict[WebSocket, set] = {}
|
||||
# Map of symbol -> set of websockets
|
||||
self.symbol_subscribers: Dict[str, set] = {}
|
||||
# All active connections
|
||||
self.active_connections: List[WebSocket] = []
|
||||
# Background task reference
|
||||
self._signal_task = None
|
||||
self._running = False
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
"""Accept new WebSocket connection"""
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
self.subscriptions[websocket] = set()
|
||||
logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
"""Handle WebSocket disconnection"""
|
||||
if websocket in self.active_connections:
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
# Remove from all symbol subscriptions
|
||||
if websocket in self.subscriptions:
|
||||
for symbol in self.subscriptions[websocket]:
|
||||
if symbol in self.symbol_subscribers:
|
||||
self.symbol_subscribers[symbol].discard(websocket)
|
||||
del self.subscriptions[websocket]
|
||||
|
||||
logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}")
|
||||
|
||||
def subscribe(self, websocket: WebSocket, symbols: List[str]):
|
||||
"""Subscribe websocket to specific symbols"""
|
||||
if websocket not in self.subscriptions:
|
||||
self.subscriptions[websocket] = set()
|
||||
|
||||
for symbol in symbols:
|
||||
symbol = symbol.upper()
|
||||
self.subscriptions[websocket].add(symbol)
|
||||
|
||||
if symbol not in self.symbol_subscribers:
|
||||
self.symbol_subscribers[symbol] = set()
|
||||
self.symbol_subscribers[symbol].add(websocket)
|
||||
|
||||
logger.debug(f"Subscribed to {symbols}. Total symbols: {len(self.subscriptions[websocket])}")
|
||||
|
||||
def unsubscribe(self, websocket: WebSocket, symbols: List[str]):
|
||||
"""Unsubscribe websocket from specific symbols"""
|
||||
if websocket not in self.subscriptions:
|
||||
return
|
||||
|
||||
for symbol in symbols:
|
||||
symbol = symbol.upper()
|
||||
self.subscriptions[websocket].discard(symbol)
|
||||
|
||||
if symbol in self.symbol_subscribers:
|
||||
self.symbol_subscribers[symbol].discard(websocket)
|
||||
|
||||
async def send_to_symbol(self, symbol: str, message: dict):
|
||||
"""Send message to all subscribers of a symbol"""
|
||||
symbol = symbol.upper()
|
||||
if symbol not in self.symbol_subscribers:
|
||||
return
|
||||
|
||||
disconnected = []
|
||||
for websocket in self.symbol_subscribers[symbol]:
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
for connection in self.active_connections:
|
||||
"""Broadcast message to all connections"""
|
||||
disconnected = []
|
||||
for websocket in self.active_connections:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except:
|
||||
pass
|
||||
await websocket.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
async def send_personal(self, websocket: WebSocket, message: dict):
|
||||
"""Send message to specific connection"""
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception:
|
||||
self.disconnect(websocket)
|
||||
|
||||
def get_subscribed_symbols(self, websocket: WebSocket) -> List[str]:
|
||||
"""Get list of symbols a websocket is subscribed to"""
|
||||
return list(self.subscriptions.get(websocket, set()))
|
||||
|
||||
@property
|
||||
def connection_count(self) -> int:
|
||||
return len(self.active_connections)
|
||||
|
||||
@property
|
||||
def subscription_stats(self) -> Dict[str, int]:
|
||||
"""Get subscription statistics"""
|
||||
return {
|
||||
symbol: len(subscribers)
|
||||
for symbol, subscribers in self.symbol_subscribers.items()
|
||||
}
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
# Global signal stream manager
|
||||
signal_manager = SignalStreamManager()
|
||||
|
||||
# Background signal generation task
|
||||
_signal_generation_task = None
|
||||
|
||||
|
||||
async def generate_signals_background():
|
||||
"""
|
||||
Background task that generates signals periodically and broadcasts to subscribers.
|
||||
"""
|
||||
global prediction_service
|
||||
|
||||
logger.info("Starting background signal generation task")
|
||||
|
||||
default_symbols = ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD"]
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Only generate if there are active connections
|
||||
if signal_manager.connection_count > 0:
|
||||
# Get symbols that have subscribers
|
||||
active_symbols = list(signal_manager.symbol_subscribers.keys())
|
||||
if not active_symbols:
|
||||
active_symbols = default_symbols
|
||||
|
||||
for symbol in active_symbols:
|
||||
try:
|
||||
if prediction_service is None:
|
||||
prediction_service = get_prediction_service()
|
||||
|
||||
# Generate signal
|
||||
signal = await prediction_service.generate_signal(
|
||||
symbol=symbol,
|
||||
timeframe="15m",
|
||||
rr_config="rr_2_1"
|
||||
)
|
||||
|
||||
# Get attention info
|
||||
attention = await prediction_service.get_attention_info(symbol, "5m")
|
||||
|
||||
# Build message
|
||||
message = {
|
||||
"type": "signal",
|
||||
"symbol": symbol,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {
|
||||
"signal_id": signal.signal_id,
|
||||
"direction": signal.direction.value,
|
||||
"entry_price": signal.entry_price,
|
||||
"stop_loss": signal.stop_loss,
|
||||
"take_profit": signal.take_profit,
|
||||
"confidence": signal.confidence_score,
|
||||
"amd_phase": signal.amd_phase.value,
|
||||
"valid_until": signal.valid_until.isoformat(),
|
||||
"attention": {
|
||||
"score": attention.attention_score if attention else 1.0,
|
||||
"flow_class": attention.flow_class if attention else 1,
|
||||
"is_tradeable": attention.is_tradeable if attention else True
|
||||
} if attention else None
|
||||
}
|
||||
}
|
||||
|
||||
# Send to symbol subscribers
|
||||
await signal_manager.send_to_symbol(symbol, message)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Signal generation failed for {symbol}: {e}")
|
||||
continue
|
||||
|
||||
# Wait before next cycle (60 seconds)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Background signal generation task cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Background signal task error: {e}")
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def start_signal_streaming():
|
||||
"""Start background signal streaming on app startup"""
|
||||
global _signal_generation_task
|
||||
_signal_generation_task = asyncio.create_task(generate_signals_background())
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def stop_signal_streaming():
|
||||
"""Stop background signal streaming on app shutdown"""
|
||||
global _signal_generation_task
|
||||
if _signal_generation_task:
|
||||
_signal_generation_task.cancel()
|
||||
try:
|
||||
await _signal_generation_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
@app.websocket("/ws/signals")
|
||||
async def websocket_signals(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for real-time trading signals
|
||||
WebSocket endpoint for real-time trading signals.
|
||||
|
||||
Connect to receive signals as they are generated
|
||||
Protocol:
|
||||
1. Connect to /ws/signals
|
||||
2. Send subscription message: {"action": "subscribe", "symbols": ["XAUUSD", "EURUSD"]}
|
||||
3. Receive signals: {"type": "signal", "symbol": "XAUUSD", "data": {...}}
|
||||
4. Unsubscribe: {"action": "unsubscribe", "symbols": ["XAUUSD"]}
|
||||
5. Request immediate signal: {"action": "get_signal", "symbol": "XAUUSD"}
|
||||
|
||||
Message Types Received:
|
||||
- signal: Trading signal for subscribed symbol
|
||||
- attention: Attention score update
|
||||
- hierarchical: Full L0→L1→L2 prediction
|
||||
- error: Error message
|
||||
- subscribed: Confirmation of subscription
|
||||
"""
|
||||
await manager.connect(websocket)
|
||||
await signal_manager.connect(websocket)
|
||||
|
||||
# Send welcome message
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "connected",
|
||||
"message": "Connected to ML Engine signal stream",
|
||||
"available_symbols": ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD", "USDJPY"],
|
||||
"instructions": {
|
||||
"subscribe": {"action": "subscribe", "symbols": ["XAUUSD"]},
|
||||
"unsubscribe": {"action": "unsubscribe", "symbols": ["XAUUSD"]},
|
||||
"get_signal": {"action": "get_signal", "symbol": "XAUUSD"},
|
||||
"get_attention": {"action": "get_attention", "symbol": "XAUUSD"},
|
||||
"get_hierarchical": {"action": "get_hierarchical", "symbol": "XAUUSD"}
|
||||
}
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Keep connection alive and send signals
|
||||
# Receive message from client
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# TODO: Process incoming requests and send signals
|
||||
# For now, just echo back
|
||||
await websocket.send_json({
|
||||
"type": "signal",
|
||||
"data": {
|
||||
"symbol": "XAUUSD",
|
||||
"direction": "long",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
})
|
||||
try:
|
||||
message = json.loads(data)
|
||||
action = message.get("action", "")
|
||||
|
||||
if action == "subscribe":
|
||||
symbols = message.get("symbols", [])
|
||||
if symbols:
|
||||
signal_manager.subscribe(websocket, symbols)
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "subscribed",
|
||||
"symbols": symbols,
|
||||
"total_subscriptions": len(signal_manager.get_subscribed_symbols(websocket))
|
||||
})
|
||||
|
||||
elif action == "unsubscribe":
|
||||
symbols = message.get("symbols", [])
|
||||
if symbols:
|
||||
signal_manager.unsubscribe(websocket, symbols)
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "unsubscribed",
|
||||
"symbols": symbols
|
||||
})
|
||||
|
||||
elif action == "get_signal":
|
||||
symbol = message.get("symbol", "XAUUSD").upper()
|
||||
try:
|
||||
if prediction_service:
|
||||
signal = await prediction_service.generate_signal(symbol, "15m", "rr_2_1")
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "signal",
|
||||
"symbol": symbol,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"data": {
|
||||
"signal_id": signal.signal_id,
|
||||
"direction": signal.direction.value,
|
||||
"entry_price": signal.entry_price,
|
||||
"stop_loss": signal.stop_loss,
|
||||
"take_profit": signal.take_profit,
|
||||
"confidence": signal.confidence_score,
|
||||
"amd_phase": signal.amd_phase.value
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": f"Failed to get signal: {str(e)}"
|
||||
})
|
||||
|
||||
elif action == "get_attention":
|
||||
symbol = message.get("symbol", "XAUUSD").upper()
|
||||
timeframe = message.get("timeframe", "5m")
|
||||
try:
|
||||
if prediction_service:
|
||||
attention = await prediction_service.get_attention_info(symbol, timeframe)
|
||||
if attention:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "attention",
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"data": {
|
||||
"attention_score": attention.attention_score,
|
||||
"flow_class": attention.flow_class,
|
||||
"flow_label": attention.flow_label,
|
||||
"is_high_flow": attention.is_high_flow,
|
||||
"is_tradeable": attention.is_tradeable
|
||||
}
|
||||
})
|
||||
else:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "attention",
|
||||
"symbol": symbol,
|
||||
"data": None,
|
||||
"message": "Attention model not available"
|
||||
})
|
||||
except Exception as e:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": f"Failed to get attention: {str(e)}"
|
||||
})
|
||||
|
||||
elif action == "get_hierarchical":
|
||||
symbol = message.get("symbol", "XAUUSD").upper()
|
||||
try:
|
||||
if prediction_service and prediction_service.hierarchical_available:
|
||||
result = await prediction_service.get_hierarchical_prediction(symbol)
|
||||
if result:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "hierarchical",
|
||||
"symbol": symbol,
|
||||
"timestamp": result.timestamp.isoformat(),
|
||||
"data": {
|
||||
"attention_5m": result.attention_score_5m,
|
||||
"attention_15m": result.attention_score_15m,
|
||||
"pred_high_5m": result.pred_high_5m,
|
||||
"pred_low_5m": result.pred_low_5m,
|
||||
"pred_high_15m": result.pred_high_15m,
|
||||
"pred_low_15m": result.pred_low_15m,
|
||||
"delta_high_final": result.delta_high_final,
|
||||
"delta_low_final": result.delta_low_final,
|
||||
"confidence": result.confidence_proba,
|
||||
"should_trade": result.should_trade,
|
||||
"trade_quality": result.trade_quality
|
||||
}
|
||||
})
|
||||
else:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": "Hierarchical prediction failed"
|
||||
})
|
||||
else:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": "Hierarchical pipeline not available"
|
||||
})
|
||||
except Exception as e:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": f"Failed to get hierarchical prediction: {str(e)}"
|
||||
})
|
||||
|
||||
elif action == "ping":
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "pong",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
else:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": f"Unknown action: {action}"
|
||||
})
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await signal_manager.send_personal(websocket, {
|
||||
"type": "error",
|
||||
"message": "Invalid JSON message"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
signal_manager.disconnect(websocket)
|
||||
|
||||
|
||||
@app.get("/ws/status", tags=["WebSocket"])
|
||||
async def websocket_status():
|
||||
"""Get WebSocket streaming status"""
|
||||
return {
|
||||
"active_connections": signal_manager.connection_count,
|
||||
"subscriptions": signal_manager.subscription_stats,
|
||||
"streaming_active": _signal_generation_task is not None and not _signal_generation_task.done(),
|
||||
"supported_symbols": ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD", "USDJPY"],
|
||||
"signal_interval_seconds": 60
|
||||
}
|
||||
|
||||
|
||||
# Main entry point
|
||||
|
||||
Loading…
Reference in New Issue
Block a user