"""Tests for VLLMBackend.""" import pytest from unittest.mock import AsyncMock, MagicMock, patch from src.engine.vllm_backend import VLLMBackend from src.exceptions import ( BackendUnavailableError, InferenceTimeoutError, ModelNotFoundError, ValidationError, ) class TestVLLMBackendValidation: """Test input validation in VLLMBackend.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) backend.base_url = "http://localhost:8000" backend.default_model = "mistralai/Mistral-7B-Instruct-v0.2" backend._client = None backend.max_tokens_limit = 4096 backend.min_tokens = 1 return backend def test_validate_empty_messages(self, backend): """Test validation rejects empty messages.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[], max_tokens=100, temperature=0.7, top_p=0.9, ) assert "empty" in exc.value.message.lower() def test_validate_missing_role(self, backend): """Test validation rejects message without role.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"content": "Hello"}], max_tokens=100, temperature=0.7, top_p=0.9, ) assert "role" in exc.value.message.lower() def test_validate_invalid_role(self, backend): """Test validation rejects invalid role.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "invalid", "content": "Hello"}], max_tokens=100, temperature=0.7, top_p=0.9, ) assert "invalid" in exc.value.message.lower() def test_validate_missing_content(self, backend): """Test validation rejects message without content.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user"}], max_tokens=100, temperature=0.7, top_p=0.9, ) assert "content" in exc.value.message.lower() def test_validate_max_tokens_too_low(self, backend): """Test validation rejects max_tokens below minimum.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=0, temperature=0.7, top_p=0.9, ) assert "max_tokens" in exc.value.param def test_validate_max_tokens_too_high(self, backend): """Test validation rejects max_tokens above maximum.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=100000, temperature=0.7, top_p=0.9, ) assert "max_tokens" in exc.value.param def test_validate_temperature_too_low(self, backend): """Test validation rejects temperature below 0.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=100, temperature=-0.1, top_p=0.9, ) assert "temperature" in exc.value.param def test_validate_temperature_too_high(self, backend): """Test validation rejects temperature above 2.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=100, temperature=2.5, top_p=0.9, ) assert "temperature" in exc.value.param def test_validate_top_p_negative(self, backend): """Test validation rejects negative top_p.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=100, temperature=0.7, top_p=-0.1, ) assert "top_p" in exc.value.param def test_validate_top_p_too_high(self, backend): """Test validation rejects top_p above 1.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=100, temperature=0.7, top_p=1.5, ) assert "top_p" in exc.value.param def test_validate_valid_inputs(self, backend): """Test validation passes for valid inputs.""" # Should not raise backend._validate_chat_inputs( messages=[{"role": "user", "content": "Hello"}], max_tokens=100, temperature=0.7, top_p=0.9, ) def test_validate_all_valid_roles(self, backend): """Test validation accepts all valid roles.""" # Should not raise backend._validate_chat_inputs( messages=[ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ], max_tokens=100, temperature=0.7, top_p=0.9, ) def test_validate_request_id_in_error(self, backend): """Test validation includes request_id in error.""" with pytest.raises(ValidationError) as exc: backend._validate_chat_inputs( messages=[], max_tokens=100, temperature=0.7, top_p=0.9, request_id="test-123", ) assert exc.value.request_id == "test-123" class TestVLLMBackendTokenEstimation: """Test token estimation.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) return backend def test_estimate_tokens_empty(self, backend): """Test token estimation for empty string.""" assert backend._estimate_tokens("") == 0 def test_estimate_tokens_none_like(self, backend): """Test token estimation for falsy values.""" assert backend._estimate_tokens("") == 0 def test_estimate_tokens_short(self, backend): """Test token estimation for short text.""" result = backend._estimate_tokens("Hello") assert result >= 1 def test_estimate_tokens_long(self, backend): """Test token estimation for longer text.""" text = "Hello world, this is a longer text to test token estimation." result = backend._estimate_tokens(text) # Should be roughly len(text) / 4 assert 10 <= result <= 20 def test_estimate_tokens_very_long(self, backend): """Test token estimation for very long text.""" text = "a" * 1000 result = backend._estimate_tokens(text) assert result == 250 # 1000 / 4 class TestVLLMBackendHealthCheck: """Test health check functionality.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) backend.base_url = "http://localhost:8000" backend._client = None return backend @pytest.mark.asyncio async def test_health_check_success(self, backend): """Test health check returns True when server is healthy.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_client.get = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client result = await backend.health_check() assert result is True mock_client.get.assert_called_once_with("/health") @pytest.mark.asyncio async def test_health_check_failure(self, backend): """Test health check returns False when server returns error.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 500 mock_client.get = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client result = await backend.health_check() assert result is False @pytest.mark.asyncio async def test_health_check_exception(self, backend): """Test health check returns False on exception.""" mock_client = AsyncMock() mock_client.get = AsyncMock(side_effect=Exception("Connection refused")) mock_client.is_closed = False backend._client = mock_client result = await backend.health_check() assert result is False class TestVLLMBackendListModels: """Test model listing functionality.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) backend.base_url = "http://localhost:8000" backend._client = None return backend @pytest.mark.asyncio async def test_list_models_success(self, backend): """Test list models returns models from vLLM.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ {"id": "mistralai/Mistral-7B-Instruct-v0.2", "object": "model"}, {"id": "codellama/CodeLlama-7b-Instruct-hf", "object": "model"}, ] } mock_response.raise_for_status = MagicMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client result = await backend.list_models() assert len(result) == 2 assert result[0]["id"] == "mistralai/Mistral-7B-Instruct-v0.2" @pytest.mark.asyncio async def test_list_models_empty(self, backend): """Test list models handles empty response.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"data": []} mock_response.raise_for_status = MagicMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client result = await backend.list_models() assert result == [] class TestVLLMBackendLoRAAdapters: """Test LoRA adapter functionality.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) backend.base_url = "http://localhost:8000" backend._client = None return backend @pytest.mark.asyncio async def test_list_lora_adapters(self, backend): """Test listing LoRA adapters.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { "data": [ {"id": "mistralai/Mistral-7B-Instruct-v0.2", "object": "model"}, {"id": "erp-core-lora", "object": "model", "parent": "mistralai/Mistral-7B-Instruct-v0.2"}, ] } mock_response.raise_for_status = MagicMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client result = await backend.list_lora_adapters() # Should only return the adapter with parent assert len(result) == 1 assert result[0]["id"] == "erp-core-lora" class TestVLLMBackendChatCompletion: """Test chat completion functionality.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) backend.base_url = "http://localhost:8000" backend.default_model = "mistralai/Mistral-7B-Instruct-v0.2" backend._client = None backend.max_tokens_limit = 4096 backend.min_tokens = 1 return backend @pytest.fixture def mock_vllm_response(self): """Sample vLLM response.""" return { "id": "chatcmpl-abc123", "object": "chat.completion", "created": 1700000000, "choices": [ { "index": 0, "message": { "role": "assistant", "content": "Hello! How can I assist you today?", }, "finish_reason": "stop", } ], "usage": { "prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18, }, } @pytest.mark.asyncio async def test_chat_completion_success(self, backend, mock_vllm_response): """Test successful chat completion.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_vllm_response mock_response.raise_for_status = MagicMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client with patch('src.engine.vllm_backend.settings') as mock_settings: mock_settings.request_timeout_ms = 60000 mock_settings.connect_timeout_ms = 5000 result = await backend.chat_completion( model="mistralai/Mistral-7B-Instruct-v0.2", messages=[{"role": "user", "content": "Hello"}], ) assert result["content"] == "Hello! How can I assist you today?" assert result["usage"]["total_tokens"] == 18 assert result["finish_reason"] == "stop" @pytest.mark.asyncio async def test_chat_completion_with_lora(self, backend, mock_vllm_response): """Test chat completion with LoRA adapter.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = mock_vllm_response mock_response.raise_for_status = MagicMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client with patch('src.engine.vllm_backend.settings') as mock_settings: mock_settings.request_timeout_ms = 60000 mock_settings.connect_timeout_ms = 5000 result = await backend.chat_completion( model="mistralai/Mistral-7B-Instruct-v0.2", messages=[{"role": "user", "content": "Hello"}], lora_adapter="erp-core", ) assert result["lora_adapter"] == "erp-core" # Verify the request included lora_request call_args = mock_client.post.call_args request_json = call_args.kwargs["json"] assert "extra_body" in request_json assert request_json["extra_body"]["lora_request"]["lora_name"] == "erp-core" @pytest.mark.asyncio async def test_chat_completion_model_not_found(self, backend): """Test chat completion with non-existent model.""" mock_client = AsyncMock() mock_response = MagicMock() mock_response.status_code = 404 mock_client.post = AsyncMock(return_value=mock_response) mock_client.is_closed = False backend._client = mock_client with patch('src.engine.vllm_backend.settings') as mock_settings: mock_settings.request_timeout_ms = 60000 mock_settings.connect_timeout_ms = 5000 with pytest.raises(ModelNotFoundError) as exc: await backend.chat_completion( model="non-existent-model", messages=[{"role": "user", "content": "Hello"}], ) assert "non-existent-model" in str(exc.value.message) @pytest.mark.asyncio async def test_chat_completion_validation_error(self, backend): """Test chat completion with invalid parameters.""" with pytest.raises(ValidationError): await backend.chat_completion( model="mistralai/Mistral-7B-Instruct-v0.2", messages=[], # Empty messages ) class TestVLLMBackendClose: """Test close functionality.""" @pytest.fixture def backend(self): """Create a backend instance for testing.""" with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None): backend = VLLMBackend.__new__(VLLMBackend) backend._client = None return backend @pytest.mark.asyncio async def test_close_with_client(self, backend): """Test close when client exists.""" mock_client = AsyncMock() mock_client.is_closed = False mock_client.aclose = AsyncMock() backend._client = mock_client await backend.close() mock_client.aclose.assert_called_once() assert backend._client is None @pytest.mark.asyncio async def test_close_without_client(self, backend): """Test close when no client exists.""" backend._client = None # Should not raise await backend.close() assert backend._client is None @pytest.mark.asyncio async def test_close_already_closed_client(self, backend): """Test close when client is already closed.""" mock_client = AsyncMock() mock_client.is_closed = True backend._client = mock_client await backend.close() # Should not call aclose since client is already closed