[GAP-2] feat: Wire HierarchicalPipeline into prediction service
- Added HierarchicalPipeline initialization in PredictionService._load_models()
- Added HierarchicalResult dataclass for typed responses
- Added get_hierarchical_prediction() method for full L0→L1→L2 predictions
- Added get_hierarchical_model_info() method for model introspection
- Added hierarchical_available property
- Added API endpoints:
- GET /api/hierarchical/{symbol} - Full 3-level prediction
- GET /api/hierarchical/{symbol}/models - Model info per symbol
- GET /api/hierarchical/status - Pipeline status
- Updated health endpoint to include hierarchical_available flag
This completes GAP #2: L1→L2 Hierarchical Predictor integration.
The pipeline now uses symbol-specific metamodels (XGBoost or Neural Gating)
to synthesize 5m and 15m base model predictions.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
b0b4a712eb
commit
a22fb11968
165
src/api/main.py
165
src/api/main.py
@ -24,7 +24,8 @@ from ..services.prediction_service import (
|
||||
initialize_prediction_service,
|
||||
Direction,
|
||||
AMDPhase as ServiceAMDPhase,
|
||||
VolatilityRegime as ServiceVolatilityRegime
|
||||
VolatilityRegime as ServiceVolatilityRegime,
|
||||
HierarchicalResult
|
||||
)
|
||||
|
||||
# API Models
|
||||
@ -123,6 +124,7 @@ class HealthResponse(BaseModel):
|
||||
status: str
|
||||
version: str
|
||||
models_loaded: bool
|
||||
hierarchical_available: bool = False
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@ -199,10 +201,14 @@ async def shutdown_event():
|
||||
@app.get("/health", response_model=HealthResponse, tags=["System"])
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
hierarchical = False
|
||||
if prediction_service:
|
||||
hierarchical = prediction_service.hierarchical_available
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
version="0.1.0",
|
||||
models_loaded=models_state["loaded"],
|
||||
hierarchical_available=hierarchical,
|
||||
timestamp=datetime.utcnow()
|
||||
)
|
||||
|
||||
@ -471,6 +477,163 @@ async def list_attention_models():
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hierarchical Pipeline Endpoints (L0→L1→L2)
|
||||
# =============================================================================
|
||||
|
||||
class HierarchicalPredictionResponse(BaseModel):
|
||||
"""Full hierarchical L0→L1→L2 prediction response"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
# Level 0: Attention
|
||||
attention_score_5m: float = Field(..., description="Attention score for 5m timeframe")
|
||||
attention_score_15m: float = Field(..., description="Attention score for 15m timeframe")
|
||||
attention_class_5m: int = Field(..., description="Flow class 5m: 0=low, 1=med, 2=high")
|
||||
attention_class_15m: int = Field(..., description="Flow class 15m: 0=low, 1=med, 2=high")
|
||||
# Level 1: Base Models
|
||||
pred_high_5m: float = Field(..., description="Predicted delta high from 5m model")
|
||||
pred_low_5m: float = Field(..., description="Predicted delta low from 5m model")
|
||||
pred_high_15m: float = Field(..., description="Predicted delta high from 15m model")
|
||||
pred_low_15m: float = Field(..., description="Predicted delta low from 15m model")
|
||||
# Level 2: Metamodel
|
||||
delta_high_final: float = Field(..., description="Final synthesized delta high")
|
||||
delta_low_final: float = Field(..., description="Final synthesized delta low")
|
||||
confidence: bool = Field(..., description="Metamodel confidence flag")
|
||||
confidence_proba: float = Field(..., description="Metamodel confidence probability")
|
||||
# Trading signals
|
||||
should_trade: bool = Field(..., description="Should trade based on attention+confidence")
|
||||
trade_quality: str = Field(..., description="Trade quality: high, medium, low, skip")
|
||||
|
||||
|
||||
@app.get("/api/hierarchical/{symbol}", response_model=HierarchicalPredictionResponse, tags=["Hierarchical"])
|
||||
async def get_hierarchical_prediction(symbol: str):
|
||||
"""
|
||||
Get full hierarchical L0→L1→L2 prediction for a symbol.
|
||||
|
||||
This is the main prediction endpoint that uses the 3-level architecture:
|
||||
- Level 0 (Attention): Determines WHEN to pay attention to market
|
||||
- Level 1 (Base Models): Symbol/timeframe specific predictions (5m + 15m)
|
||||
- Level 2 (Metamodel): Synthesizes predictions into final values
|
||||
|
||||
The metamodel type (XGBoost or Neural Gating) is selected automatically
|
||||
based on best performance per symbol.
|
||||
|
||||
Returns should_trade=True only when attention is sufficient AND
|
||||
metamodel confidence exceeds threshold.
|
||||
"""
|
||||
global prediction_service
|
||||
|
||||
if prediction_service is None:
|
||||
prediction_service = get_prediction_service()
|
||||
|
||||
if not prediction_service.hierarchical_available:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Hierarchical pipeline not available"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await prediction_service.get_hierarchical_prediction(
|
||||
symbol=symbol.upper()
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Hierarchical prediction failed for {symbol}"
|
||||
)
|
||||
|
||||
return HierarchicalPredictionResponse(
|
||||
symbol=result.symbol,
|
||||
timestamp=result.timestamp,
|
||||
attention_score_5m=result.attention_score_5m,
|
||||
attention_score_15m=result.attention_score_15m,
|
||||
attention_class_5m=result.attention_class_5m,
|
||||
attention_class_15m=result.attention_class_15m,
|
||||
pred_high_5m=result.pred_high_5m,
|
||||
pred_low_5m=result.pred_low_5m,
|
||||
pred_high_15m=result.pred_high_15m,
|
||||
pred_low_15m=result.pred_low_15m,
|
||||
delta_high_final=result.delta_high_final,
|
||||
delta_low_final=result.delta_low_final,
|
||||
confidence=result.confidence,
|
||||
confidence_proba=result.confidence_proba,
|
||||
should_trade=result.should_trade,
|
||||
trade_quality=result.trade_quality
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Hierarchical prediction failed for {symbol}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Hierarchical prediction failed: {str(e)}")
|
||||
|
||||
|
||||
@app.get("/api/hierarchical/{symbol}/models", tags=["Hierarchical"])
|
||||
async def get_hierarchical_model_info(symbol: str):
|
||||
"""
|
||||
Get information about loaded hierarchical models for a symbol.
|
||||
|
||||
Shows which models are loaded for each level of the hierarchy.
|
||||
"""
|
||||
global prediction_service
|
||||
|
||||
if prediction_service is None:
|
||||
prediction_service = get_prediction_service()
|
||||
|
||||
if not prediction_service.hierarchical_available:
|
||||
return {
|
||||
"available": False,
|
||||
"symbol": symbol.upper(),
|
||||
"models_loaded": False
|
||||
}
|
||||
|
||||
try:
|
||||
info = await prediction_service.get_hierarchical_model_info(symbol.upper())
|
||||
return {
|
||||
"available": True,
|
||||
**info
|
||||
} if info else {
|
||||
"available": True,
|
||||
"symbol": symbol.upper(),
|
||||
"models_loaded": False,
|
||||
"attention_models": [],
|
||||
"base_models": [],
|
||||
"metamodel_type": "none"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get hierarchical model info: {e}")
|
||||
return {
|
||||
"available": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/hierarchical/status", tags=["Hierarchical"])
|
||||
async def get_hierarchical_status():
|
||||
"""
|
||||
Get overall status of the hierarchical pipeline.
|
||||
|
||||
Returns availability and configuration info.
|
||||
"""
|
||||
global prediction_service
|
||||
|
||||
if prediction_service is None:
|
||||
prediction_service = get_prediction_service()
|
||||
|
||||
available = prediction_service.hierarchical_available
|
||||
|
||||
return {
|
||||
"available": available,
|
||||
"supported_symbols": ["XAUUSD", "EURUSD", "BTCUSD", "GBPUSD", "USDJPY"] if available else [],
|
||||
"timeframes": ["5m", "15m"] if available else [],
|
||||
"architecture": {
|
||||
"level_0": "Attention Models (flow detection)",
|
||||
"level_1": "Base Models (symbol/timeframe specific)",
|
||||
"level_2": "Metamodels (XGBoost or Neural Gating)"
|
||||
} if available else {}
|
||||
}
|
||||
|
||||
|
||||
# Active signals endpoint - GET version for easy consumption
|
||||
class ActiveSignalsResponse(BaseModel):
|
||||
"""Response with active signals for all symbols"""
|
||||
|
||||
@ -34,6 +34,13 @@ from ..data.indicators import TechnicalIndicators
|
||||
# Attention provider for Level 0 features
|
||||
from .attention_provider import AttentionProvider, get_attention_provider
|
||||
|
||||
# Hierarchical Pipeline for L0→L1→L2 predictions
|
||||
from ..pipelines.hierarchical_pipeline import (
|
||||
HierarchicalPipeline,
|
||||
PipelineConfig,
|
||||
PredictionResult as HierarchicalPredictionResult
|
||||
)
|
||||
|
||||
|
||||
class Direction(Enum):
|
||||
LONG = "long"
|
||||
@ -86,6 +93,31 @@ class AttentionInfo:
|
||||
is_tradeable: bool # True if attention_score >= 1.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class HierarchicalResult:
|
||||
"""Result from hierarchical L0→L1→L2 pipeline"""
|
||||
symbol: str
|
||||
timestamp: datetime
|
||||
# Level 0 outputs
|
||||
attention_score_5m: float
|
||||
attention_score_15m: float
|
||||
attention_class_5m: int
|
||||
attention_class_15m: int
|
||||
# Level 1 outputs
|
||||
pred_high_5m: float
|
||||
pred_low_5m: float
|
||||
pred_high_15m: float
|
||||
pred_low_15m: float
|
||||
# Level 2 outputs (final)
|
||||
delta_high_final: float
|
||||
delta_low_final: float
|
||||
confidence: bool
|
||||
confidence_proba: float
|
||||
# Trading signals
|
||||
should_trade: bool
|
||||
trade_quality: str # 'high', 'medium', 'low', 'skip'
|
||||
|
||||
|
||||
@dataclass
|
||||
class TradingSignal:
|
||||
"""Complete trading signal"""
|
||||
@ -154,6 +186,7 @@ class PredictionService:
|
||||
self._tpsl_classifier = None
|
||||
self._amd_detector = None
|
||||
self._attention_provider = None # Level 0 attention models
|
||||
self._hierarchical_pipeline = None # L0→L1→L2 pipeline
|
||||
self._models_loaded = False
|
||||
|
||||
# Symbol-specific trainers (nuevos modelos por símbolo/timeframe)
|
||||
@ -211,6 +244,20 @@ class PredictionService:
|
||||
logger.info("No attention models directory found, creating provider anyway")
|
||||
self._attention_provider = AttentionProvider(attention_path)
|
||||
|
||||
# Load Hierarchical Pipeline (L0→L1→L2)
|
||||
try:
|
||||
pipeline_config = PipelineConfig(
|
||||
attention_model_path=os.path.join(self.models_dir, "attention"),
|
||||
base_model_path=os.path.join(self.models_dir, "symbol_timeframe_models"),
|
||||
metamodel_path=os.path.join(self.models_dir, "metamodels"),
|
||||
neural_gating_path=os.path.join(self.models_dir, "metamodels_neural")
|
||||
)
|
||||
self._hierarchical_pipeline = HierarchicalPipeline(pipeline_config)
|
||||
logger.info("✅ HierarchicalPipeline initialized")
|
||||
except Exception as e:
|
||||
logger.warning(f"HierarchicalPipeline initialization failed: {e}")
|
||||
self._hierarchical_pipeline = None
|
||||
|
||||
self._models_loaded = True
|
||||
|
||||
# Cargar modelos por símbolo si el feature flag está activo
|
||||
@ -589,6 +636,85 @@ class PredictionService:
|
||||
should_trade = attention.attention_score >= min_attention
|
||||
return should_trade, attention
|
||||
|
||||
@property
|
||||
def hierarchical_available(self) -> bool:
|
||||
"""Check if hierarchical pipeline is available."""
|
||||
return self._hierarchical_pipeline is not None
|
||||
|
||||
async def get_hierarchical_prediction(
|
||||
self,
|
||||
symbol: str,
|
||||
load_models: bool = True
|
||||
) -> Optional[HierarchicalResult]:
|
||||
"""
|
||||
Get full hierarchical L0→L1→L2 prediction.
|
||||
|
||||
This is the main method for getting predictions from the 3-level
|
||||
architecture combining attention, base models, and metamodels.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'XAUUSD')
|
||||
load_models: Whether to load models if not already loaded
|
||||
|
||||
Returns:
|
||||
HierarchicalResult with all layer outputs, or None if unavailable
|
||||
"""
|
||||
if not self._hierarchical_pipeline:
|
||||
logger.warning("HierarchicalPipeline not available")
|
||||
return None
|
||||
|
||||
# Load models for this symbol if needed
|
||||
if load_models:
|
||||
try:
|
||||
self._hierarchical_pipeline.load_models(symbol)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load hierarchical models for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
# Get market data for both timeframes
|
||||
try:
|
||||
df_5m, df_15m = await asyncio.gather(
|
||||
self.get_market_data(symbol, "5m", lookback_periods=300),
|
||||
self.get_market_data(symbol, "15m", lookback_periods=150)
|
||||
)
|
||||
|
||||
if df_5m.empty or df_15m.empty:
|
||||
logger.warning(f"Insufficient data for hierarchical prediction: {symbol}")
|
||||
return None
|
||||
|
||||
# Run hierarchical prediction
|
||||
result = self._hierarchical_pipeline.predict(df_5m, df_15m, symbol)
|
||||
|
||||
# Convert to HierarchicalResult dataclass
|
||||
return HierarchicalResult(
|
||||
symbol=result.symbol,
|
||||
timestamp=result.timestamp,
|
||||
attention_score_5m=result.attention_score_5m,
|
||||
attention_score_15m=result.attention_score_15m,
|
||||
attention_class_5m=result.attention_class_5m,
|
||||
attention_class_15m=result.attention_class_15m,
|
||||
pred_high_5m=result.pred_high_5m,
|
||||
pred_low_5m=result.pred_low_5m,
|
||||
pred_high_15m=result.pred_high_15m,
|
||||
pred_low_15m=result.pred_low_15m,
|
||||
delta_high_final=result.delta_high_final,
|
||||
delta_low_final=result.delta_low_final,
|
||||
confidence=result.confidence,
|
||||
confidence_proba=result.confidence_proba,
|
||||
should_trade=result.should_trade,
|
||||
trade_quality=result.trade_quality
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Hierarchical prediction failed for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def get_hierarchical_model_info(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get information about loaded hierarchical models for a symbol."""
|
||||
if not self._hierarchical_pipeline:
|
||||
return None
|
||||
return self._hierarchical_pipeline.get_model_info(symbol)
|
||||
|
||||
async def generate_signal(
|
||||
self,
|
||||
symbol: str,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user