516 lines
18 KiB
Python
516 lines
18 KiB
Python
"""Tests for VLLMBackend."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from src.engine.vllm_backend import VLLMBackend
|
|
from src.exceptions import (
|
|
BackendUnavailableError,
|
|
InferenceTimeoutError,
|
|
ModelNotFoundError,
|
|
ValidationError,
|
|
)
|
|
|
|
|
|
class TestVLLMBackendValidation:
|
|
"""Test input validation in VLLMBackend."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
backend.base_url = "http://localhost:8000"
|
|
backend.default_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
|
backend._client = None
|
|
backend.max_tokens_limit = 4096
|
|
backend.min_tokens = 1
|
|
return backend
|
|
|
|
def test_validate_empty_messages(self, backend):
|
|
"""Test validation rejects empty messages."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "empty" in exc.value.message.lower()
|
|
|
|
def test_validate_missing_role(self, backend):
|
|
"""Test validation rejects message without role."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "role" in exc.value.message.lower()
|
|
|
|
def test_validate_invalid_role(self, backend):
|
|
"""Test validation rejects invalid role."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "invalid", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "invalid" in exc.value.message.lower()
|
|
|
|
def test_validate_missing_content(self, backend):
|
|
"""Test validation rejects message without content."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "content" in exc.value.message.lower()
|
|
|
|
def test_validate_max_tokens_too_low(self, backend):
|
|
"""Test validation rejects max_tokens below minimum."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=0,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "max_tokens" in exc.value.param
|
|
|
|
def test_validate_max_tokens_too_high(self, backend):
|
|
"""Test validation rejects max_tokens above maximum."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100000,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "max_tokens" in exc.value.param
|
|
|
|
def test_validate_temperature_too_low(self, backend):
|
|
"""Test validation rejects temperature below 0."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=-0.1,
|
|
top_p=0.9,
|
|
)
|
|
assert "temperature" in exc.value.param
|
|
|
|
def test_validate_temperature_too_high(self, backend):
|
|
"""Test validation rejects temperature above 2."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=2.5,
|
|
top_p=0.9,
|
|
)
|
|
assert "temperature" in exc.value.param
|
|
|
|
def test_validate_top_p_negative(self, backend):
|
|
"""Test validation rejects negative top_p."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=-0.1,
|
|
)
|
|
assert "top_p" in exc.value.param
|
|
|
|
def test_validate_top_p_too_high(self, backend):
|
|
"""Test validation rejects top_p above 1."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=1.5,
|
|
)
|
|
assert "top_p" in exc.value.param
|
|
|
|
def test_validate_valid_inputs(self, backend):
|
|
"""Test validation passes for valid inputs."""
|
|
# Should not raise
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
|
|
def test_validate_all_valid_roles(self, backend):
|
|
"""Test validation accepts all valid roles."""
|
|
# Should not raise
|
|
backend._validate_chat_inputs(
|
|
messages=[
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
|
|
def test_validate_request_id_in_error(self, backend):
|
|
"""Test validation includes request_id in error."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
request_id="test-123",
|
|
)
|
|
assert exc.value.request_id == "test-123"
|
|
|
|
|
|
class TestVLLMBackendTokenEstimation:
|
|
"""Test token estimation."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
return backend
|
|
|
|
def test_estimate_tokens_empty(self, backend):
|
|
"""Test token estimation for empty string."""
|
|
assert backend._estimate_tokens("") == 0
|
|
|
|
def test_estimate_tokens_none_like(self, backend):
|
|
"""Test token estimation for falsy values."""
|
|
assert backend._estimate_tokens("") == 0
|
|
|
|
def test_estimate_tokens_short(self, backend):
|
|
"""Test token estimation for short text."""
|
|
result = backend._estimate_tokens("Hello")
|
|
assert result >= 1
|
|
|
|
def test_estimate_tokens_long(self, backend):
|
|
"""Test token estimation for longer text."""
|
|
text = "Hello world, this is a longer text to test token estimation."
|
|
result = backend._estimate_tokens(text)
|
|
# Should be roughly len(text) / 4
|
|
assert 10 <= result <= 20
|
|
|
|
def test_estimate_tokens_very_long(self, backend):
|
|
"""Test token estimation for very long text."""
|
|
text = "a" * 1000
|
|
result = backend._estimate_tokens(text)
|
|
assert result == 250 # 1000 / 4
|
|
|
|
|
|
class TestVLLMBackendHealthCheck:
|
|
"""Test health check functionality."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
backend.base_url = "http://localhost:8000"
|
|
backend._client = None
|
|
return backend
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_success(self, backend):
|
|
"""Test health check returns True when server is healthy."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
result = await backend.health_check()
|
|
assert result is True
|
|
mock_client.get.assert_called_once_with("/health")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_failure(self, backend):
|
|
"""Test health check returns False when server returns error."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 500
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
result = await backend.health_check()
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_exception(self, backend):
|
|
"""Test health check returns False on exception."""
|
|
mock_client = AsyncMock()
|
|
mock_client.get = AsyncMock(side_effect=Exception("Connection refused"))
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
result = await backend.health_check()
|
|
assert result is False
|
|
|
|
|
|
class TestVLLMBackendListModels:
|
|
"""Test model listing functionality."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
backend.base_url = "http://localhost:8000"
|
|
backend._client = None
|
|
return backend
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_models_success(self, backend):
|
|
"""Test list models returns models from vLLM."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{"id": "mistralai/Mistral-7B-Instruct-v0.2", "object": "model"},
|
|
{"id": "codellama/CodeLlama-7b-Instruct-hf", "object": "model"},
|
|
]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
result = await backend.list_models()
|
|
assert len(result) == 2
|
|
assert result[0]["id"] == "mistralai/Mistral-7B-Instruct-v0.2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_models_empty(self, backend):
|
|
"""Test list models handles empty response."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {"data": []}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
result = await backend.list_models()
|
|
assert result == []
|
|
|
|
|
|
class TestVLLMBackendLoRAAdapters:
|
|
"""Test LoRA adapter functionality."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
backend.base_url = "http://localhost:8000"
|
|
backend._client = None
|
|
return backend
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_lora_adapters(self, backend):
|
|
"""Test listing LoRA adapters."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{"id": "mistralai/Mistral-7B-Instruct-v0.2", "object": "model"},
|
|
{"id": "erp-core-lora", "object": "model", "parent": "mistralai/Mistral-7B-Instruct-v0.2"},
|
|
]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.get = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
result = await backend.list_lora_adapters()
|
|
# Should only return the adapter with parent
|
|
assert len(result) == 1
|
|
assert result[0]["id"] == "erp-core-lora"
|
|
|
|
|
|
class TestVLLMBackendChatCompletion:
|
|
"""Test chat completion functionality."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
backend.base_url = "http://localhost:8000"
|
|
backend.default_model = "mistralai/Mistral-7B-Instruct-v0.2"
|
|
backend._client = None
|
|
backend.max_tokens_limit = 4096
|
|
backend.min_tokens = 1
|
|
return backend
|
|
|
|
@pytest.fixture
|
|
def mock_vllm_response(self):
|
|
"""Sample vLLM response."""
|
|
return {
|
|
"id": "chatcmpl-abc123",
|
|
"object": "chat.completion",
|
|
"created": 1700000000,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Hello! How can I assist you today?",
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 10,
|
|
"completion_tokens": 8,
|
|
"total_tokens": 18,
|
|
},
|
|
}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_success(self, backend, mock_vllm_response):
|
|
"""Test successful chat completion."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_vllm_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
with patch('src.engine.vllm_backend.settings') as mock_settings:
|
|
mock_settings.request_timeout_ms = 60000
|
|
mock_settings.connect_timeout_ms = 5000
|
|
|
|
result = await backend.chat_completion(
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
)
|
|
|
|
assert result["content"] == "Hello! How can I assist you today?"
|
|
assert result["usage"]["total_tokens"] == 18
|
|
assert result["finish_reason"] == "stop"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_with_lora(self, backend, mock_vllm_response):
|
|
"""Test chat completion with LoRA adapter."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_vllm_response
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
with patch('src.engine.vllm_backend.settings') as mock_settings:
|
|
mock_settings.request_timeout_ms = 60000
|
|
mock_settings.connect_timeout_ms = 5000
|
|
|
|
result = await backend.chat_completion(
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
lora_adapter="erp-core",
|
|
)
|
|
|
|
assert result["lora_adapter"] == "erp-core"
|
|
# Verify the request included lora_request
|
|
call_args = mock_client.post.call_args
|
|
request_json = call_args.kwargs["json"]
|
|
assert "extra_body" in request_json
|
|
assert request_json["extra_body"]["lora_request"]["lora_name"] == "erp-core"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_model_not_found(self, backend):
|
|
"""Test chat completion with non-existent model."""
|
|
mock_client = AsyncMock()
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 404
|
|
mock_client.post = AsyncMock(return_value=mock_response)
|
|
mock_client.is_closed = False
|
|
backend._client = mock_client
|
|
|
|
with patch('src.engine.vllm_backend.settings') as mock_settings:
|
|
mock_settings.request_timeout_ms = 60000
|
|
mock_settings.connect_timeout_ms = 5000
|
|
|
|
with pytest.raises(ModelNotFoundError) as exc:
|
|
await backend.chat_completion(
|
|
model="non-existent-model",
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
)
|
|
|
|
assert "non-existent-model" in str(exc.value.message)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completion_validation_error(self, backend):
|
|
"""Test chat completion with invalid parameters."""
|
|
with pytest.raises(ValidationError):
|
|
await backend.chat_completion(
|
|
model="mistralai/Mistral-7B-Instruct-v0.2",
|
|
messages=[], # Empty messages
|
|
)
|
|
|
|
|
|
class TestVLLMBackendClose:
|
|
"""Test close functionality."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
with patch.object(VLLMBackend, '__init__', lambda self, base_url=None: None):
|
|
backend = VLLMBackend.__new__(VLLMBackend)
|
|
backend._client = None
|
|
return backend
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_with_client(self, backend):
|
|
"""Test close when client exists."""
|
|
mock_client = AsyncMock()
|
|
mock_client.is_closed = False
|
|
mock_client.aclose = AsyncMock()
|
|
backend._client = mock_client
|
|
|
|
await backend.close()
|
|
|
|
mock_client.aclose.assert_called_once()
|
|
assert backend._client is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_without_client(self, backend):
|
|
"""Test close when no client exists."""
|
|
backend._client = None
|
|
|
|
# Should not raise
|
|
await backend.close()
|
|
assert backend._client is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_already_closed_client(self, backend):
|
|
"""Test close when client is already closed."""
|
|
mock_client = AsyncMock()
|
|
mock_client.is_closed = True
|
|
backend._client = mock_client
|
|
|
|
await backend.close()
|
|
# Should not call aclose since client is already closed
|