Files
SnarfCode/tests/unit/test_scanner.py
2026-05-22 00:19:30 -04:00

481 lines
16 KiB
Python

"""Unit tests for the Scanner orchestrator."""
import time
from unittest.mock import MagicMock, patch
import pytest
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProfile,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import (
AuthenticationError,
ConnectionLostError,
Scanner,
ScanTimeoutError,
)
# ---------------------------------------------------------------------------
# Helpers / Fixtures
# ---------------------------------------------------------------------------
def make_profile(**overrides) -> ScanProfile:
"""Create a valid ScanProfile with sensible defaults."""
defaults = {
"provider": ProviderType.KUBERNETES,
"credentials": {"kubeconfig_path": "/home/user/.kube/config"},
"endpoints": ["https://k8s-api.local:6443"],
"resource_type_filters": None,
}
defaults.update(overrides)
return ScanProfile(**defaults)
def make_resource(resource_type: str = "kubernetes_deployment", name: str = "nginx") -> DiscoveredResource:
"""Create a sample DiscoveredResource."""
return DiscoveredResource(
resource_type=resource_type,
unique_id=f"apps/v1/{resource_type}/{name}",
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AARCH64,
endpoint="https://k8s-api.local:6443",
attributes={"replicas": 3},
raw_references=[],
)
class MockPlugin(ProviderPlugin):
"""A mock provider plugin for testing."""
def __init__(
self,
supported_types=None,
endpoints=None,
auth_error=None,
discover_result=None,
discover_side_effect=None,
):
self._supported_types = supported_types or [
"kubernetes_deployment",
"kubernetes_service",
"kubernetes_ingress",
]
self._endpoints = endpoints or ["https://k8s-api.local:6443"]
self._auth_error = auth_error
self._discover_result = discover_result
self._discover_side_effect = discover_side_effect
def authenticate(self, credentials: dict[str, str]) -> None:
if self._auth_error:
raise self._auth_error
def get_platform_category(self) -> PlatformCategory:
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
return self._endpoints
def list_supported_resource_types(self) -> list[str]:
return self._supported_types
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
return CpuArchitecture.AARCH64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback=None,
) -> ScanResult:
if self._discover_side_effect:
raise self._discover_side_effect
if self._discover_result:
return self._discover_result
# Default: return one resource per type and invoke progress callback
resources = []
for i, rt in enumerate(resource_types):
resources.append(make_resource(resource_type=rt, name=f"{rt}_instance"))
if progress_callback:
progress_callback(
ScanProgress(
current_resource_type=rt,
resources_discovered=len(resources),
resource_types_completed=i + 1,
total_resource_types=len(resource_types),
)
)
return ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="",
profile_hash="",
)
# ---------------------------------------------------------------------------
# Tests: Successful scan flow
# ---------------------------------------------------------------------------
class TestSuccessfulScan:
"""Tests for the happy path scan flow."""
def test_scan_returns_all_discovered_resources(self):
profile = make_profile()
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
result = scanner.scan()
assert len(result.resources) == 3
assert result.is_partial is False
assert result.scan_timestamp != ""
assert result.profile_hash != ""
def test_scan_with_resource_type_filters(self):
profile = make_profile(
resource_type_filters=["kubernetes_deployment", "kubernetes_service"]
)
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
result = scanner.scan()
assert len(result.resources) == 2
resource_types = [r.resource_type for r in result.resources]
assert "kubernetes_deployment" in resource_types
assert "kubernetes_service" in resource_types
def test_scan_uses_plugin_endpoints_when_profile_has_none(self):
profile = make_profile(endpoints=None)
plugin = MockPlugin(endpoints=["https://fallback.local:6443"])
scanner = Scanner(profile, plugin)
result = scanner.scan()
assert len(result.resources) == 3
def test_scan_uses_profile_endpoints_when_provided(self):
profile = make_profile(endpoints=["https://custom.local:6443"])
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
# The plugin's discover_resources will be called with profile endpoints
result = scanner.scan()
assert result is not None
# ---------------------------------------------------------------------------
# Tests: Authentication failure handling
# ---------------------------------------------------------------------------
class TestAuthenticationFailure:
"""Tests for authentication error handling."""
def test_auth_failure_raises_authentication_error(self):
profile = make_profile()
plugin = MockPlugin(auth_error=RuntimeError("Invalid token"))
scanner = Scanner(profile, plugin)
with pytest.raises(AuthenticationError) as exc_info:
scanner.scan()
assert "kubernetes" in exc_info.value.provider_name
assert "Invalid token" in exc_info.value.reason
def test_auth_error_contains_provider_name(self):
profile = make_profile(provider=ProviderType.DOCKER_SWARM)
plugin = MockPlugin(auth_error=RuntimeError("Connection refused"))
scanner = Scanner(profile, plugin)
with pytest.raises(AuthenticationError) as exc_info:
scanner.scan()
assert exc_info.value.provider_name == "docker_swarm"
def test_auth_error_contains_reason(self):
profile = make_profile()
plugin = MockPlugin(auth_error=RuntimeError("Certificate expired"))
scanner = Scanner(profile, plugin)
with pytest.raises(AuthenticationError) as exc_info:
scanner.scan()
assert "Certificate expired" in exc_info.value.reason
def test_invalid_profile_raises_value_error(self):
profile = make_profile(credentials={})
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
with pytest.raises(ValueError, match="Invalid scan profile"):
scanner.scan()
def test_no_plugin_raises_value_error(self):
profile = make_profile()
scanner = Scanner(profile, plugin=None)
with pytest.raises(ValueError, match="No provider plugin"):
scanner.scan()
# ---------------------------------------------------------------------------
# Tests: Progress callback invocation
# ---------------------------------------------------------------------------
class TestProgressCallback:
"""Tests for progress reporting."""
def test_progress_callback_invoked_per_resource_type(self):
profile = make_profile()
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
progress_updates = []
scanner.scan(progress_callback=progress_updates.append)
# Should have one progress update per resource type (3 types)
assert len(progress_updates) == 3
def test_progress_callback_contains_correct_counts(self):
profile = make_profile()
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
progress_updates = []
scanner.scan(progress_callback=progress_updates.append)
# Last update should show all types completed
last = progress_updates[-1]
assert last.resource_types_completed == 3
assert last.total_resource_types == 3
def test_progress_callback_shows_incremental_completion(self):
profile = make_profile()
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
progress_updates = []
scanner.scan(progress_callback=progress_updates.append)
for i, update in enumerate(progress_updates):
assert update.resource_types_completed == i + 1
def test_scan_works_without_progress_callback(self):
profile = make_profile()
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
# Should not raise
result = scanner.scan(progress_callback=None)
assert result is not None
# ---------------------------------------------------------------------------
# Tests: Retry logic on transient errors
# ---------------------------------------------------------------------------
class TestRetryLogic:
"""Tests for retry behavior on transient errors."""
def test_retries_on_transient_error_then_succeeds(self):
profile = make_profile()
call_count = {"n": 0}
expected_result = ScanResult(
resources=[make_resource()],
warnings=[],
errors=[],
scan_timestamp="",
profile_hash="",
)
class RetryPlugin(MockPlugin):
def discover_resources(self, endpoints, resource_types, progress_callback=None):
call_count["n"] += 1
if call_count["n"] < 3:
raise RuntimeError("Transient network error")
return expected_result
plugin = RetryPlugin()
scanner = Scanner(profile, plugin)
with patch("iac_reverse.scanner.scanner.time.sleep"):
result = scanner.scan()
assert len(result.resources) == 1
assert call_count["n"] == 3
def test_returns_error_result_after_max_retries_exhausted(self):
profile = make_profile()
class AlwaysFailPlugin(MockPlugin):
def discover_resources(self, endpoints, resource_types, progress_callback=None):
raise RuntimeError("Persistent failure")
plugin = AlwaysFailPlugin()
scanner = Scanner(profile, plugin)
with patch("iac_reverse.scanner.scanner.time.sleep"):
result = scanner.scan()
assert result.is_partial is True
assert len(result.errors) > 0
assert "Persistent failure" in result.errors[0]
def test_exponential_backoff_timing(self):
profile = make_profile()
sleep_calls = []
class FailPlugin(MockPlugin):
def discover_resources(self, endpoints, resource_types, progress_callback=None):
raise RuntimeError("Transient error")
plugin = FailPlugin()
scanner = Scanner(profile, plugin)
with patch("iac_reverse.scanner.scanner.time.sleep", side_effect=lambda s: sleep_calls.append(s)):
scanner.scan()
# Should have 3 sleep calls (retries 0, 1, 2 → backoff 1, 2, 4)
assert len(sleep_calls) == 3
assert sleep_calls[0] == 1.0
assert sleep_calls[1] == 2.0
assert sleep_calls[2] == 4.0
# ---------------------------------------------------------------------------
# Tests: Partial inventory on connection loss
# ---------------------------------------------------------------------------
class TestConnectionLoss:
"""Tests for partial inventory return on connection loss."""
def test_connection_error_returns_partial_result(self):
profile = make_profile()
class ConnectionLossPlugin(MockPlugin):
def discover_resources(self, endpoints, resource_types, progress_callback=None):
raise ConnectionError("Connection reset by peer")
plugin = ConnectionLossPlugin()
scanner = Scanner(profile, plugin)
with pytest.raises(ConnectionLostError) as exc_info:
scanner.scan()
partial = exc_info.value.partial_result
assert partial.is_partial is True
assert len(partial.errors) > 0
def test_connection_lost_error_has_partial_result(self):
profile = make_profile()
partial = ScanResult(
resources=[make_resource()],
warnings=["partial"],
errors=["lost connection"],
scan_timestamp="",
profile_hash="",
is_partial=True,
)
class RaisesConnectionLost(MockPlugin):
def discover_resources(self, endpoints, resource_types, progress_callback=None):
raise ConnectionLostError(partial_result=partial)
plugin = RaisesConnectionLost()
scanner = Scanner(profile, plugin)
with pytest.raises(ConnectionLostError) as exc_info:
scanner.scan()
assert exc_info.value.partial_result.is_partial is True
assert len(exc_info.value.partial_result.resources) == 1
# ---------------------------------------------------------------------------
# Tests: Warning for unsupported resource types
# ---------------------------------------------------------------------------
class TestUnsupportedResourceTypes:
"""Tests for warning logging on unsupported resource types."""
def test_unsupported_types_generate_warnings(self):
profile = make_profile(
resource_type_filters=[
"kubernetes_deployment",
"nonexistent_type",
"another_fake_type",
]
)
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
result = scanner.scan()
# Should have warnings for the 2 unsupported types
unsupported_warnings = [
w for w in result.warnings if "Unsupported resource type" in w
]
assert len(unsupported_warnings) == 2
assert "nonexistent_type" in unsupported_warnings[0]
assert "another_fake_type" in unsupported_warnings[1]
def test_unsupported_types_do_not_block_supported_types(self):
profile = make_profile(
resource_type_filters=[
"kubernetes_deployment",
"nonexistent_type",
]
)
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
result = scanner.scan()
# Should still discover the supported type
assert len(result.resources) == 1
assert result.resources[0].resource_type == "kubernetes_deployment"
def test_all_unsupported_types_returns_empty_resources(self):
profile = make_profile(
resource_type_filters=["fake_type_1", "fake_type_2"]
)
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
result = scanner.scan()
assert len(result.resources) == 0
assert len(result.warnings) == 2
def test_warning_includes_provider_name(self):
profile = make_profile(
resource_type_filters=["unsupported_thing"]
)
plugin = MockPlugin()
scanner = Scanner(profile, plugin)
result = scanner.scan()
assert "kubernetes" in result.warnings[0]