diff --git a/src/api/main.py b/src/api/main.py index 8ddc8ed..adb574c 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -1349,56 +1349,425 @@ async def scan_symbols(request: MultiSymbolRequest): # WebSocket for real-time signals # ============================================================================= from fastapi import WebSocket, WebSocketDisconnect +import json -class ConnectionManager: - """Manage WebSocket connections""" +class SignalStreamManager: + """ + Manages WebSocket connections for real-time signal streaming. + + Supports: + - Per-symbol subscriptions + - Broadcast to all connections + - Targeted messages to specific subscriptions + - Connection state tracking + """ + def __init__(self): + # Map of websocket -> set of subscribed symbols + self.subscriptions: Dict[WebSocket, set] = {} + # Map of symbol -> set of websockets + self.symbol_subscribers: Dict[str, set] = {} + # All active connections self.active_connections: List[WebSocket] = [] + # Background task reference + self._signal_task = None + self._running = False async def connect(self, websocket: WebSocket): + """Accept new WebSocket connection""" await websocket.accept() self.active_connections.append(websocket) + self.subscriptions[websocket] = set() + logger.info(f"WebSocket connected. Total connections: {len(self.active_connections)}") def disconnect(self, websocket: WebSocket): - self.active_connections.remove(websocket) + """Handle WebSocket disconnection""" + if websocket in self.active_connections: + self.active_connections.remove(websocket) + + # Remove from all symbol subscriptions + if websocket in self.subscriptions: + for symbol in self.subscriptions[websocket]: + if symbol in self.symbol_subscribers: + self.symbol_subscribers[symbol].discard(websocket) + del self.subscriptions[websocket] + + logger.info(f"WebSocket disconnected. Total connections: {len(self.active_connections)}") + + def subscribe(self, websocket: WebSocket, symbols: List[str]): + """Subscribe websocket to specific symbols""" + if websocket not in self.subscriptions: + self.subscriptions[websocket] = set() + + for symbol in symbols: + symbol = symbol.upper() + self.subscriptions[websocket].add(symbol) + + if symbol not in self.symbol_subscribers: + self.symbol_subscribers[symbol] = set() + self.symbol_subscribers[symbol].add(websocket) + + logger.debug(f"Subscribed to {symbols}. Total symbols: {len(self.subscriptions[websocket])}") + + def unsubscribe(self, websocket: WebSocket, symbols: List[str]): + """Unsubscribe websocket from specific symbols""" + if websocket not in self.subscriptions: + return + + for symbol in symbols: + symbol = symbol.upper() + self.subscriptions[websocket].discard(symbol) + + if symbol in self.symbol_subscribers: + self.symbol_subscribers[symbol].discard(websocket) + + async def send_to_symbol(self, symbol: str, message: dict): + """Send message to all subscribers of a symbol""" + symbol = symbol.upper() + if symbol not in self.symbol_subscribers: + return + + disconnected = [] + for websocket in self.symbol_subscribers[symbol]: + try: + await websocket.send_json(message) + except Exception: + disconnected.append(websocket) + + # Clean up disconnected + for ws in disconnected: + self.disconnect(ws) async def broadcast(self, message: dict): - for connection in self.active_connections: + """Broadcast message to all connections""" + disconnected = [] + for websocket in self.active_connections: try: - await connection.send_json(message) - except: - pass + await websocket.send_json(message) + except Exception: + disconnected.append(websocket) + + # Clean up disconnected + for ws in disconnected: + self.disconnect(ws) + + async def send_personal(self, websocket: WebSocket, message: dict): + """Send message to specific connection""" + try: + await websocket.send_json(message) + except Exception: + self.disconnect(websocket) + + def get_subscribed_symbols(self, websocket: WebSocket) -> List[str]: + """Get list of symbols a websocket is subscribed to""" + return list(self.subscriptions.get(websocket, set())) + + @property + def connection_count(self) -> int: + return len(self.active_connections) + + @property + def subscription_stats(self) -> Dict[str, int]: + """Get subscription statistics""" + return { + symbol: len(subscribers) + for symbol, subscribers in self.symbol_subscribers.items() + } -manager = ConnectionManager() +# Global signal stream manager +signal_manager = SignalStreamManager() + +# Background signal generation task +_signal_generation_task = None + + +async def generate_signals_background(): + """ + Background task that generates signals periodically and broadcasts to subscribers. + """ + global prediction_service + + logger.info("Starting background signal generation task") + + default_symbols = ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD"] + + while True: + try: + # Only generate if there are active connections + if signal_manager.connection_count > 0: + # Get symbols that have subscribers + active_symbols = list(signal_manager.symbol_subscribers.keys()) + if not active_symbols: + active_symbols = default_symbols + + for symbol in active_symbols: + try: + if prediction_service is None: + prediction_service = get_prediction_service() + + # Generate signal + signal = await prediction_service.generate_signal( + symbol=symbol, + timeframe="15m", + rr_config="rr_2_1" + ) + + # Get attention info + attention = await prediction_service.get_attention_info(symbol, "5m") + + # Build message + message = { + "type": "signal", + "symbol": symbol, + "timestamp": datetime.utcnow().isoformat(), + "data": { + "signal_id": signal.signal_id, + "direction": signal.direction.value, + "entry_price": signal.entry_price, + "stop_loss": signal.stop_loss, + "take_profit": signal.take_profit, + "confidence": signal.confidence_score, + "amd_phase": signal.amd_phase.value, + "valid_until": signal.valid_until.isoformat(), + "attention": { + "score": attention.attention_score if attention else 1.0, + "flow_class": attention.flow_class if attention else 1, + "is_tradeable": attention.is_tradeable if attention else True + } if attention else None + } + } + + # Send to symbol subscribers + await signal_manager.send_to_symbol(symbol, message) + + except Exception as e: + logger.debug(f"Signal generation failed for {symbol}: {e}") + continue + + # Wait before next cycle (60 seconds) + await asyncio.sleep(60) + + except asyncio.CancelledError: + logger.info("Background signal generation task cancelled") + break + except Exception as e: + logger.error(f"Background signal task error: {e}") + await asyncio.sleep(10) + + +@app.on_event("startup") +async def start_signal_streaming(): + """Start background signal streaming on app startup""" + global _signal_generation_task + _signal_generation_task = asyncio.create_task(generate_signals_background()) + + +@app.on_event("shutdown") +async def stop_signal_streaming(): + """Stop background signal streaming on app shutdown""" + global _signal_generation_task + if _signal_generation_task: + _signal_generation_task.cancel() + try: + await _signal_generation_task + except asyncio.CancelledError: + pass @app.websocket("/ws/signals") async def websocket_signals(websocket: WebSocket): """ - WebSocket endpoint for real-time trading signals + WebSocket endpoint for real-time trading signals. - Connect to receive signals as they are generated + Protocol: + 1. Connect to /ws/signals + 2. Send subscription message: {"action": "subscribe", "symbols": ["XAUUSD", "EURUSD"]} + 3. Receive signals: {"type": "signal", "symbol": "XAUUSD", "data": {...}} + 4. Unsubscribe: {"action": "unsubscribe", "symbols": ["XAUUSD"]} + 5. Request immediate signal: {"action": "get_signal", "symbol": "XAUUSD"} + + Message Types Received: + - signal: Trading signal for subscribed symbol + - attention: Attention score update + - hierarchical: Full L0→L1→L2 prediction + - error: Error message + - subscribed: Confirmation of subscription """ - await manager.connect(websocket) + await signal_manager.connect(websocket) + + # Send welcome message + await signal_manager.send_personal(websocket, { + "type": "connected", + "message": "Connected to ML Engine signal stream", + "available_symbols": ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD", "USDJPY"], + "instructions": { + "subscribe": {"action": "subscribe", "symbols": ["XAUUSD"]}, + "unsubscribe": {"action": "unsubscribe", "symbols": ["XAUUSD"]}, + "get_signal": {"action": "get_signal", "symbol": "XAUUSD"}, + "get_attention": {"action": "get_attention", "symbol": "XAUUSD"}, + "get_hierarchical": {"action": "get_hierarchical", "symbol": "XAUUSD"} + } + }) + try: while True: - # Keep connection alive and send signals + # Receive message from client data = await websocket.receive_text() - # TODO: Process incoming requests and send signals - # For now, just echo back - await websocket.send_json({ - "type": "signal", - "data": { - "symbol": "XAUUSD", - "direction": "long", - "timestamp": datetime.utcnow().isoformat() - } - }) + try: + message = json.loads(data) + action = message.get("action", "") + + if action == "subscribe": + symbols = message.get("symbols", []) + if symbols: + signal_manager.subscribe(websocket, symbols) + await signal_manager.send_personal(websocket, { + "type": "subscribed", + "symbols": symbols, + "total_subscriptions": len(signal_manager.get_subscribed_symbols(websocket)) + }) + + elif action == "unsubscribe": + symbols = message.get("symbols", []) + if symbols: + signal_manager.unsubscribe(websocket, symbols) + await signal_manager.send_personal(websocket, { + "type": "unsubscribed", + "symbols": symbols + }) + + elif action == "get_signal": + symbol = message.get("symbol", "XAUUSD").upper() + try: + if prediction_service: + signal = await prediction_service.generate_signal(symbol, "15m", "rr_2_1") + await signal_manager.send_personal(websocket, { + "type": "signal", + "symbol": symbol, + "timestamp": datetime.utcnow().isoformat(), + "data": { + "signal_id": signal.signal_id, + "direction": signal.direction.value, + "entry_price": signal.entry_price, + "stop_loss": signal.stop_loss, + "take_profit": signal.take_profit, + "confidence": signal.confidence_score, + "amd_phase": signal.amd_phase.value + } + }) + except Exception as e: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": f"Failed to get signal: {str(e)}" + }) + + elif action == "get_attention": + symbol = message.get("symbol", "XAUUSD").upper() + timeframe = message.get("timeframe", "5m") + try: + if prediction_service: + attention = await prediction_service.get_attention_info(symbol, timeframe) + if attention: + await signal_manager.send_personal(websocket, { + "type": "attention", + "symbol": symbol, + "timeframe": timeframe, + "data": { + "attention_score": attention.attention_score, + "flow_class": attention.flow_class, + "flow_label": attention.flow_label, + "is_high_flow": attention.is_high_flow, + "is_tradeable": attention.is_tradeable + } + }) + else: + await signal_manager.send_personal(websocket, { + "type": "attention", + "symbol": symbol, + "data": None, + "message": "Attention model not available" + }) + except Exception as e: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": f"Failed to get attention: {str(e)}" + }) + + elif action == "get_hierarchical": + symbol = message.get("symbol", "XAUUSD").upper() + try: + if prediction_service and prediction_service.hierarchical_available: + result = await prediction_service.get_hierarchical_prediction(symbol) + if result: + await signal_manager.send_personal(websocket, { + "type": "hierarchical", + "symbol": symbol, + "timestamp": result.timestamp.isoformat(), + "data": { + "attention_5m": result.attention_score_5m, + "attention_15m": result.attention_score_15m, + "pred_high_5m": result.pred_high_5m, + "pred_low_5m": result.pred_low_5m, + "pred_high_15m": result.pred_high_15m, + "pred_low_15m": result.pred_low_15m, + "delta_high_final": result.delta_high_final, + "delta_low_final": result.delta_low_final, + "confidence": result.confidence_proba, + "should_trade": result.should_trade, + "trade_quality": result.trade_quality + } + }) + else: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": "Hierarchical prediction failed" + }) + else: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": "Hierarchical pipeline not available" + }) + except Exception as e: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": f"Failed to get hierarchical prediction: {str(e)}" + }) + + elif action == "ping": + await signal_manager.send_personal(websocket, { + "type": "pong", + "timestamp": datetime.utcnow().isoformat() + }) + + else: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": f"Unknown action: {action}" + }) + + except json.JSONDecodeError: + await signal_manager.send_personal(websocket, { + "type": "error", + "message": "Invalid JSON message" + }) + except WebSocketDisconnect: - manager.disconnect(websocket) + signal_manager.disconnect(websocket) + + +@app.get("/ws/status", tags=["WebSocket"]) +async def websocket_status(): + """Get WebSocket streaming status""" + return { + "active_connections": signal_manager.connection_count, + "subscriptions": signal_manager.subscription_stats, + "streaming_active": _signal_generation_task is not None and not _signal_generation_task.done(), + "supported_symbols": ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD", "USDJPY"], + "signal_interval_seconds": 60 + } # Main entry point