481 lines
16 KiB
Python
481 lines
16 KiB
Python
"""Unit tests for the Scanner orchestrator."""
|
|
|
|
import time
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from iac_reverse.models import (
|
|
CpuArchitecture,
|
|
DiscoveredResource,
|
|
PlatformCategory,
|
|
ProviderType,
|
|
ScanProfile,
|
|
ScanProgress,
|
|
ScanResult,
|
|
)
|
|
from iac_reverse.plugin_base import ProviderPlugin
|
|
from iac_reverse.scanner.scanner import (
|
|
AuthenticationError,
|
|
ConnectionLostError,
|
|
Scanner,
|
|
ScanTimeoutError,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers / Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def make_profile(**overrides) -> ScanProfile:
|
|
"""Create a valid ScanProfile with sensible defaults."""
|
|
defaults = {
|
|
"provider": ProviderType.KUBERNETES,
|
|
"credentials": {"kubeconfig_path": "/home/user/.kube/config"},
|
|
"endpoints": ["https://k8s-api.local:6443"],
|
|
"resource_type_filters": None,
|
|
}
|
|
defaults.update(overrides)
|
|
return ScanProfile(**defaults)
|
|
|
|
|
|
def make_resource(resource_type: str = "kubernetes_deployment", name: str = "nginx") -> DiscoveredResource:
|
|
"""Create a sample DiscoveredResource."""
|
|
return DiscoveredResource(
|
|
resource_type=resource_type,
|
|
unique_id=f"apps/v1/{resource_type}/{name}",
|
|
name=name,
|
|
provider=ProviderType.KUBERNETES,
|
|
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
|
|
architecture=CpuArchitecture.AARCH64,
|
|
endpoint="https://k8s-api.local:6443",
|
|
attributes={"replicas": 3},
|
|
raw_references=[],
|
|
)
|
|
|
|
|
|
class MockPlugin(ProviderPlugin):
|
|
"""A mock provider plugin for testing."""
|
|
|
|
def __init__(
|
|
self,
|
|
supported_types=None,
|
|
endpoints=None,
|
|
auth_error=None,
|
|
discover_result=None,
|
|
discover_side_effect=None,
|
|
):
|
|
self._supported_types = supported_types or [
|
|
"kubernetes_deployment",
|
|
"kubernetes_service",
|
|
"kubernetes_ingress",
|
|
]
|
|
self._endpoints = endpoints or ["https://k8s-api.local:6443"]
|
|
self._auth_error = auth_error
|
|
self._discover_result = discover_result
|
|
self._discover_side_effect = discover_side_effect
|
|
|
|
def authenticate(self, credentials: dict[str, str]) -> None:
|
|
if self._auth_error:
|
|
raise self._auth_error
|
|
|
|
def get_platform_category(self) -> PlatformCategory:
|
|
return PlatformCategory.CONTAINER_ORCHESTRATION
|
|
|
|
def list_endpoints(self) -> list[str]:
|
|
return self._endpoints
|
|
|
|
def list_supported_resource_types(self) -> list[str]:
|
|
return self._supported_types
|
|
|
|
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
|
|
return CpuArchitecture.AARCH64
|
|
|
|
def discover_resources(
|
|
self,
|
|
endpoints: list[str],
|
|
resource_types: list[str],
|
|
progress_callback=None,
|
|
) -> ScanResult:
|
|
if self._discover_side_effect:
|
|
raise self._discover_side_effect
|
|
|
|
if self._discover_result:
|
|
return self._discover_result
|
|
|
|
# Default: return one resource per type and invoke progress callback
|
|
resources = []
|
|
for i, rt in enumerate(resource_types):
|
|
resources.append(make_resource(resource_type=rt, name=f"{rt}_instance"))
|
|
if progress_callback:
|
|
progress_callback(
|
|
ScanProgress(
|
|
current_resource_type=rt,
|
|
resources_discovered=len(resources),
|
|
resource_types_completed=i + 1,
|
|
total_resource_types=len(resource_types),
|
|
)
|
|
)
|
|
|
|
return ScanResult(
|
|
resources=resources,
|
|
warnings=[],
|
|
errors=[],
|
|
scan_timestamp="",
|
|
profile_hash="",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Successful scan flow
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSuccessfulScan:
|
|
"""Tests for the happy path scan flow."""
|
|
|
|
def test_scan_returns_all_discovered_resources(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
assert len(result.resources) == 3
|
|
assert result.is_partial is False
|
|
assert result.scan_timestamp != ""
|
|
assert result.profile_hash != ""
|
|
|
|
def test_scan_with_resource_type_filters(self):
|
|
profile = make_profile(
|
|
resource_type_filters=["kubernetes_deployment", "kubernetes_service"]
|
|
)
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
assert len(result.resources) == 2
|
|
resource_types = [r.resource_type for r in result.resources]
|
|
assert "kubernetes_deployment" in resource_types
|
|
assert "kubernetes_service" in resource_types
|
|
|
|
def test_scan_uses_plugin_endpoints_when_profile_has_none(self):
|
|
profile = make_profile(endpoints=None)
|
|
plugin = MockPlugin(endpoints=["https://fallback.local:6443"])
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
assert len(result.resources) == 3
|
|
|
|
def test_scan_uses_profile_endpoints_when_provided(self):
|
|
profile = make_profile(endpoints=["https://custom.local:6443"])
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
# The plugin's discover_resources will be called with profile endpoints
|
|
result = scanner.scan()
|
|
assert result is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Authentication failure handling
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAuthenticationFailure:
|
|
"""Tests for authentication error handling."""
|
|
|
|
def test_auth_failure_raises_authentication_error(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin(auth_error=RuntimeError("Invalid token"))
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with pytest.raises(AuthenticationError) as exc_info:
|
|
scanner.scan()
|
|
|
|
assert "kubernetes" in exc_info.value.provider_name
|
|
assert "Invalid token" in exc_info.value.reason
|
|
|
|
def test_auth_error_contains_provider_name(self):
|
|
profile = make_profile(provider=ProviderType.DOCKER_SWARM)
|
|
plugin = MockPlugin(auth_error=RuntimeError("Connection refused"))
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with pytest.raises(AuthenticationError) as exc_info:
|
|
scanner.scan()
|
|
|
|
assert exc_info.value.provider_name == "docker_swarm"
|
|
|
|
def test_auth_error_contains_reason(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin(auth_error=RuntimeError("Certificate expired"))
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with pytest.raises(AuthenticationError) as exc_info:
|
|
scanner.scan()
|
|
|
|
assert "Certificate expired" in exc_info.value.reason
|
|
|
|
def test_invalid_profile_raises_value_error(self):
|
|
profile = make_profile(credentials={})
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with pytest.raises(ValueError, match="Invalid scan profile"):
|
|
scanner.scan()
|
|
|
|
def test_no_plugin_raises_value_error(self):
|
|
profile = make_profile()
|
|
scanner = Scanner(profile, plugin=None)
|
|
|
|
with pytest.raises(ValueError, match="No provider plugin"):
|
|
scanner.scan()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Progress callback invocation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestProgressCallback:
|
|
"""Tests for progress reporting."""
|
|
|
|
def test_progress_callback_invoked_per_resource_type(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
progress_updates = []
|
|
scanner.scan(progress_callback=progress_updates.append)
|
|
|
|
# Should have one progress update per resource type (3 types)
|
|
assert len(progress_updates) == 3
|
|
|
|
def test_progress_callback_contains_correct_counts(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
progress_updates = []
|
|
scanner.scan(progress_callback=progress_updates.append)
|
|
|
|
# Last update should show all types completed
|
|
last = progress_updates[-1]
|
|
assert last.resource_types_completed == 3
|
|
assert last.total_resource_types == 3
|
|
|
|
def test_progress_callback_shows_incremental_completion(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
progress_updates = []
|
|
scanner.scan(progress_callback=progress_updates.append)
|
|
|
|
for i, update in enumerate(progress_updates):
|
|
assert update.resource_types_completed == i + 1
|
|
|
|
def test_scan_works_without_progress_callback(self):
|
|
profile = make_profile()
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
# Should not raise
|
|
result = scanner.scan(progress_callback=None)
|
|
assert result is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Retry logic on transient errors
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRetryLogic:
|
|
"""Tests for retry behavior on transient errors."""
|
|
|
|
def test_retries_on_transient_error_then_succeeds(self):
|
|
profile = make_profile()
|
|
call_count = {"n": 0}
|
|
expected_result = ScanResult(
|
|
resources=[make_resource()],
|
|
warnings=[],
|
|
errors=[],
|
|
scan_timestamp="",
|
|
profile_hash="",
|
|
)
|
|
|
|
class RetryPlugin(MockPlugin):
|
|
def discover_resources(self, endpoints, resource_types, progress_callback=None):
|
|
call_count["n"] += 1
|
|
if call_count["n"] < 3:
|
|
raise RuntimeError("Transient network error")
|
|
return expected_result
|
|
|
|
plugin = RetryPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with patch("iac_reverse.scanner.scanner.time.sleep"):
|
|
result = scanner.scan()
|
|
|
|
assert len(result.resources) == 1
|
|
assert call_count["n"] == 3
|
|
|
|
def test_returns_error_result_after_max_retries_exhausted(self):
|
|
profile = make_profile()
|
|
|
|
class AlwaysFailPlugin(MockPlugin):
|
|
def discover_resources(self, endpoints, resource_types, progress_callback=None):
|
|
raise RuntimeError("Persistent failure")
|
|
|
|
plugin = AlwaysFailPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with patch("iac_reverse.scanner.scanner.time.sleep"):
|
|
result = scanner.scan()
|
|
|
|
assert result.is_partial is True
|
|
assert len(result.errors) > 0
|
|
assert "Persistent failure" in result.errors[0]
|
|
|
|
def test_exponential_backoff_timing(self):
|
|
profile = make_profile()
|
|
sleep_calls = []
|
|
|
|
class FailPlugin(MockPlugin):
|
|
def discover_resources(self, endpoints, resource_types, progress_callback=None):
|
|
raise RuntimeError("Transient error")
|
|
|
|
plugin = FailPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with patch("iac_reverse.scanner.scanner.time.sleep", side_effect=lambda s: sleep_calls.append(s)):
|
|
scanner.scan()
|
|
|
|
# Should have 3 sleep calls (retries 0, 1, 2 → backoff 1, 2, 4)
|
|
assert len(sleep_calls) == 3
|
|
assert sleep_calls[0] == 1.0
|
|
assert sleep_calls[1] == 2.0
|
|
assert sleep_calls[2] == 4.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Partial inventory on connection loss
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConnectionLoss:
|
|
"""Tests for partial inventory return on connection loss."""
|
|
|
|
def test_connection_error_returns_partial_result(self):
|
|
profile = make_profile()
|
|
|
|
class ConnectionLossPlugin(MockPlugin):
|
|
def discover_resources(self, endpoints, resource_types, progress_callback=None):
|
|
raise ConnectionError("Connection reset by peer")
|
|
|
|
plugin = ConnectionLossPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with pytest.raises(ConnectionLostError) as exc_info:
|
|
scanner.scan()
|
|
|
|
partial = exc_info.value.partial_result
|
|
assert partial.is_partial is True
|
|
assert len(partial.errors) > 0
|
|
|
|
def test_connection_lost_error_has_partial_result(self):
|
|
profile = make_profile()
|
|
partial = ScanResult(
|
|
resources=[make_resource()],
|
|
warnings=["partial"],
|
|
errors=["lost connection"],
|
|
scan_timestamp="",
|
|
profile_hash="",
|
|
is_partial=True,
|
|
)
|
|
|
|
class RaisesConnectionLost(MockPlugin):
|
|
def discover_resources(self, endpoints, resource_types, progress_callback=None):
|
|
raise ConnectionLostError(partial_result=partial)
|
|
|
|
plugin = RaisesConnectionLost()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
with pytest.raises(ConnectionLostError) as exc_info:
|
|
scanner.scan()
|
|
|
|
assert exc_info.value.partial_result.is_partial is True
|
|
assert len(exc_info.value.partial_result.resources) == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests: Warning for unsupported resource types
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUnsupportedResourceTypes:
|
|
"""Tests for warning logging on unsupported resource types."""
|
|
|
|
def test_unsupported_types_generate_warnings(self):
|
|
profile = make_profile(
|
|
resource_type_filters=[
|
|
"kubernetes_deployment",
|
|
"nonexistent_type",
|
|
"another_fake_type",
|
|
]
|
|
)
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
# Should have warnings for the 2 unsupported types
|
|
unsupported_warnings = [
|
|
w for w in result.warnings if "Unsupported resource type" in w
|
|
]
|
|
assert len(unsupported_warnings) == 2
|
|
assert "nonexistent_type" in unsupported_warnings[0]
|
|
assert "another_fake_type" in unsupported_warnings[1]
|
|
|
|
def test_unsupported_types_do_not_block_supported_types(self):
|
|
profile = make_profile(
|
|
resource_type_filters=[
|
|
"kubernetes_deployment",
|
|
"nonexistent_type",
|
|
]
|
|
)
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
# Should still discover the supported type
|
|
assert len(result.resources) == 1
|
|
assert result.resources[0].resource_type == "kubernetes_deployment"
|
|
|
|
def test_all_unsupported_types_returns_empty_resources(self):
|
|
profile = make_profile(
|
|
resource_type_filters=["fake_type_1", "fake_type_2"]
|
|
)
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
assert len(result.resources) == 0
|
|
assert len(result.warnings) == 2
|
|
|
|
def test_warning_includes_provider_name(self):
|
|
profile = make_profile(
|
|
resource_type_filters=["unsupported_thing"]
|
|
)
|
|
plugin = MockPlugin()
|
|
scanner = Scanner(profile, plugin)
|
|
|
|
result = scanner.scan()
|
|
|
|
assert "kubernetes" in result.warnings[0]
|