164 lines
5.6 KiB
Python
164 lines
5.6 KiB
Python
"""Tests for OllamaBackend."""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from src.engine.ollama_backend import OllamaBackend
|
|
from src.exceptions import (
|
|
BackendUnavailableError,
|
|
InferenceTimeoutError,
|
|
ModelNotFoundError,
|
|
ValidationError,
|
|
)
|
|
|
|
|
|
class TestOllamaBackendValidation:
|
|
"""Test input validation in OllamaBackend."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
return OllamaBackend(base_url="http://localhost:11434")
|
|
|
|
def test_validate_empty_messages(self, backend):
|
|
"""Test validation rejects empty messages."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "empty" in exc.value.message.lower()
|
|
|
|
def test_validate_missing_role(self, backend):
|
|
"""Test validation rejects message without role."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "role" in exc.value.message.lower()
|
|
|
|
def test_validate_invalid_role(self, backend):
|
|
"""Test validation rejects invalid role."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "invalid", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "invalid" in exc.value.message.lower()
|
|
|
|
def test_validate_max_tokens_too_low(self, backend):
|
|
"""Test validation rejects max_tokens below minimum."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=0,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "max_tokens" in exc.value.param
|
|
|
|
def test_validate_max_tokens_too_high(self, backend):
|
|
"""Test validation rejects max_tokens above maximum."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100000,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
assert "max_tokens" in exc.value.param
|
|
|
|
def test_validate_temperature_too_low(self, backend):
|
|
"""Test validation rejects temperature below 0."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=-0.1,
|
|
top_p=0.9,
|
|
)
|
|
assert "temperature" in exc.value.param
|
|
|
|
def test_validate_temperature_too_high(self, backend):
|
|
"""Test validation rejects temperature above 2."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=2.5,
|
|
top_p=0.9,
|
|
)
|
|
assert "temperature" in exc.value.param
|
|
|
|
def test_validate_top_p_out_of_range(self, backend):
|
|
"""Test validation rejects top_p out of range."""
|
|
with pytest.raises(ValidationError) as exc:
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=1.5,
|
|
)
|
|
assert "top_p" in exc.value.param
|
|
|
|
def test_validate_valid_inputs(self, backend):
|
|
"""Test validation passes for valid inputs."""
|
|
# Should not raise
|
|
backend._validate_chat_inputs(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
top_p=0.9,
|
|
)
|
|
|
|
|
|
class TestOllamaBackendTokenEstimation:
|
|
"""Test token estimation."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
return OllamaBackend(base_url="http://localhost:11434")
|
|
|
|
def test_estimate_tokens_empty(self, backend):
|
|
"""Test token estimation for empty string."""
|
|
assert backend._estimate_tokens("") == 0
|
|
|
|
def test_estimate_tokens_short(self, backend):
|
|
"""Test token estimation for short text."""
|
|
result = backend._estimate_tokens("Hello")
|
|
assert result >= 1
|
|
|
|
def test_estimate_tokens_long(self, backend):
|
|
"""Test token estimation for longer text."""
|
|
text = "Hello world, this is a longer text to test token estimation."
|
|
result = backend._estimate_tokens(text)
|
|
# Should be roughly len(text) / 4
|
|
assert 10 <= result <= 20
|
|
|
|
|
|
class TestOllamaBackendModelMapping:
|
|
"""Test model name mapping."""
|
|
|
|
@pytest.fixture
|
|
def backend(self):
|
|
"""Create a backend instance for testing."""
|
|
return OllamaBackend(base_url="http://localhost:11434")
|
|
|
|
def test_map_known_model(self, backend):
|
|
"""Test mapping of known model name."""
|
|
result = backend._map_model_name("gpt-oss-20b")
|
|
assert result == backend.default_model
|
|
|
|
def test_map_unknown_model_passthrough(self, backend):
|
|
"""Test unknown model names pass through unchanged."""
|
|
result = backend._map_model_name("custom-model")
|
|
assert result == "custom-model"
|