"""Unit tests for the Scanner orchestrator.""" import time from unittest.mock import MagicMock, patch import pytest from iac_reverse.models import ( CpuArchitecture, DiscoveredResource, PlatformCategory, ProviderType, ScanProfile, ScanProgress, ScanResult, ) from iac_reverse.plugin_base import ProviderPlugin from iac_reverse.scanner.scanner import ( AuthenticationError, ConnectionLostError, Scanner, ScanTimeoutError, ) # --------------------------------------------------------------------------- # Helpers / Fixtures # --------------------------------------------------------------------------- def make_profile(**overrides) -> ScanProfile: """Create a valid ScanProfile with sensible defaults.""" defaults = { "provider": ProviderType.KUBERNETES, "credentials": {"kubeconfig_path": "/home/user/.kube/config"}, "endpoints": ["https://k8s-api.local:6443"], "resource_type_filters": None, } defaults.update(overrides) return ScanProfile(**defaults) def make_resource(resource_type: str = "kubernetes_deployment", name: str = "nginx") -> DiscoveredResource: """Create a sample DiscoveredResource.""" return DiscoveredResource( resource_type=resource_type, unique_id=f"apps/v1/{resource_type}/{name}", name=name, provider=ProviderType.KUBERNETES, platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, architecture=CpuArchitecture.AARCH64, endpoint="https://k8s-api.local:6443", attributes={"replicas": 3}, raw_references=[], ) class MockPlugin(ProviderPlugin): """A mock provider plugin for testing.""" def __init__( self, supported_types=None, endpoints=None, auth_error=None, discover_result=None, discover_side_effect=None, ): self._supported_types = supported_types or [ "kubernetes_deployment", "kubernetes_service", "kubernetes_ingress", ] self._endpoints = endpoints or ["https://k8s-api.local:6443"] self._auth_error = auth_error self._discover_result = discover_result self._discover_side_effect = discover_side_effect def authenticate(self, credentials: dict[str, str]) -> None: if self._auth_error: raise self._auth_error def get_platform_category(self) -> PlatformCategory: return PlatformCategory.CONTAINER_ORCHESTRATION def list_endpoints(self) -> list[str]: return self._endpoints def list_supported_resource_types(self) -> list[str]: return self._supported_types def detect_architecture(self, endpoint: str) -> CpuArchitecture: return CpuArchitecture.AARCH64 def discover_resources( self, endpoints: list[str], resource_types: list[str], progress_callback=None, ) -> ScanResult: if self._discover_side_effect: raise self._discover_side_effect if self._discover_result: return self._discover_result # Default: return one resource per type and invoke progress callback resources = [] for i, rt in enumerate(resource_types): resources.append(make_resource(resource_type=rt, name=f"{rt}_instance")) if progress_callback: progress_callback( ScanProgress( current_resource_type=rt, resources_discovered=len(resources), resource_types_completed=i + 1, total_resource_types=len(resource_types), ) ) return ScanResult( resources=resources, warnings=[], errors=[], scan_timestamp="", profile_hash="", ) # --------------------------------------------------------------------------- # Tests: Successful scan flow # --------------------------------------------------------------------------- class TestSuccessfulScan: """Tests for the happy path scan flow.""" def test_scan_returns_all_discovered_resources(self): profile = make_profile() plugin = MockPlugin() scanner = Scanner(profile, plugin) result = scanner.scan() assert len(result.resources) == 3 assert result.is_partial is False assert result.scan_timestamp != "" assert result.profile_hash != "" def test_scan_with_resource_type_filters(self): profile = make_profile( resource_type_filters=["kubernetes_deployment", "kubernetes_service"] ) plugin = MockPlugin() scanner = Scanner(profile, plugin) result = scanner.scan() assert len(result.resources) == 2 resource_types = [r.resource_type for r in result.resources] assert "kubernetes_deployment" in resource_types assert "kubernetes_service" in resource_types def test_scan_uses_plugin_endpoints_when_profile_has_none(self): profile = make_profile(endpoints=None) plugin = MockPlugin(endpoints=["https://fallback.local:6443"]) scanner = Scanner(profile, plugin) result = scanner.scan() assert len(result.resources) == 3 def test_scan_uses_profile_endpoints_when_provided(self): profile = make_profile(endpoints=["https://custom.local:6443"]) plugin = MockPlugin() scanner = Scanner(profile, plugin) # The plugin's discover_resources will be called with profile endpoints result = scanner.scan() assert result is not None # --------------------------------------------------------------------------- # Tests: Authentication failure handling # --------------------------------------------------------------------------- class TestAuthenticationFailure: """Tests for authentication error handling.""" def test_auth_failure_raises_authentication_error(self): profile = make_profile() plugin = MockPlugin(auth_error=RuntimeError("Invalid token")) scanner = Scanner(profile, plugin) with pytest.raises(AuthenticationError) as exc_info: scanner.scan() assert "kubernetes" in exc_info.value.provider_name assert "Invalid token" in exc_info.value.reason def test_auth_error_contains_provider_name(self): profile = make_profile(provider=ProviderType.DOCKER_SWARM) plugin = MockPlugin(auth_error=RuntimeError("Connection refused")) scanner = Scanner(profile, plugin) with pytest.raises(AuthenticationError) as exc_info: scanner.scan() assert exc_info.value.provider_name == "docker_swarm" def test_auth_error_contains_reason(self): profile = make_profile() plugin = MockPlugin(auth_error=RuntimeError("Certificate expired")) scanner = Scanner(profile, plugin) with pytest.raises(AuthenticationError) as exc_info: scanner.scan() assert "Certificate expired" in exc_info.value.reason def test_invalid_profile_raises_value_error(self): profile = make_profile(credentials={}) plugin = MockPlugin() scanner = Scanner(profile, plugin) with pytest.raises(ValueError, match="Invalid scan profile"): scanner.scan() def test_no_plugin_raises_value_error(self): profile = make_profile() scanner = Scanner(profile, plugin=None) with pytest.raises(ValueError, match="No provider plugin"): scanner.scan() # --------------------------------------------------------------------------- # Tests: Progress callback invocation # --------------------------------------------------------------------------- class TestProgressCallback: """Tests for progress reporting.""" def test_progress_callback_invoked_per_resource_type(self): profile = make_profile() plugin = MockPlugin() scanner = Scanner(profile, plugin) progress_updates = [] scanner.scan(progress_callback=progress_updates.append) # Should have one progress update per resource type (3 types) assert len(progress_updates) == 3 def test_progress_callback_contains_correct_counts(self): profile = make_profile() plugin = MockPlugin() scanner = Scanner(profile, plugin) progress_updates = [] scanner.scan(progress_callback=progress_updates.append) # Last update should show all types completed last = progress_updates[-1] assert last.resource_types_completed == 3 assert last.total_resource_types == 3 def test_progress_callback_shows_incremental_completion(self): profile = make_profile() plugin = MockPlugin() scanner = Scanner(profile, plugin) progress_updates = [] scanner.scan(progress_callback=progress_updates.append) for i, update in enumerate(progress_updates): assert update.resource_types_completed == i + 1 def test_scan_works_without_progress_callback(self): profile = make_profile() plugin = MockPlugin() scanner = Scanner(profile, plugin) # Should not raise result = scanner.scan(progress_callback=None) assert result is not None # --------------------------------------------------------------------------- # Tests: Retry logic on transient errors # --------------------------------------------------------------------------- class TestRetryLogic: """Tests for retry behavior on transient errors.""" def test_retries_on_transient_error_then_succeeds(self): profile = make_profile() call_count = {"n": 0} expected_result = ScanResult( resources=[make_resource()], warnings=[], errors=[], scan_timestamp="", profile_hash="", ) class RetryPlugin(MockPlugin): def discover_resources(self, endpoints, resource_types, progress_callback=None): call_count["n"] += 1 if call_count["n"] < 3: raise RuntimeError("Transient network error") return expected_result plugin = RetryPlugin() scanner = Scanner(profile, plugin) with patch("iac_reverse.scanner.scanner.time.sleep"): result = scanner.scan() assert len(result.resources) == 1 assert call_count["n"] == 3 def test_returns_error_result_after_max_retries_exhausted(self): profile = make_profile() class AlwaysFailPlugin(MockPlugin): def discover_resources(self, endpoints, resource_types, progress_callback=None): raise RuntimeError("Persistent failure") plugin = AlwaysFailPlugin() scanner = Scanner(profile, plugin) with patch("iac_reverse.scanner.scanner.time.sleep"): result = scanner.scan() assert result.is_partial is True assert len(result.errors) > 0 assert "Persistent failure" in result.errors[0] def test_exponential_backoff_timing(self): profile = make_profile() sleep_calls = [] class FailPlugin(MockPlugin): def discover_resources(self, endpoints, resource_types, progress_callback=None): raise RuntimeError("Transient error") plugin = FailPlugin() scanner = Scanner(profile, plugin) with patch("iac_reverse.scanner.scanner.time.sleep", side_effect=lambda s: sleep_calls.append(s)): scanner.scan() # Should have 3 sleep calls (retries 0, 1, 2 → backoff 1, 2, 4) assert len(sleep_calls) == 3 assert sleep_calls[0] == 1.0 assert sleep_calls[1] == 2.0 assert sleep_calls[2] == 4.0 # --------------------------------------------------------------------------- # Tests: Partial inventory on connection loss # --------------------------------------------------------------------------- class TestConnectionLoss: """Tests for partial inventory return on connection loss.""" def test_connection_error_returns_partial_result(self): profile = make_profile() class ConnectionLossPlugin(MockPlugin): def discover_resources(self, endpoints, resource_types, progress_callback=None): raise ConnectionError("Connection reset by peer") plugin = ConnectionLossPlugin() scanner = Scanner(profile, plugin) with pytest.raises(ConnectionLostError) as exc_info: scanner.scan() partial = exc_info.value.partial_result assert partial.is_partial is True assert len(partial.errors) > 0 def test_connection_lost_error_has_partial_result(self): profile = make_profile() partial = ScanResult( resources=[make_resource()], warnings=["partial"], errors=["lost connection"], scan_timestamp="", profile_hash="", is_partial=True, ) class RaisesConnectionLost(MockPlugin): def discover_resources(self, endpoints, resource_types, progress_callback=None): raise ConnectionLostError(partial_result=partial) plugin = RaisesConnectionLost() scanner = Scanner(profile, plugin) with pytest.raises(ConnectionLostError) as exc_info: scanner.scan() assert exc_info.value.partial_result.is_partial is True assert len(exc_info.value.partial_result.resources) == 1 # --------------------------------------------------------------------------- # Tests: Warning for unsupported resource types # --------------------------------------------------------------------------- class TestUnsupportedResourceTypes: """Tests for warning logging on unsupported resource types.""" def test_unsupported_types_generate_warnings(self): profile = make_profile( resource_type_filters=[ "kubernetes_deployment", "nonexistent_type", "another_fake_type", ] ) plugin = MockPlugin() scanner = Scanner(profile, plugin) result = scanner.scan() # Should have warnings for the 2 unsupported types unsupported_warnings = [ w for w in result.warnings if "Unsupported resource type" in w ] assert len(unsupported_warnings) == 2 assert "nonexistent_type" in unsupported_warnings[0] assert "another_fake_type" in unsupported_warnings[1] def test_unsupported_types_do_not_block_supported_types(self): profile = make_profile( resource_type_filters=[ "kubernetes_deployment", "nonexistent_type", ] ) plugin = MockPlugin() scanner = Scanner(profile, plugin) result = scanner.scan() # Should still discover the supported type assert len(result.resources) == 1 assert result.resources[0].resource_type == "kubernetes_deployment" def test_all_unsupported_types_returns_empty_resources(self): profile = make_profile( resource_type_filters=["fake_type_1", "fake_type_2"] ) plugin = MockPlugin() scanner = Scanner(profile, plugin) result = scanner.scan() assert len(result.resources) == 0 assert len(result.warnings) == 2 def test_warning_includes_provider_name(self): profile = make_profile( resource_type_filters=["unsupported_thing"] ) plugin = MockPlugin() scanner = Scanner(profile, plugin) result = scanner.scan() assert "kubernetes" in result.warnings[0]