diff --git a/.kiro/specs/iac-reverse-engineering/tasks.md b/.kiro/specs/iac-reverse-engineering/tasks.md index 5382e49..94f9dcc 100644 --- a/.kiro/specs/iac-reverse-engineering/tasks.md +++ b/.kiro/specs/iac-reverse-engineering/tasks.md @@ -6,15 +6,15 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu ## Tasks -- [ ] 1. Set up project structure and core data models - - [ ] 1.1 Create project directory structure, pyproject.toml, and install dependencies +- [x] 1. Set up project structure and core data models + - [x] 1.1 Create project directory structure, pyproject.toml, and install dependencies - Create `src/iac_reverse/` package with `__init__.py` - Create subdirectories: `scanner/`, `resolver/`, `generator/`, `state_builder/`, `validator/`, `incremental/`, `auth/`, `cli/` - Set up `pyproject.toml` with dependencies: kubernetes, docker, pywinrm, hypothesis, pytest, click, jinja2, networkx, pyyaml, python-synology - Create `tests/` directory with `unit/`, `property/`, `integration/` subdirectories - _Requirements: 1.1, 5.1, 5.2_ - - [ ] 1.2 Define core enums, data classes, and interfaces + - [x] 1.2 Define core enums, data classes, and interfaces - Implement `ProviderType` enum (docker_swarm, kubernetes, synology, harvester, bare_metal, windows) - Implement `PlatformCategory` enum (container_orchestration, storage_appliance, hci, bare_metal, windows) and `PROVIDER_PLATFORM_MAP` - Implement `CpuArchitecture` enum (amd64, arm, aarch64) @@ -27,19 +27,19 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Define `ProviderPlugin` abstract base class with all abstract methods - _Requirements: 1.1, 1.2, 2.1, 3.1, 4.1, 5.1, 5.2, 8.1_ - - [ ] 1.3 Implement ScanProfile validation logic + - [x] 1.3 Implement ScanProfile validation logic - Validate mandatory fields: provider type and non-empty credentials - Validate optional fields: resource_type_filters max 200 entries, endpoints list - Validate resource types against provider's supported types - Return all validation errors in a single response - _Requirements: 6.1, 6.6, 6.7_ - - [ ]* 1.4 Write property test for scan profile validation (Property 20) + - [x] 1.4 Write property test for scan profile validation (Property 20) - **Property 20: Scan profile validation completeness** - **Validates: Requirements 6.1, 6.6, 6.7** -- [ ] 2. Implement Scanner core and provider plugin system - - [ ] 2.1 Implement Scanner orchestrator with progress reporting and error handling +- [x] 2. Implement Scanner core and provider plugin system + - [x] 2.1 Implement Scanner orchestrator with progress reporting and error handling - Create `Scanner` class that accepts a `ScanProfile` and orchestrates discovery - Implement connection timeout (30 seconds) and authentication error handling with descriptive messages - Implement progress callback invocation per resource type completion @@ -48,44 +48,44 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Implement warning logging for unsupported resource types while continuing scan - _Requirements: 1.1, 1.3, 1.4, 1.5, 1.6, 1.7_ - - [ ]* 2.2 Write property tests for Scanner behavior (Properties 2, 3, 4, 5) + - [x] 2.2 Write property tests for Scanner behavior (Properties 2, 3, 4, 5) - **Property 2: Authentication error descriptiveness** - **Property 3: Graceful degradation on unsupported resource types** - **Property 4: Progress reporting frequency** - **Property 5: Partial inventory preservation on failure** - **Validates: Requirements 1.3, 1.4, 1.5, 1.7** - - [ ] 2.3 Implement Docker Swarm provider plugin + - [x] 2.3 Implement Docker Swarm provider plugin - Implement `DockerSwarmPlugin` using docker-sdk-python - Discover services, networks, volumes, configs, secrets (metadata only) - Detect architecture from node info - _Requirements: 1.1, 1.2, 5.2_ - - [ ] 2.4 Implement Kubernetes provider plugin + - [x] 2.4 Implement Kubernetes provider plugin - Implement `KubernetesPlugin` using kubernetes-client - Discover deployments, services, ingresses, config maps, persistent volumes, namespaces - Detect architecture from node labels - _Requirements: 1.1, 1.2, 5.2_ - - [ ] 2.5 Implement Synology provider plugin + - [x] 2.5 Implement Synology provider plugin - Implement `SynologyPlugin` using Synology DSM API - Discover shared folders, volumes, storage pools, replication tasks, users - Detect architecture from system info (ARM vs AMD64) - _Requirements: 1.1, 1.2, 5.2_ - - [ ] 2.6 Implement Harvester provider plugin + - [x] 2.6 Implement Harvester provider plugin - Implement `HarvesterPlugin` using Harvester/K8s-based API - Discover VMs, volumes, images, networks (HCI combined resources) - Detect architecture from node info - _Requirements: 1.1, 1.2, 5.2_ - - [ ] 2.7 Implement Bare Metal provider plugin + - [x] 2.7 Implement Bare Metal provider plugin - Implement `BareMetalPlugin` using IPMI/Redfish API - Discover hardware inventory, BMC configs, network interfaces, RAID configurations - Detect architecture from system hardware info - _Requirements: 1.1, 1.2, 5.2_ - - [ ] 2.8 Implement Windows provider plugin + - [x] 2.8 Implement Windows provider plugin - Implement `WindowsDiscoveryPlugin` using pywinrm library - Authenticate via WinRM using NTLM or Kerberos (configurable transport, port, SSL) - Discover Windows services, scheduled tasks, IIS sites, IIS app pools, network adapters, firewall rules, installed software, Windows features, Hyper-V VMs, Hyper-V switches, DNS records, local users, local groups @@ -94,21 +94,21 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Handle WinRM-specific errors: WinRM not enabled, WMI query failure, insufficient privileges - _Requirements: 1.1, 1.2, 5.2_ - - [ ] 2.9 Implement Authentik integration (SSO + discovery plugin) + - [x] 2.9 Implement Authentik integration (SSO + discovery plugin) - Implement `AuthentikAuthProvider` for OAuth2/OIDC SSO flow (authenticate, refresh, validate) - Implement `AuthentikDiscoveryPlugin` conforming to `ProviderPlugin` - Discover flows, stages, providers, applications, outposts, property mappings, certificates, groups, sources - _Requirements: 1.1, 1.2, 5.2_ - - [ ]* 2.10 Write property test for resource inventory completeness (Property 1) + - [x] 2.10 Write property test for resource inventory completeness (Property 1) - **Property 1: Resource inventory completeness** - **Validates: Requirements 1.2** -- [ ] 3. Checkpoint - Ensure all tests pass +- [x] 3. Checkpoint - Ensure all tests pass - Ensure all tests pass, ask the user if questions arise. -- [ ] 4. Implement Dependency Resolver - - [ ] 4.1 Implement dependency resolution and graph building +- [x] 4. Implement Dependency Resolver + - [x] 4.1 Implement dependency resolution and graph building - Create `DependencyResolver` class - Analyze resource `raw_references` to identify parent-child, reference, and dependency relationships - Build dependency graph using networkx @@ -116,27 +116,27 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Represent relationships as explicit Terraform references (not hardcoded IDs) - _Requirements: 2.1, 2.2, 2.4_ - - [ ] 4.2 Implement cycle detection and resolution suggestions + - [x] 4.2 Implement cycle detection and resolution suggestions - Detect circular dependencies in the graph - Report cycles listing all involved resources - Suggest resolution strategies (which relationship to break, data source lookup alternatives) - _Requirements: 2.3_ - - [ ] 4.3 Implement unresolved reference handling + - [x] 4.3 Implement unresolved reference handling - Identify references to IDs not in the current inventory - Log warnings for unresolved references - Represent unresolved references as data source lookups or variables in output - _Requirements: 2.5_ - - [ ]* 4.4 Write property tests for Dependency Resolver (Properties 6, 7, 8, 9) + - [x] 4.4 Write property tests for Dependency Resolver (Properties 6, 7, 8, 9) - **Property 6: Dependency relationship identification** - **Property 7: Cycle detection correctness** - **Property 8: Topological order validity** - **Property 9: Unresolved references become data sources or variables** - **Validates: Requirements 2.1, 2.3, 2.4, 2.5** -- [ ] 5. Implement Code Generator - - [ ] 5.1 Implement HCL code generation with Jinja2 templates +- [x] 5. Implement Code Generator + - [x] 5.1 Implement HCL code generation with Jinja2 templates - Create `CodeGenerator` class - Create Jinja2 templates for Terraform resource blocks per provider/resource type - Generate syntactically valid HCL files from dependency graph @@ -146,31 +146,31 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Generate architecture-specific tags/labels on resources - _Requirements: 3.1, 3.2, 3.5, 3.6_ - - [ ] 5.2 Implement identifier sanitization + - [x] 5.2 Implement identifier sanitization - Create `sanitize_identifier()` function - Convert resource names to valid Terraform identifiers: `^[a-zA-Z_][a-zA-Z0-9_]*$` - Handle special characters, unicode, leading digits, spaces by replacing with underscores - Ensure non-empty output for any input - _Requirements: 3.4_ - - [ ] 5.3 Implement variable extraction logic + - [x] 5.3 Implement variable extraction logic - Identify attribute values appearing in 2+ resources - Extract shared values into `variables.tf` with defaults set to most common value - Generate variable declarations with type expressions and descriptions - _Requirements: 3.3_ - - [ ] 5.4 Implement provider configuration block generation + - [x] 5.4 Implement provider configuration block generation - Generate separate provider blocks for each distinct provider used - Include platform-specific configuration (endpoints, certificate settings) - _Requirements: 5.4_ - - [ ] 5.5 Implement multi-provider resource merging with conflict resolution + - [x] 5.5 Implement multi-provider resource merging with conflict resolution - Merge resources from multiple scan profiles into unified inventory - Resolve naming conflicts by prefixing with provider identifier - Preserve provider-specific attributes - _Requirements: 5.3_ - - [ ]* 5.6 Write property tests for Code Generator (Properties 10, 11, 12, 13, 14, 15) + - [x] 5.6 Write property tests for Code Generator (Properties 10, 11, 12, 13, 14, 15) - **Property 10: References in generated output use Terraform syntax** - **Property 11: Generated HCL syntactic validity** - **Property 12: File organization by resource type** @@ -179,8 +179,8 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - **Property 15: Traceability comments in generated code** - **Validates: Requirements 2.2, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6** -- [ ] 6. Implement State Builder - - [ ] 6.1 Implement Terraform state file generation (format v4) +- [x] 6. Implement State Builder + - [x] 6.1 Implement Terraform state file generation (format v4) - Create `StateBuilder` class - Generate state JSON with version=4, unique UUID lineage, serial number - Create state entries binding each resource block to its live infrastructure ID @@ -190,22 +190,22 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Include dependency references in state entries - _Requirements: 4.1, 4.2, 4.4, 4.5_ - - [ ] 6.2 Implement unmapped resource handling in state builder + - [x] 6.2 Implement unmapped resource handling in state builder - Log warnings for resources that cannot be mapped to state entries - Handle missing provider-assigned resource identifiers - Exclude unmapped resources from state file - _Requirements: 4.3, 4.6_ - - [ ]* 6.3 Write property tests for State Builder (Properties 16, 17) + - [x] 6.3 Write property tests for State Builder (Properties 16, 17) - **Property 16: State file structural validity** - **Property 17: State entry completeness and schema correctness** - **Validates: Requirements 4.1, 4.2, 4.4, 4.5** -- [ ] 7. Checkpoint - Ensure all tests pass +- [x] 7. Checkpoint - Ensure all tests pass - Ensure all tests pass, ask the user if questions arise. -- [ ] 8. Implement Validator - - [ ] 8.1 Implement Terraform validation runner +- [x] 8. Implement Validator + - [x] 8.1 Implement Terraform validation runner - Create `Validator` class - Run `terraform init` and `terraform validate` against generated output - Run `terraform plan` and check for zero planned changes @@ -214,46 +214,46 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Handle missing Terraform binary with descriptive error - _Requirements: 7.1, 7.2, 7.3, 7.5_ - - [ ] 8.2 Implement auto-correction loop for validation errors + - [x] 8.2 Implement auto-correction loop for validation errors - Attempt to correct validation errors (up to 3 attempts) - Re-validate after each correction - Report failure with remaining error details if corrections exhausted - _Requirements: 7.4_ - - [ ]* 8.3 Write property test for drift report correctness (Property 22) + - [x] 8.3 Write property test for drift report correctness (Property 22) - **Property 22: Drift report correctness** - **Validates: Requirements 7.3** -- [ ] 9. Implement Incremental Scan Engine - - [ ] 9.1 Implement scan snapshot storage and retrieval +- [x] 9. Implement Incremental Scan Engine + - [x] 9.1 Implement scan snapshot storage and retrieval - Store scan results as timestamped JSON in `.iac-reverse/snapshots/` - Use profile_hash for matching scans to profiles - Retain at least 2 most recent snapshots per profile - Load previous snapshot for comparison - _Requirements: 8.4, 8.6_ - - [ ] 9.2 Implement change detection and classification + - [x] 9.2 Implement change detection and classification - Compare current scan against previous snapshot - Classify resources as added, removed, or modified - Produce change summary with counts and resource details - Handle first scan (no previous) as full initial scan - _Requirements: 8.1, 8.4, 8.5_ - - [ ] 9.3 Implement incremental code and state updates + - [x] 9.3 Implement incremental code and state updates - Update only IaC files containing changed resources (not full regeneration) - Remove resource blocks and state entries for removed resources - Add/update blocks for added/modified resources - _Requirements: 8.2, 8.3_ - - [ ]* 9.4 Write property tests for Incremental Scan (Properties 23, 24, 25, 26) + - [x] 9.4 Write property tests for Incremental Scan (Properties 23, 24, 25, 26) - **Property 23: Change classification correctness** - **Property 24: Incremental update scope** - **Property 25: Removed resource exclusion** - **Property 26: Snapshot retention** - **Validates: Requirements 8.1, 8.2, 8.3, 8.5, 8.6** -- [ ] 10. Implement CLI and wire pipeline together - - [ ] 10.1 Implement CLI entry point with Click +- [x] 10. Implement CLI and wire pipeline together + - [x] 10.1 Implement CLI entry point with Click - Create `cli.py` with Click command group - Implement `scan` command accepting scan profile YAML path - Implement `generate` command to run full pipeline (scan → resolve → generate → state → validate) @@ -264,32 +264,32 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu - Add progress bars and formatted output for scan progress - _Requirements: 1.1, 1.5, 6.1, 6.2, 6.3, 6.4, 6.5_ - - [ ] 10.2 Implement scan profile YAML loading and environment variable expansion + - [x] 10.2 Implement scan profile YAML loading and environment variable expansion - Parse YAML scan profiles - Expand `${ENV_VAR}` references in credential fields - Support multi-profile YAML for multi-provider scans - _Requirements: 6.1, 5.3_ - - [ ]* 10.3 Write property tests for multi-provider and filtering (Properties 18, 19, 20, 21) + - [x] 10.3 Write property tests for multi-provider and filtering (Properties 18, 19, 20, 21) - **Property 18: Multi-provider merge with naming conflict resolution** - **Property 19: Provider block generation** - **Property 20: Scan profile validation completeness** (additional coverage) - **Property 21: Filtering correctness** - **Validates: Requirements 5.3, 5.4, 6.1, 6.2, 6.4, 6.6, 6.7** -- [ ] 11. Implement resource type filter and multi-provider failure handling - - [ ] 11.1 Implement resource type filtering in scanner +- [x] 11. Implement resource type filter and multi-provider failure handling + - [x] 11.1 Implement resource type filtering in scanner - When filters specified, discover only listed resource types - When no filters specified, discover all supported types for provider - _Requirements: 6.2, 6.3_ - - [ ] 11.2 Implement multi-provider partial failure handling + - [x] 11.2 Implement multi-provider partial failure handling - Complete scanning for all remaining providers when one fails - Include successfully discovered resources in inventory - Report which providers failed with error details - _Requirements: 5.5_ -- [ ] 12. Final checkpoint - Ensure all tests pass +- [x] 12. Final checkpoint - Ensure all tests pass - Ensure all tests pass, ask the user if questions arise. ## Notes diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..b6393be --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "iac-reverse" +version = "0.1.0" +description = "Reverse engineer existing on-premises infrastructure into Terraform HCL code and state files" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} + +dependencies = [ + "click>=8.1.7", + "jinja2>=3.1.3", + "networkx>=3.2.1", + "pyyaml>=6.0.1", + "kubernetes>=28.1.0", + "docker>=7.0.0", + "pywinrm>=0.4.3", + "python-synology>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "hypothesis>=6.92.0", + "pytest>=7.4.4", + "pytest-cov>=4.1.0", +] + +[project.scripts] +iac-reverse = "iac_reverse.cli:main" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["src"] + +[tool.hypothesis] +max_examples = 100 diff --git a/src/iac_reverse.egg-info/PKG-INFO b/src/iac_reverse.egg-info/PKG-INFO new file mode 100644 index 0000000..655bf50 --- /dev/null +++ b/src/iac_reverse.egg-info/PKG-INFO @@ -0,0 +1,23 @@ +Metadata-Version: 2.4 +Name: iac-reverse +Version: 0.1.0 +Summary: Reverse engineer existing on-premises infrastructure into Terraform HCL code and state files +License: MIT +Requires-Python: >=3.11 +Description-Content-Type: text/markdown +Requires-Dist: click>=8.1.7 +Requires-Dist: jinja2>=3.1.3 +Requires-Dist: networkx>=3.2.1 +Requires-Dist: pyyaml>=6.0.1 +Requires-Dist: kubernetes>=28.1.0 +Requires-Dist: docker>=7.0.0 +Requires-Dist: pywinrm>=0.4.3 +Requires-Dist: python-synology>=1.0.0 +Provides-Extra: dev +Requires-Dist: hypothesis>=6.92.0; extra == "dev" +Requires-Dist: pytest>=7.4.4; extra == "dev" +Requires-Dist: pytest-cov>=4.1.0; extra == "dev" + +# SnarfCode +# I added this line to test the syncing + diff --git a/src/iac_reverse.egg-info/SOURCES.txt b/src/iac_reverse.egg-info/SOURCES.txt new file mode 100644 index 0000000..fe248d7 --- /dev/null +++ b/src/iac_reverse.egg-info/SOURCES.txt @@ -0,0 +1,17 @@ +README.md +pyproject.toml +src/iac_reverse/__init__.py +src/iac_reverse.egg-info/PKG-INFO +src/iac_reverse.egg-info/SOURCES.txt +src/iac_reverse.egg-info/dependency_links.txt +src/iac_reverse.egg-info/entry_points.txt +src/iac_reverse.egg-info/requires.txt +src/iac_reverse.egg-info/top_level.txt +src/iac_reverse/auth/__init__.py +src/iac_reverse/cli/__init__.py +src/iac_reverse/generator/__init__.py +src/iac_reverse/incremental/__init__.py +src/iac_reverse/resolver/__init__.py +src/iac_reverse/scanner/__init__.py +src/iac_reverse/state_builder/__init__.py +src/iac_reverse/validator/__init__.py \ No newline at end of file diff --git a/src/iac_reverse.egg-info/dependency_links.txt b/src/iac_reverse.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/iac_reverse.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/iac_reverse.egg-info/entry_points.txt b/src/iac_reverse.egg-info/entry_points.txt new file mode 100644 index 0000000..0a76ac1 --- /dev/null +++ b/src/iac_reverse.egg-info/entry_points.txt @@ -0,0 +1,2 @@ +[console_scripts] +iac-reverse = iac_reverse.cli:main diff --git a/src/iac_reverse.egg-info/requires.txt b/src/iac_reverse.egg-info/requires.txt new file mode 100644 index 0000000..cd3ac0d --- /dev/null +++ b/src/iac_reverse.egg-info/requires.txt @@ -0,0 +1,13 @@ +click>=8.1.7 +jinja2>=3.1.3 +networkx>=3.2.1 +pyyaml>=6.0.1 +kubernetes>=28.1.0 +docker>=7.0.0 +pywinrm>=0.4.3 +python-synology>=1.0.0 + +[dev] +hypothesis>=6.92.0 +pytest>=7.4.4 +pytest-cov>=4.1.0 diff --git a/src/iac_reverse.egg-info/top_level.txt b/src/iac_reverse.egg-info/top_level.txt new file mode 100644 index 0000000..7b96f56 --- /dev/null +++ b/src/iac_reverse.egg-info/top_level.txt @@ -0,0 +1 @@ +iac_reverse diff --git a/src/iac_reverse/__init__.py b/src/iac_reverse/__init__.py new file mode 100644 index 0000000..fe481f6 --- /dev/null +++ b/src/iac_reverse/__init__.py @@ -0,0 +1,58 @@ +"""IaC Reverse Engineering Tool. + +Reverse engineer existing on-premises infrastructure into Terraform HCL code and state files. +""" + +__version__ = "0.1.0" + +from iac_reverse.models import ( + ChangeType, + ChangeSummary, + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + ExtractedVariable, + GeneratedFile, + PlannedChange, + PlatformCategory, + PROVIDER_PLATFORM_MAP, + ProviderType, + ResourceChange, + ResourceRelationship, + ScanProfile, + ScanProgress, + ScanResult, + StateEntry, + StateFile, + UnresolvedReference, + ValidationError, + ValidationResult, +) +from iac_reverse.plugin_base import ProviderPlugin + +__all__ = [ + "ChangeType", + "ChangeSummary", + "CodeGenerationResult", + "CpuArchitecture", + "DependencyGraph", + "DiscoveredResource", + "ExtractedVariable", + "GeneratedFile", + "PlannedChange", + "PlatformCategory", + "PROVIDER_PLATFORM_MAP", + "ProviderPlugin", + "ProviderType", + "ResourceChange", + "ResourceRelationship", + "ScanProfile", + "ScanProgress", + "ScanResult", + "StateEntry", + "StateFile", + "UnresolvedReference", + "ValidationError", + "ValidationResult", +] diff --git a/src/iac_reverse/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..4b5be25 Binary files /dev/null and b/src/iac_reverse/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/__pycache__/models.cpython-313.pyc b/src/iac_reverse/__pycache__/models.cpython-313.pyc new file mode 100644 index 0000000..ea551fa Binary files /dev/null and b/src/iac_reverse/__pycache__/models.cpython-313.pyc differ diff --git a/src/iac_reverse/__pycache__/plugin_base.cpython-313.pyc b/src/iac_reverse/__pycache__/plugin_base.cpython-313.pyc new file mode 100644 index 0000000..0a50fc8 Binary files /dev/null and b/src/iac_reverse/__pycache__/plugin_base.cpython-313.pyc differ diff --git a/src/iac_reverse/auth/__init__.py b/src/iac_reverse/auth/__init__.py new file mode 100644 index 0000000..5c15d92 --- /dev/null +++ b/src/iac_reverse/auth/__init__.py @@ -0,0 +1,21 @@ +"""Authentication module for Authentik SSO integration.""" + +from iac_reverse.auth.authentik_auth import ( + AuthenticationError, + AuthentikAuthProvider, + AuthentikConfig, + AuthentikSession, +) +from iac_reverse.auth.authentik_discovery import ( + AuthentikDiscoveryError, + AuthentikDiscoveryPlugin, +) + +__all__ = [ + "AuthenticationError", + "AuthentikAuthProvider", + "AuthentikConfig", + "AuthentikSession", + "AuthentikDiscoveryError", + "AuthentikDiscoveryPlugin", +] diff --git a/src/iac_reverse/auth/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/auth/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..9b6cb06 Binary files /dev/null and b/src/iac_reverse/auth/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/auth/__pycache__/authentik_auth.cpython-313.pyc b/src/iac_reverse/auth/__pycache__/authentik_auth.cpython-313.pyc new file mode 100644 index 0000000..b3df617 Binary files /dev/null and b/src/iac_reverse/auth/__pycache__/authentik_auth.cpython-313.pyc differ diff --git a/src/iac_reverse/auth/__pycache__/authentik_discovery.cpython-313.pyc b/src/iac_reverse/auth/__pycache__/authentik_discovery.cpython-313.pyc new file mode 100644 index 0000000..2aa19eb Binary files /dev/null and b/src/iac_reverse/auth/__pycache__/authentik_discovery.cpython-313.pyc differ diff --git a/src/iac_reverse/auth/authentik_auth.py b/src/iac_reverse/auth/authentik_auth.py new file mode 100644 index 0000000..6283db4 --- /dev/null +++ b/src/iac_reverse/auth/authentik_auth.py @@ -0,0 +1,204 @@ +"""Authentik SSO authentication provider. + +Handles OAuth2/OIDC authentication flow with an Authentik instance, +including token refresh and validation. +""" + +from dataclasses import dataclass, field +from urllib.parse import urljoin + +import requests + + +@dataclass +class AuthentikConfig: + """Configuration for connecting to an Authentik instance.""" + + base_url: str # Authentik instance URL (e.g., "https://auth.internal.lab") + client_id: str # OAuth2 client ID for this tool + client_secret: str # OAuth2 client secret + + +@dataclass +class AuthentikSession: + """Active session from Authentik SSO authentication.""" + + access_token: str + refresh_token: str + user_id: str + groups: list[str] = field(default_factory=list) + + +class AuthenticationError(Exception): + """Raised when Authentik authentication fails.""" + + pass + + +class AuthentikAuthProvider: + """Handles SSO authentication for the tool via Authentik OAuth2/OIDC. + + Provides methods to authenticate users, refresh expired sessions, + and validate existing tokens against the Authentik instance. + """ + + def authenticate_user(self, config: AuthentikConfig) -> AuthentikSession: + """Initiate OAuth2/OIDC flow with Authentik and return a session. + + Uses the client credentials or resource owner password grant to obtain + an access token from Authentik's token endpoint. + + Args: + config: Authentik connection configuration. + + Returns: + An AuthentikSession with access/refresh tokens and user info. + + Raises: + AuthenticationError: If authentication fails for any reason. + """ + token_url = urljoin(config.base_url.rstrip("/") + "/", "application/o/token/") + + try: + response = requests.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": config.client_id, + "client_secret": config.client_secret, + "scope": "openid profile email", + }, + timeout=30, + ) + except requests.RequestException as e: + raise AuthenticationError( + f"Authentik: failed to connect to {config.base_url} - {e}" + ) + + if response.status_code != 200: + raise AuthenticationError( + f"Authentik: authentication failed with status {response.status_code} " + f"- {response.text}" + ) + + token_data = response.json() + access_token = token_data.get("access_token", "") + refresh_token = token_data.get("refresh_token", "") + + # Fetch user info to get user_id and groups + user_id, groups = self._fetch_user_info(config.base_url, access_token) + + return AuthentikSession( + access_token=access_token, + refresh_token=refresh_token, + user_id=user_id, + groups=groups, + ) + + def refresh_session( + self, config: AuthentikConfig, session: AuthentikSession + ) -> AuthentikSession: + """Refresh an expired session token. + + Args: + config: Authentik connection configuration. + session: The current session with a valid refresh token. + + Returns: + A new AuthentikSession with refreshed tokens. + + Raises: + AuthenticationError: If the refresh fails. + """ + token_url = urljoin(config.base_url.rstrip("/") + "/", "application/o/token/") + + try: + response = requests.post( + token_url, + data={ + "grant_type": "refresh_token", + "refresh_token": session.refresh_token, + "client_id": config.client_id, + "client_secret": config.client_secret, + }, + timeout=30, + ) + except requests.RequestException as e: + raise AuthenticationError( + f"Authentik: failed to refresh session - {e}" + ) + + if response.status_code != 200: + raise AuthenticationError( + f"Authentik: token refresh failed with status {response.status_code} " + f"- {response.text}" + ) + + token_data = response.json() + access_token = token_data.get("access_token", "") + refresh_token = token_data.get("refresh_token", session.refresh_token) + + user_id, groups = self._fetch_user_info(config.base_url, access_token) + + return AuthentikSession( + access_token=access_token, + refresh_token=refresh_token, + user_id=user_id, + groups=groups, + ) + + def validate_token(self, config: AuthentikConfig, token: str) -> bool: + """Validate an existing token is still valid. + + Checks the token against Authentik's userinfo endpoint. + + Args: + config: Authentik connection configuration. + token: The access token to validate. + + Returns: + True if the token is valid, False otherwise. + """ + userinfo_url = urljoin( + config.base_url.rstrip("/") + "/", "application/o/userinfo/" + ) + + try: + response = requests.get( + userinfo_url, + headers={"Authorization": f"Bearer {token}"}, + timeout=10, + ) + return response.status_code == 200 + except requests.RequestException: + return False + + def _fetch_user_info( + self, base_url: str, access_token: str + ) -> tuple[str, list[str]]: + """Fetch user info from Authentik's userinfo endpoint. + + Args: + base_url: Authentik instance base URL. + access_token: Valid access token. + + Returns: + Tuple of (user_id, groups list). + """ + userinfo_url = urljoin(base_url.rstrip("/") + "/", "application/o/userinfo/") + + try: + response = requests.get( + userinfo_url, + headers={"Authorization": f"Bearer {access_token}"}, + timeout=10, + ) + if response.status_code == 200: + data = response.json() + user_id = data.get("sub", "") + groups = data.get("groups", []) + return user_id, groups + except requests.RequestException: + pass + + return "", [] diff --git a/src/iac_reverse/auth/authentik_discovery.py b/src/iac_reverse/auth/authentik_discovery.py new file mode 100644 index 0000000..cbab560 --- /dev/null +++ b/src/iac_reverse/auth/authentik_discovery.py @@ -0,0 +1,384 @@ +"""Authentik discovery plugin. + +Discovers Authentik configurations as infrastructure resources, including +flows, stages, providers, applications, outposts, property mappings, +certificates, groups, and sources. +""" + +from typing import Callable +from urllib.parse import urljoin + +import requests + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin + + +class AuthentikDiscoveryError(Exception): + """Raised when Authentik discovery encounters an error.""" + + pass + + +# Mapping of resource types to their Authentik API endpoints +_RESOURCE_TYPE_API_MAP: dict[str, str] = { + "authentik_flow": "api/v3/flows/instances/", + "authentik_stage": "api/v3/stages/all/", + "authentik_provider": "api/v3/providers/all/", + "authentik_application": "api/v3/core/applications/", + "authentik_outpost": "api/v3/outposts/instances/", + "authentik_property_mapping": "api/v3/propertymappings/all/", + "authentik_certificate": "api/v3/crypto/certificatekeypairs/", + "authentik_group": "api/v3/core/groups/", + "authentik_source": "api/v3/sources/all/", +} + + +class AuthentikDiscoveryPlugin(ProviderPlugin): + """Discovers Authentik configurations as infrastructure resources. + + Connects to an Authentik instance via its REST API and enumerates + flows, stages, providers, applications, outposts, property mappings, + certificates, groups, and sources. + + Since Authentik is an identity provider (not a traditional infrastructure + platform), it uses PlatformCategory.CONTAINER_ORCHESTRATION as a + categorization convenience — Authentik typically runs as a containerized + service within the orchestration layer. + """ + + def __init__(self) -> None: + self._base_url: str = "" + self._api_token: str = "" + self._authenticated: bool = False + + def authenticate(self, credentials: dict[str, str]) -> None: + """Authenticate with the Authentik REST API. + + Expected credentials: + - base_url: Authentik instance URL (e.g., "https://auth.internal.lab") + - api_token: Authentik API token for administrative access + + Args: + credentials: Dictionary with base_url and api_token. + + Raises: + AuthentikDiscoveryError: If authentication fails. + """ + base_url = credentials.get("base_url", "") + api_token = credentials.get("api_token", "") + + if not base_url: + raise AuthentikDiscoveryError( + "Authentik: 'base_url' is required in credentials" + ) + if not api_token: + raise AuthentikDiscoveryError( + "Authentik: 'api_token' is required in credentials" + ) + + self._base_url = base_url.rstrip("/") + self._api_token = api_token + + # Verify connectivity by hitting the core API + try: + response = requests.get( + self._build_url("api/v3/core/applications/"), + headers=self._auth_headers(), + params={"page_size": 1}, + timeout=30, + ) + except requests.RequestException as e: + raise AuthentikDiscoveryError( + f"Authentik: failed to connect to {base_url} - {e}" + ) + + if response.status_code == 401: + raise AuthentikDiscoveryError( + "Authentik: authentication failed - invalid API token" + ) + if response.status_code == 403: + raise AuthentikDiscoveryError( + "Authentik: authentication failed - insufficient permissions" + ) + if response.status_code not in (200, 201): + raise AuthentikDiscoveryError( + f"Authentik: unexpected status {response.status_code} " + f"during authentication check" + ) + + self._authenticated = True + + def get_platform_category(self) -> PlatformCategory: + """Return the platform category for Authentik. + + Authentik is an identity provider that typically runs as a containerized + service, so it is categorized under CONTAINER_ORCHESTRATION. + """ + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + """Return the Authentik instance endpoint. + + Returns: + List containing the configured Authentik base URL. + """ + if not self._base_url: + return [] + return [self._base_url] + + def list_supported_resource_types(self) -> list[str]: + """Return all Authentik resource types this plugin can discover. + + Returns: + List of Authentik resource type strings. + """ + return [ + "authentik_flow", + "authentik_stage", + "authentik_provider", + "authentik_application", + "authentik_outpost", + "authentik_property_mapping", + "authentik_certificate", + "authentik_group", + "authentik_source", + ] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect the CPU architecture of the Authentik host. + + Authentik is a web service; architecture detection is not directly + applicable. Defaults to AMD64 as the most common deployment target. + + Args: + endpoint: The Authentik endpoint URL. + + Returns: + CpuArchitecture.AMD64 as the default. + """ + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover Authentik resources via the REST API. + + Connects to the Authentik API and enumerates all resources of the + requested types. Reports progress via the callback function. + + Args: + endpoints: List of Authentik endpoint URLs (typically one). + resource_types: List of resource type strings to discover. + progress_callback: Callable that receives ScanProgress updates. + + Returns: + ScanResult containing all discovered Authentik resources. + + Raises: + AuthentikDiscoveryError: If not authenticated. + """ + if not self._authenticated: + raise AuthentikDiscoveryError( + "Authentik: must authenticate before discovering resources" + ) + + import datetime + + resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + + endpoint = endpoints[0] if endpoints else self._base_url + total_types = len(resource_types) + + for idx, resource_type in enumerate(resource_types): + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(resources), + resource_types_completed=idx, + total_resource_types=total_types, + ) + ) + + if resource_type not in _RESOURCE_TYPE_API_MAP: + warnings.append( + f"Unsupported Authentik resource type: {resource_type}" + ) + continue + + try: + discovered = self._discover_resource_type( + resource_type, endpoint + ) + resources.extend(discovered) + except Exception as e: + errors.append( + f"Error discovering {resource_type}: {e}" + ) + + # Final progress update + progress_callback( + ScanProgress( + current_resource_type="complete", + resources_discovered=len(resources), + resource_types_completed=total_types, + total_resource_types=total_types, + ) + ) + + scan_timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat() + + return ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp=scan_timestamp, + profile_hash="", + is_partial=len(errors) > 0, + ) + + def _discover_resource_type( + self, resource_type: str, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all resources of a specific type from Authentik API. + + Handles pagination to retrieve all results. + + Args: + resource_type: The Authentik resource type to discover. + endpoint: The Authentik endpoint URL. + + Returns: + List of DiscoveredResource objects. + """ + api_path = _RESOURCE_TYPE_API_MAP[resource_type] + results: list[DiscoveredResource] = [] + page = 1 + + while True: + response = requests.get( + self._build_url(api_path), + headers=self._auth_headers(), + params={"page": page, "page_size": 100}, + timeout=30, + ) + + if response.status_code != 200: + raise AuthentikDiscoveryError( + f"API request failed for {resource_type}: " + f"status {response.status_code}" + ) + + data = response.json() + items = data.get("results", []) + + for item in items: + resource = self._map_to_resource(resource_type, item, endpoint) + results.append(resource) + + # Check for next page + if data.get("pagination", {}).get("next", 0) > 0: + page += 1 + else: + break + + return results + + def _map_to_resource( + self, resource_type: str, item: dict, endpoint: str + ) -> DiscoveredResource: + """Map an Authentik API response item to a DiscoveredResource. + + Args: + resource_type: The resource type string. + item: The API response dictionary for a single resource. + endpoint: The Authentik endpoint URL. + + Returns: + A DiscoveredResource instance. + """ + # Extract common fields with sensible defaults + unique_id = str(item.get("pk", item.get("uuid", item.get("id", "")))) + name = item.get("name", item.get("slug", item.get("title", unique_id))) + + return DiscoveredResource( + resource_type=resource_type, + unique_id=f"authentik/{resource_type}/{unique_id}", + name=name, + provider=ProviderType.DOCKER_SWARM, # Closest match for containerized identity provider + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint=endpoint, + attributes=item, + raw_references=self._extract_references(item), + ) + + def _extract_references(self, item: dict) -> list[str]: + """Extract references to other resources from an API item. + + Looks for common reference fields in Authentik API responses. + + Args: + item: The API response dictionary. + + Returns: + List of reference ID strings. + """ + references: list[str] = [] + + # Common reference fields in Authentik API + ref_fields = [ + "flow", + "provider", + "application", + "outpost", + "group", + "source", + "certificate", + "stages", + "policies", + ] + + for field_name in ref_fields: + value = item.get(field_name) + if value is None: + continue + if isinstance(value, str) and value: + references.append(value) + elif isinstance(value, list): + for v in value: + if isinstance(v, str) and v: + references.append(v) + + return references + + def _build_url(self, path: str) -> str: + """Build a full URL from the base URL and a relative path. + + Args: + path: Relative API path. + + Returns: + Full URL string. + """ + return urljoin(self._base_url + "/", path) + + def _auth_headers(self) -> dict[str, str]: + """Return authorization headers for API requests. + + Returns: + Dictionary with Authorization header. + """ + return {"Authorization": f"Bearer {self._api_token}"} diff --git a/src/iac_reverse/cli/__init__.py b/src/iac_reverse/cli/__init__.py new file mode 100644 index 0000000..126fb2a --- /dev/null +++ b/src/iac_reverse/cli/__init__.py @@ -0,0 +1,6 @@ +"""CLI module for command-line interface.""" + +from iac_reverse.cli.cli import cli, main +from iac_reverse.cli.profile_loader import ProfileLoader, ProfileLoaderError + +__all__ = ["cli", "main", "ProfileLoader", "ProfileLoaderError"] diff --git a/src/iac_reverse/cli/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/cli/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..589de81 Binary files /dev/null and b/src/iac_reverse/cli/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/cli/__pycache__/cli.cpython-313.pyc b/src/iac_reverse/cli/__pycache__/cli.cpython-313.pyc new file mode 100644 index 0000000..b6b24be Binary files /dev/null and b/src/iac_reverse/cli/__pycache__/cli.cpython-313.pyc differ diff --git a/src/iac_reverse/cli/__pycache__/profile_loader.cpython-313.pyc b/src/iac_reverse/cli/__pycache__/profile_loader.cpython-313.pyc new file mode 100644 index 0000000..22e7273 Binary files /dev/null and b/src/iac_reverse/cli/__pycache__/profile_loader.cpython-313.pyc differ diff --git a/src/iac_reverse/cli/cli.py b/src/iac_reverse/cli/cli.py new file mode 100644 index 0000000..2cbf564 --- /dev/null +++ b/src/iac_reverse/cli/cli.py @@ -0,0 +1,444 @@ +"""CLI entry point for the IaC Reverse Engineering tool. + +Provides commands for scanning infrastructure, generating Terraform code, +running incremental diffs, validating output, and authenticating via Authentik SSO. +""" + +import sys +from pathlib import Path +from typing import Optional + +import click +import yaml + +from iac_reverse.models import ( + ProviderType, + ScanProfile, + ScanProgress, +) + + +def _load_scan_profile(profile_path: str) -> ScanProfile: + """Load a ScanProfile from a YAML file. + + Args: + profile_path: Path to the YAML scan profile file. + + Returns: + A ScanProfile instance. + + Raises: + click.ClickException: If the file cannot be read or parsed. + """ + path = Path(profile_path) + if not path.exists(): + raise click.ClickException(f"Profile not found: {profile_path}") + + try: + with open(path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + except yaml.YAMLError as e: + raise click.ClickException(f"Invalid YAML in profile: {e}") + + if not isinstance(data, dict): + raise click.ClickException("Profile must be a YAML mapping") + + provider_str = data.get("provider", "") + try: + provider = ProviderType(provider_str) + except ValueError: + raise click.ClickException( + f"Unknown provider '{provider_str}'. " + f"Supported: {[p.value for p in ProviderType]}" + ) + + return ScanProfile( + provider=provider, + credentials=data.get("credentials", {}), + endpoints=data.get("endpoints"), + resource_type_filters=data.get("resource_type_filters"), + authentik_token=data.get("authentik_token"), + ) + + +def _create_plugin(profile: ScanProfile): + """Create the appropriate provider plugin for a scan profile. + + Args: + profile: The ScanProfile specifying the provider. + + Returns: + A ProviderPlugin instance for the profile's provider. + + Raises: + click.ClickException: If the provider plugin cannot be created. + """ + from iac_reverse.scanner.docker_swarm_plugin import DockerSwarmPlugin + from iac_reverse.scanner.kubernetes_plugin import KubernetesPlugin + from iac_reverse.scanner.synology_plugin import SynologyPlugin + from iac_reverse.scanner.harvester_plugin import HarvesterPlugin + from iac_reverse.scanner.bare_metal_plugin import BareMetalPlugin + from iac_reverse.scanner.windows_plugin import WindowsPlugin + + plugin_map = { + ProviderType.DOCKER_SWARM: DockerSwarmPlugin, + ProviderType.KUBERNETES: KubernetesPlugin, + ProviderType.SYNOLOGY: SynologyPlugin, + ProviderType.HARVESTER: HarvesterPlugin, + ProviderType.BARE_METAL: BareMetalPlugin, + ProviderType.WINDOWS: WindowsPlugin, + } + + plugin_class = plugin_map.get(profile.provider) + if plugin_class is None: + raise click.ClickException( + f"No plugin available for provider '{profile.provider.value}'" + ) + + return plugin_class() + + +def _progress_callback(progress: ScanProgress) -> None: + """Display scan progress to the user.""" + click.echo( + f" [{progress.resource_types_completed}/{progress.total_resource_types}] " + f"Scanning {progress.current_resource_type}... " + f"({progress.resources_discovered} resources found)" + ) + + +@click.group() +@click.version_option(version="0.1.0", prog_name="iac-reverse") +def cli(): + """IaC Reverse Engineering Tool. + + Reverse-engineer on-premises infrastructure into Terraform HCL code and state files. + """ + pass + + +@cli.command() +@click.option( + "--profile", + required=True, + type=click.Path(exists=True), + help="Path to YAML scan profile.", +) +def scan(profile: str): + """Scan infrastructure and display discovered resources. + + Loads the scan profile, connects to the provider, and discovers + all matching resources. + """ + from iac_reverse.scanner.scanner import Scanner + + click.echo(f"Loading scan profile: {profile}") + scan_profile = _load_scan_profile(profile) + + click.echo(f"Provider: {scan_profile.provider.value}") + click.echo("Creating plugin...") + plugin = _create_plugin(scan_profile) + + click.echo("Starting scan...") + scanner = Scanner(profile=scan_profile, plugin=plugin) + + try: + result = scanner.scan(progress_callback=_progress_callback) + except Exception as e: + raise click.ClickException(f"Scan failed: {e}") + + click.echo("") + click.echo(f"Scan complete: {len(result.resources)} resources discovered") + + if result.warnings: + click.echo(f"Warnings: {len(result.warnings)}") + for w in result.warnings: + click.echo(f" ⚠ {w}") + + if result.errors: + click.echo(f"Errors: {len(result.errors)}") + for e in result.errors: + click.echo(f" ✗ {e}") + + +@cli.command() +@click.option( + "--profile", + required=True, + type=click.Path(exists=True), + help="Path to YAML scan profile.", +) +@click.option( + "--output-dir", + required=True, + type=click.Path(), + help="Output directory for generated Terraform files.", +) +def generate(profile: str, output_dir: str): + """Run full pipeline: scan → resolve → generate → state → validate. + + Scans infrastructure, resolves dependencies, generates Terraform HCL + code, builds state file, and validates the output. + """ + from iac_reverse.scanner.scanner import Scanner + from iac_reverse.resolver.resolver import DependencyResolver + from iac_reverse.generator.code_generator import CodeGenerator + from iac_reverse.state_builder.state_builder import StateBuilder + from iac_reverse.validator.validator import Validator + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Step 1: Scan + click.echo(f"Loading scan profile: {profile}") + scan_profile = _load_scan_profile(profile) + plugin = _create_plugin(scan_profile) + + click.echo("Step 1/5: Scanning infrastructure...") + scanner = Scanner(profile=scan_profile, plugin=plugin) + + try: + scan_result = scanner.scan(progress_callback=_progress_callback) + except Exception as e: + raise click.ClickException(f"Scan failed: {e}") + + click.echo(f" Found {len(scan_result.resources)} resources") + + # Step 2: Resolve dependencies + click.echo("Step 2/5: Resolving dependencies...") + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + click.echo( + f" Resolved {len(graph.relationships)} relationships, " + f"{len(graph.cycles)} cycles detected" + ) + + # Step 3: Generate code + click.echo("Step 3/5: Generating Terraform code...") + generator = CodeGenerator() + code_result = generator.generate(graph, [scan_profile]) + click.echo(f" Generated {len(code_result.resource_files)} resource files") + + # Write generated files to output directory + for gen_file in code_result.resource_files: + file_path = output_path / gen_file.filename + file_path.write_text(gen_file.content, encoding="utf-8") + + if code_result.variables_file.content: + (output_path / code_result.variables_file.filename).write_text( + code_result.variables_file.content, encoding="utf-8" + ) + if code_result.provider_file.content: + (output_path / code_result.provider_file.filename).write_text( + code_result.provider_file.content, encoding="utf-8" + ) + + # Step 4: Build state + click.echo("Step 4/5: Building Terraform state...") + state_builder = StateBuilder() + state_file = state_builder.build(code_result, graph, provider_version="1.0.0") + state_json = state_file.to_json() + (output_path / "terraform.tfstate").write_text(state_json, encoding="utf-8") + click.echo(f" State file: {len(state_file.resources)} entries") + + if state_builder.unmapped_resources: + click.echo(f" Unmapped: {len(state_builder.unmapped_resources)} resources") + + # Step 5: Validate + click.echo("Step 5/5: Validating output...") + validator = Validator() + validation = validator.validate(str(output_path)) + + if validation.validate_success: + click.echo(" ✓ Validation passed") + else: + click.echo(" ✗ Validation failed") + for err in validation.errors: + click.echo(f" {err.file}:{err.line} - {err.message}") + + # Summary + click.echo("") + click.echo("Generation complete:") + click.echo(f" Output directory: {output_dir}") + click.echo(f" Resource files: {len(code_result.resource_files)}") + click.echo(f" Total resources: {len(scan_result.resources)}") + + +@cli.command() +@click.option( + "--profile", + required=True, + type=click.Path(exists=True), + help="Path to YAML scan profile.", +) +def diff(profile: str): + """Run incremental scan and display changes. + + Loads the previous snapshot, runs a new scan, compares results, + and displays the change summary. + """ + from iac_reverse.scanner.scanner import Scanner + from iac_reverse.incremental.snapshot_store import SnapshotStore + from iac_reverse.incremental.change_detector import ChangeDetector + + click.echo(f"Loading scan profile: {profile}") + scan_profile = _load_scan_profile(profile) + plugin = _create_plugin(scan_profile) + + # Load previous snapshot + click.echo("Loading previous snapshot...") + snapshot_store = SnapshotStore() + scanner = Scanner(profile=scan_profile, plugin=plugin) + profile_hash = scanner._compute_profile_hash() + previous = snapshot_store.load_previous(profile_hash) + + if previous is None: + click.echo(" No previous snapshot found (first scan)") + + # Run current scan + click.echo("Scanning infrastructure...") + try: + current = scanner.scan(progress_callback=_progress_callback) + except Exception as e: + raise click.ClickException(f"Scan failed: {e}") + + # Compare + click.echo("Comparing with previous scan...") + detector = ChangeDetector() + summary = detector.compare(current, previous) + + # Store new snapshot + snapshot_store.store_snapshot(current, profile_hash) + click.echo(" Snapshot saved") + + # Display results + click.echo("") + click.echo("Change Summary:") + click.echo(f" Added: {summary.added_count}") + click.echo(f" Removed: {summary.removed_count}") + click.echo(f" Modified: {summary.modified_count}") + + if summary.changes: + click.echo("") + for change in summary.changes: + symbol = {"added": "+", "removed": "-", "modified": "~"} + s = symbol.get(change.change_type.value, "?") + click.echo( + f" {s} {change.resource_type}/{change.resource_name}" + ) + + +@cli.command() +@click.option( + "--dir", + "output_dir", + required=True, + type=click.Path(exists=True), + help="Path to directory containing Terraform output to validate.", +) +def validate(output_dir: str): + """Validate existing Terraform output. + + Runs terraform init, validate, and plan against the specified + directory and reports results. + """ + from iac_reverse.validator.validator import Validator + + click.echo(f"Validating: {output_dir}") + validator = Validator() + result = validator.validate(output_dir) + + click.echo("") + click.echo("Validation Results:") + click.echo(f" terraform init: {'✓' if result.init_success else '✗'}") + click.echo(f" terraform validate: {'✓' if result.validate_success else '✗'}") + click.echo(f" terraform plan: {'✓' if result.plan_success else '✗'}") + + if result.correction_attempts > 0: + click.echo(f" Auto-corrections: {result.correction_attempts}") + + if result.errors: + click.echo("") + click.echo("Errors:") + for err in result.errors: + location = f"{err.file}:{err.line}" if err.file else "(general)" + click.echo(f" ✗ {location} - {err.message}") + + if result.planned_changes: + click.echo("") + click.echo(f"Planned Changes ({len(result.planned_changes)}):") + for change in result.planned_changes: + click.echo( + f" {change.change_type}: {change.resource_address}" + ) + + if result.validate_success and result.plan_success: + click.echo("") + click.echo("✓ All validations passed - no drift detected") + elif result.validate_success and not result.plan_success: + click.echo("") + click.echo("⚠ Validation passed but drift detected") + + +@cli.command() +@click.option( + "--url", + required=True, + help="Authentik instance URL (e.g., https://auth.internal.lab).", +) +@click.option( + "--client-id", + required=True, + help="OAuth2 client ID for this tool.", +) +@click.option( + "--client-secret", + prompt=True, + hide_input=True, + help="OAuth2 client secret (prompted if not provided).", +) +def login(url: str, client_id: str, client_secret: str): + """Authenticate with Authentik SSO. + + Performs OAuth2/OIDC authentication and stores the token + for use by subsequent commands. + """ + from iac_reverse.auth.authentik_auth import ( + AuthentikAuthProvider, + AuthentikConfig, + AuthenticationError, + ) + + click.echo(f"Authenticating with Authentik at {url}...") + + config = AuthentikConfig( + base_url=url, + client_id=client_id, + client_secret=client_secret, + ) + + provider = AuthentikAuthProvider() + + try: + session = provider.authenticate_user(config) + except AuthenticationError as e: + raise click.ClickException(f"Authentication failed: {e}") + + # Store token in local config directory + token_dir = Path(".iac-reverse") + token_dir.mkdir(parents=True, exist_ok=True) + token_file = token_dir / "token" + token_file.write_text(session.access_token, encoding="utf-8") + + click.echo(f"✓ Authenticated as user: {session.user_id}") + click.echo(f" Groups: {', '.join(session.groups) if session.groups else 'none'}") + click.echo(f" Token stored in {token_file}") + + +def main(): + """Main entry point for the CLI.""" + cli() + + +if __name__ == "__main__": + main() diff --git a/src/iac_reverse/cli/profile_loader.py b/src/iac_reverse/cli/profile_loader.py new file mode 100644 index 0000000..cf4341b --- /dev/null +++ b/src/iac_reverse/cli/profile_loader.py @@ -0,0 +1,184 @@ +"""Profile loader for YAML scan profiles with environment variable expansion. + +Handles loading single and multi-profile YAML files, expanding ${ENV_VAR} +and ${ENV_VAR:-default} patterns in credential field values. +""" + +import os +import re +from pathlib import Path +from typing import Any + +import yaml + +from iac_reverse.models import ProviderType, ScanProfile + + +# Pattern matches ${VAR_NAME} or ${VAR_NAME:-default_value} +_ENV_VAR_PATTERN = re.compile(r"\$\{([^}:]+)(?::-([^}]*))?\}") + + +class ProfileLoaderError(Exception): + """Raised when profile loading or env var expansion fails.""" + + pass + + +class ProfileLoader: + """Loads scan profiles from YAML files with environment variable expansion. + + Supports: + - Single profile YAML (a dict with provider, credentials, etc.) + - Multi-profile YAML (a list of profile dicts) + - ${ENV_VAR} expansion in credential values + - ${ENV_VAR:-default} syntax for defaults when env var is unset + """ + + def load(self, path: str) -> list[ScanProfile]: + """Load one or more ScanProfiles from a YAML file. + + Args: + path: Path to the YAML scan profile file. + + Returns: + A list of ScanProfile instances. + + Raises: + ProfileLoaderError: If the file cannot be read, parsed, or contains + invalid profile data. + """ + file_path = Path(path) + if not file_path.exists(): + raise ProfileLoaderError(f"Profile not found: {path}") + + try: + with open(file_path, "r", encoding="utf-8") as f: + data = yaml.safe_load(f) + except yaml.YAMLError as e: + raise ProfileLoaderError(f"Invalid YAML in profile: {e}") + + if data is None: + raise ProfileLoaderError("Profile file is empty") + + if isinstance(data, list): + # Multi-profile YAML + profiles = [] + for i, item in enumerate(data): + if not isinstance(item, dict): + raise ProfileLoaderError( + f"Profile at index {i} must be a YAML mapping" + ) + profiles.append(self._parse_profile(item, index=i)) + return profiles + elif isinstance(data, dict): + # Single profile YAML + return [self._parse_profile(data)] + else: + raise ProfileLoaderError( + "Profile must be a YAML mapping or a list of mappings" + ) + + def expand_env_vars(self, value: str) -> str: + """Expand ${ENV_VAR} and ${ENV_VAR:-default} patterns in a string. + + Args: + value: String potentially containing env var references. + + Returns: + The string with all env var references replaced by their values. + + Raises: + ProfileLoaderError: If an env var is not set and no default is provided. + """ + + def _replace(match: re.Match) -> str: + var_name = match.group(1) + default_value = match.group(2) + + env_value = os.environ.get(var_name) + if env_value is not None: + return env_value + + if default_value is not None: + return default_value + + raise ProfileLoaderError( + f"Environment variable '{var_name}' is not set and no default provided" + ) + + return _ENV_VAR_PATTERN.sub(_replace, value) + + def _parse_profile( + self, data: dict[str, Any], index: int | None = None + ) -> ScanProfile: + """Parse a single profile dict into a ScanProfile. + + Args: + data: Dictionary with profile configuration. + index: Optional index for error messages in multi-profile files. + + Returns: + A ScanProfile instance with env vars expanded in credentials. + + Raises: + ProfileLoaderError: If required fields are missing or invalid. + """ + context = f" at index {index}" if index is not None else "" + + provider_str = data.get("provider") + if not provider_str: + raise ProfileLoaderError(f"Missing 'provider' field{context}") + + try: + provider = ProviderType(provider_str) + except ValueError: + raise ProfileLoaderError( + f"Unknown provider '{provider_str}'{context}. " + f"Supported: {[p.value for p in ProviderType]}" + ) + + credentials = data.get("credentials", {}) + if not isinstance(credentials, dict): + raise ProfileLoaderError( + f"'credentials' must be a mapping{context}" + ) + + # Expand env vars in credential values recursively + expanded_credentials = self._expand_credentials(credentials) + + endpoints = data.get("endpoints") + resource_type_filters = data.get("resource_type_filters") + authentik_token = data.get("authentik_token") + + # Expand env vars in authentik_token if it's a string + if isinstance(authentik_token, str): + authentik_token = self.expand_env_vars(authentik_token) + + return ScanProfile( + provider=provider, + credentials=expanded_credentials, + endpoints=endpoints, + resource_type_filters=resource_type_filters, + authentik_token=authentik_token, + ) + + def _expand_credentials(self, credentials: dict[str, Any]) -> dict[str, str]: + """Recursively expand environment variables in credential values. + + Args: + credentials: Dictionary of credential key-value pairs. + + Returns: + Dictionary with all string values having env vars expanded. + """ + expanded: dict[str, str] = {} + for key, value in credentials.items(): + if isinstance(value, str): + expanded[key] = self.expand_env_vars(value) + elif isinstance(value, dict): + # Recursively expand nested dicts + expanded[key] = self._expand_credentials(value) + else: + # Keep non-string values as-is (numbers, booleans, etc.) + expanded[key] = value + return expanded diff --git a/src/iac_reverse/generator/__init__.py b/src/iac_reverse/generator/__init__.py new file mode 100644 index 0000000..af00c16 --- /dev/null +++ b/src/iac_reverse/generator/__init__.py @@ -0,0 +1,15 @@ +"""Code generator module for Terraform HCL output.""" + +from iac_reverse.generator.code_generator import CodeGenerator +from iac_reverse.generator.provider_block import ProviderBlockGenerator +from iac_reverse.generator.resource_merger import ResourceMerger +from iac_reverse.generator.sanitize import sanitize_identifier +from iac_reverse.generator.variable_extractor import VariableExtractor + +__all__ = [ + "CodeGenerator", + "ProviderBlockGenerator", + "ResourceMerger", + "VariableExtractor", + "sanitize_identifier", +] diff --git a/src/iac_reverse/generator/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/generator/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..401294b Binary files /dev/null and b/src/iac_reverse/generator/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/generator/__pycache__/code_generator.cpython-313.pyc b/src/iac_reverse/generator/__pycache__/code_generator.cpython-313.pyc new file mode 100644 index 0000000..e530002 Binary files /dev/null and b/src/iac_reverse/generator/__pycache__/code_generator.cpython-313.pyc differ diff --git a/src/iac_reverse/generator/__pycache__/provider_block.cpython-313.pyc b/src/iac_reverse/generator/__pycache__/provider_block.cpython-313.pyc new file mode 100644 index 0000000..2f10501 Binary files /dev/null and b/src/iac_reverse/generator/__pycache__/provider_block.cpython-313.pyc differ diff --git a/src/iac_reverse/generator/__pycache__/resource_merger.cpython-313.pyc b/src/iac_reverse/generator/__pycache__/resource_merger.cpython-313.pyc new file mode 100644 index 0000000..497ecb9 Binary files /dev/null and b/src/iac_reverse/generator/__pycache__/resource_merger.cpython-313.pyc differ diff --git a/src/iac_reverse/generator/__pycache__/sanitize.cpython-313.pyc b/src/iac_reverse/generator/__pycache__/sanitize.cpython-313.pyc new file mode 100644 index 0000000..53b32ae Binary files /dev/null and b/src/iac_reverse/generator/__pycache__/sanitize.cpython-313.pyc differ diff --git a/src/iac_reverse/generator/__pycache__/variable_extractor.cpython-313.pyc b/src/iac_reverse/generator/__pycache__/variable_extractor.cpython-313.pyc new file mode 100644 index 0000000..686b1b2 Binary files /dev/null and b/src/iac_reverse/generator/__pycache__/variable_extractor.cpython-313.pyc differ diff --git a/src/iac_reverse/generator/code_generator.py b/src/iac_reverse/generator/code_generator.py new file mode 100644 index 0000000..e21d7c3 --- /dev/null +++ b/src/iac_reverse/generator/code_generator.py @@ -0,0 +1,304 @@ +"""HCL code generator using Jinja2 templates. + +Produces Terraform HCL files from a DependencyGraph and list of ScanProfiles. +Organizes output by resource type (one .tf file per type), includes traceability +comments, architecture-specific tags/labels, and uses Terraform resource +references for inter-resource dependencies. +""" + +import logging +from collections import defaultdict + +from jinja2 import Environment, BaseLoader + +from iac_reverse.generator.sanitize import sanitize_identifier +from iac_reverse.models import ( + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + GeneratedFile, + ResourceRelationship, + ScanProfile, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Jinja2 HCL Templates +# --------------------------------------------------------------------------- + +_RESOURCE_BLOCK_TEMPLATE = """\ +{% for resource in resources %} +# Source: {{ resource.unique_id }} +resource "{{ resource.resource_type }}" "{{ resource.tf_name }}" { +{% for key, value in resource.attributes.items() %} + {{ key }} = {{ value }} +{% endfor %} +{% if resource.tags %} + + tags = { +{% for tag_key, tag_value in resource.tags.items() %} + "{{ tag_key }}" = "{{ tag_value }}" +{% endfor %} + } +{% endif %} +{% if resource.dependencies %} + +{% for dep in resource.dependencies %} + depends_on = [{{ dep }}] +{% endfor %} +{% endif %} +} +{% endfor %} +""" + +_RESOURCE_BLOCK_TEMPLATE_V2 = """\ +{% for resource in resources %} +# Source: {{ resource.unique_id }} +resource "{{ resource.resource_type }}" "{{ resource.tf_name }}" { +{% for key, value in resource.rendered_attributes %} + {{ key }} = {{ value }} +{% endfor %} +{% if resource.tags %} + + tags = { +{% for tag_key, tag_value in resource.tags.items() %} + "{{ tag_key }}" = "{{ tag_value }}" +{% endfor %} + } +{% endif %} +} + +{% endfor %} +""" + + +# --------------------------------------------------------------------------- +# Helper: format HCL attribute values +# --------------------------------------------------------------------------- + + +def _format_hcl_value(value: object) -> str: + """Format a Python value as an HCL literal string.""" + if isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, int): + return str(value) + elif isinstance(value, float): + return str(value) + elif isinstance(value, str): + # Escape quotes in strings + escaped = value.replace("\\", "\\\\").replace('"', '\\"') + return f'"{escaped}"' + elif isinstance(value, list): + items = [_format_hcl_value(item) for item in value] + return "[" + ", ".join(items) + "]" + elif isinstance(value, dict): + lines = [] + lines.append("{") + for k, v in value.items(): + lines.append(f' "{k}" = {_format_hcl_value(v)}') + lines.append(" }") + return "\n".join(lines) + else: + return f'"{value}"' + + +# --------------------------------------------------------------------------- +# Internal data structure for template rendering +# --------------------------------------------------------------------------- + + +class _RenderableResource: + """Internal representation of a resource ready for template rendering.""" + + def __init__( + self, + resource_type: str, + tf_name: str, + unique_id: str, + rendered_attributes: list[tuple[str, str]], + tags: dict[str, str], + ): + self.resource_type = resource_type + self.tf_name = tf_name + self.unique_id = unique_id + self.rendered_attributes = rendered_attributes + self.tags = tags + + +# --------------------------------------------------------------------------- +# CodeGenerator +# --------------------------------------------------------------------------- + + +class CodeGenerator: + """Generates Terraform HCL files from a dependency graph. + + Accepts a DependencyGraph and list of ScanProfiles, produces one .tf file + per resource type with traceability comments, architecture tags, and + Terraform resource references for dependencies. + """ + + def __init__(self) -> None: + """Initialize the code generator with Jinja2 environment.""" + self._env = Environment( + loader=BaseLoader(), + trim_blocks=True, + lstrip_blocks=True, + ) + self._template = self._env.from_string(_RESOURCE_BLOCK_TEMPLATE_V2) + + def generate( + self, graph: DependencyGraph, profiles: list[ScanProfile] + ) -> CodeGenerationResult: + """Generate Terraform HCL from a dependency graph. + + Args: + graph: The DependencyGraph containing resources and relationships. + profiles: List of ScanProfiles used during scanning. + + Returns: + CodeGenerationResult with resource_files, variables_file, and + provider_file. Variables and provider files are empty placeholders + (implemented in tasks 5.3 and 5.4). + """ + # Build lookup maps + resource_map: dict[str, DiscoveredResource] = { + r.unique_id: r for r in graph.resources + } + # Map from source_id -> list of (target_id, relationship) + relationships_by_source: dict[str, list[ResourceRelationship]] = defaultdict( + list + ) + for rel in graph.relationships: + relationships_by_source[rel.source_id].append(rel) + + # Group resources by type + resources_by_type: dict[str, list[DiscoveredResource]] = defaultdict(list) + for resource in graph.resources: + resources_by_type[resource.resource_type].append(resource) + + # Generate one file per resource type + resource_files: list[GeneratedFile] = [] + for resource_type, resources in sorted(resources_by_type.items()): + renderable_resources = [] + for resource in resources: + renderable = self._build_renderable( + resource, relationships_by_source, resource_map + ) + renderable_resources.append(renderable) + + content = self._template.render(resources=renderable_resources) + filename = f"{resource_type}.tf" + resource_files.append( + GeneratedFile( + filename=filename, + content=content, + resource_count=len(resources), + ) + ) + + # Placeholder files for tasks 5.3 and 5.4 + variables_file = GeneratedFile( + filename="variables.tf", + content="", + resource_count=0, + ) + provider_file = GeneratedFile( + filename="providers.tf", + content="", + resource_count=0, + ) + + return CodeGenerationResult( + resource_files=resource_files, + variables_file=variables_file, + provider_file=provider_file, + ) + + def _build_renderable( + self, + resource: DiscoveredResource, + relationships_by_source: dict[str, list[ResourceRelationship]], + resource_map: dict[str, DiscoveredResource], + ) -> _RenderableResource: + """Build a renderable resource with formatted attributes and references. + + For attributes that reference other resources in the graph, replaces + the hardcoded ID with a Terraform resource reference expression. + """ + tf_name = sanitize_identifier(resource.name) + + # Build a set of target IDs this resource references + target_ids_for_resource: dict[str, ResourceRelationship] = {} + for rel in relationships_by_source.get(resource.unique_id, []): + target_ids_for_resource[rel.target_id] = rel + + # Render attributes, replacing references with Terraform expressions + rendered_attributes: list[tuple[str, str]] = [] + for attr_key, attr_value in resource.attributes.items(): + resolved_value = self._resolve_attribute_value( + attr_value, target_ids_for_resource, resource_map + ) + rendered_attributes.append((attr_key, resolved_value)) + + # Generate architecture-specific tags + tags = self._generate_architecture_tags(resource) + + return _RenderableResource( + resource_type=resource.resource_type, + tf_name=tf_name, + unique_id=resource.unique_id, + rendered_attributes=rendered_attributes, + tags=tags, + ) + + def _resolve_attribute_value( + self, + value: object, + target_ids: dict[str, ResourceRelationship], + resource_map: dict[str, DiscoveredResource], + ) -> str: + """Resolve an attribute value, replacing resource IDs with Terraform references. + + If the value matches a target resource's unique_id or name, returns a + Terraform resource reference expression. Otherwise formats as HCL literal. + """ + if isinstance(value, str): + # Check if this string matches a target resource's unique_id + if value in target_ids: + target_resource = resource_map[value] + return self._make_terraform_reference(target_resource) + + # Check if this string matches a target resource's name + for target_id, rel in target_ids.items(): + target_resource = resource_map[target_id] + if value == target_resource.name: + return self._make_terraform_reference(target_resource) + + # Default: format as HCL literal + return _format_hcl_value(value) + + def _make_terraform_reference(self, target_resource: DiscoveredResource) -> str: + """Create a Terraform resource reference expression. + + Example: kubernetes_namespace.default.id + """ + target_tf_name = sanitize_identifier(target_resource.name) + return f"{target_resource.resource_type}.{target_tf_name}.id" + + def _generate_architecture_tags( + self, resource: DiscoveredResource + ) -> dict[str, str]: + """Generate architecture-specific tags/labels for a resource. + + Returns a dict of tag key-value pairs including the CPU architecture. + """ + tags: dict[str, str] = { + "arch": resource.architecture.value, + "managed_by": "iac-reverse", + } + return tags diff --git a/src/iac_reverse/generator/provider_block.py b/src/iac_reverse/generator/provider_block.py new file mode 100644 index 0000000..49fe4d6 --- /dev/null +++ b/src/iac_reverse/generator/provider_block.py @@ -0,0 +1,197 @@ +"""Provider block generator for Terraform HCL output. + +Generates a providers.tf file containing: +- A terraform { required_providers { ... } } block listing all providers used +- Individual provider configuration blocks with platform-specific settings + (endpoints, certificates, credentials) for each distinct provider type. +""" + +from __future__ import annotations + +from iac_reverse.models import ProviderType, ScanProfile, GeneratedFile + + +# --------------------------------------------------------------------------- +# Provider metadata: maps ProviderType to Terraform provider details +# --------------------------------------------------------------------------- + +# Each entry: (terraform_provider_name, source, version_constraint) +_PROVIDER_METADATA: dict[ProviderType, tuple[str, str, str]] = { + ProviderType.KUBERNETES: ( + "kubernetes", + "hashicorp/kubernetes", + "~> 2.0", + ), + ProviderType.DOCKER_SWARM: ( + "docker", + "kreuzwerker/docker", + "~> 3.0", + ), + ProviderType.SYNOLOGY: ( + "synology", + "synology-community/synology", + "~> 0.2", + ), + ProviderType.HARVESTER: ( + "harvester", + "harvester/harvester", + "~> 0.6", + ), + ProviderType.BARE_METAL: ( + "redfish", + "dell/redfish", + "~> 1.0", + ), + ProviderType.WINDOWS: ( + "windows", + "hashicorp/windows", + "~> 0.1", + ), +} + + +def _generate_provider_config( + provider_type: ProviderType, profile: ScanProfile +) -> str: + """Generate the provider configuration block for a given provider type. + + Uses credentials and endpoints from the ScanProfile to populate + platform-specific configuration attributes. + """ + tf_name = _PROVIDER_METADATA[provider_type][0] + lines: list[str] = [] + lines.append(f'provider "{tf_name}" {{') + + if provider_type == ProviderType.KUBERNETES: + host = profile.credentials.get("host", "") + cluster_ca = profile.credentials.get("cluster_ca_certificate", "") + token = profile.credentials.get("token", "") + lines.append(f' host = "{host}"') + lines.append(f' cluster_ca_certificate = "{cluster_ca}"') + lines.append(f' token = "{token}"') + + elif provider_type == ProviderType.DOCKER_SWARM: + host = profile.credentials.get("host", "") + cert_path = profile.credentials.get("cert_path", "") + lines.append(f' host = "{host}"') + lines.append(f' cert_path = "{cert_path}"') + + elif provider_type == ProviderType.SYNOLOGY: + url = profile.credentials.get("url", "") + username = profile.credentials.get("username", "") + password = profile.credentials.get("password", "") + lines.append(f' url = "{url}"') + lines.append(f' username = "{username}"') + lines.append(f' password = "{password}"') + + elif provider_type == ProviderType.HARVESTER: + kubeconfig = profile.credentials.get("kubeconfig", "") + lines.append(f' kubeconfig = "{kubeconfig}"') + + elif provider_type == ProviderType.BARE_METAL: + endpoint = profile.credentials.get("endpoint", "") + username = profile.credentials.get("username", "") + password = profile.credentials.get("password", "") + lines.append(f' endpoint = "{endpoint}"') + lines.append(f' username = "{username}"') + lines.append(f' password = "{password}"') + + elif provider_type == ProviderType.WINDOWS: + host = profile.credentials.get("host", "") + username = profile.credentials.get("username", "") + password = profile.credentials.get("password", "") + lines.append(f' host = "{host}"') + lines.append(f' username = "{username}"') + lines.append(f' password = "{password}"') + lines.append("") + lines.append(" winrm {") + winrm_port = profile.credentials.get("winrm_port", "5985") + winrm_use_ssl = profile.credentials.get("winrm_use_ssl", "false") + lines.append(f" port = {winrm_port}") + lines.append(f" use_ssl = {winrm_use_ssl}") + lines.append(" }") + + lines.append("}") + return "\n".join(lines) + + +def _generate_required_providers_block( + provider_types: set[ProviderType], +) -> str: + """Generate the terraform { required_providers { ... } } block.""" + lines: list[str] = [] + lines.append("terraform {") + lines.append(" required_providers {") + + for provider_type in sorted(provider_types, key=lambda p: p.value): + tf_name, source, version = _PROVIDER_METADATA[provider_type] + lines.append(f" {tf_name} = {{") + lines.append(f' source = "{source}"') + lines.append(f' version = "{version}"') + lines.append(" }") + + lines.append(" }") + lines.append("}") + return "\n".join(lines) + + +# --------------------------------------------------------------------------- +# ProviderBlockGenerator +# --------------------------------------------------------------------------- + + +class ProviderBlockGenerator: + """Generates Terraform provider configuration blocks. + + Accepts a list of ScanProfiles and a set of ProviderTypes used in the + generated code, and produces a providers.tf file containing: + - A terraform { required_providers { ... } } block + - Individual provider blocks with platform-specific configuration + """ + + def generate( + self, + profiles: list[ScanProfile], + provider_types: set[ProviderType], + ) -> GeneratedFile: + """Generate the providers.tf file content. + + Args: + profiles: List of ScanProfiles providing credentials/endpoints. + provider_types: Set of distinct ProviderTypes used in the code. + + Returns: + A GeneratedFile with filename "providers.tf" and the HCL content. + """ + # Build a map from ProviderType -> first matching profile + profile_map: dict[ProviderType, ScanProfile] = {} + for profile in profiles: + if profile.provider not in profile_map: + profile_map[profile.provider] = profile + + sections: list[str] = [] + + # 1. required_providers block + sections.append(_generate_required_providers_block(provider_types)) + + # 2. Individual provider configuration blocks + for provider_type in sorted(provider_types, key=lambda p: p.value): + profile = profile_map.get(provider_type) + if profile is not None: + sections.append( + _generate_provider_config(provider_type, profile) + ) + else: + # Generate a placeholder block if no profile matches + tf_name = _PROVIDER_METADATA[provider_type][0] + sections.append( + f'provider "{tf_name}" {{\n # No profile provided\n}}' + ) + + content = "\n\n".join(sections) + "\n" + + return GeneratedFile( + filename="providers.tf", + content=content, + resource_count=0, + ) diff --git a/src/iac_reverse/generator/resource_merger.py b/src/iac_reverse/generator/resource_merger.py new file mode 100644 index 0000000..888cbda --- /dev/null +++ b/src/iac_reverse/generator/resource_merger.py @@ -0,0 +1,59 @@ +"""Multi-provider resource merging with conflict resolution. + +Merges resources from multiple scan profiles into a unified inventory, +resolving naming conflicts by prefixing with the provider identifier. +""" + +from dataclasses import replace +from collections import defaultdict + +from iac_reverse.models import DiscoveredResource, ScanResult + + +class ResourceMerger: + """Merges resources from multiple ScanResult objects into a unified list. + + When resources from different providers share the same name, the merger + resolves the conflict by prefixing each conflicting resource's name with + its provider identifier (e.g., "kubernetes_nginx", "docker_swarm_nginx"). + + Provider-specific attributes are preserved unchanged. + """ + + def merge(self, scan_results: list[ScanResult]) -> list[DiscoveredResource]: + """Merge resources from multiple scan results into a unified list. + + Args: + scan_results: List of ScanResult objects, one per provider/scan profile. + + Returns: + A unified list of DiscoveredResource with naming conflicts resolved + by prefixing conflicting names with the provider identifier. + """ + # Collect all resources from all scan results + all_resources: list[DiscoveredResource] = [] + for result in scan_results: + all_resources.extend(result.resources) + + # Group resources by name to detect conflicts + resources_by_name: dict[str, list[DiscoveredResource]] = defaultdict(list) + for resource in all_resources: + resources_by_name[resource.name].append(resource) + + # Identify conflicting names: same name from different providers + conflicting_names: set[str] = set() + for name, resources in resources_by_name.items(): + providers = {r.provider for r in resources} + if len(providers) > 1: + conflicting_names.add(name) + + # Build the merged list, resolving conflicts + merged: list[DiscoveredResource] = [] + for resource in all_resources: + if resource.name in conflicting_names: + prefixed_name = f"{resource.provider.value}_{resource.name}" + merged.append(replace(resource, name=prefixed_name)) + else: + merged.append(resource) + + return merged diff --git a/src/iac_reverse/generator/sanitize.py b/src/iac_reverse/generator/sanitize.py new file mode 100644 index 0000000..e268175 --- /dev/null +++ b/src/iac_reverse/generator/sanitize.py @@ -0,0 +1,41 @@ +"""Identifier sanitization for Terraform resource names. + +Converts arbitrary resource names into valid Terraform identifiers +matching the pattern: ^[a-zA-Z_][a-zA-Z0-9_]*$ +""" + +import re + + +def sanitize_identifier(name: str) -> str: + """Convert a resource name to a valid Terraform identifier. + + Terraform identifiers must match: ^[a-zA-Z_][a-zA-Z0-9_]*$ + + Rules applied: + - Replace any character not in [a-zA-Z0-9_] with underscore + - Collapse multiple consecutive underscores into one + - If result starts with a digit, prepend an underscore + - If result is empty or only underscores, return "_resource" + + Args: + name: Any string resource name. + + Returns: + A valid Terraform identifier derived from the input. + """ + # Replace any non-alphanumeric/underscore character with underscore + result = re.sub(r"[^a-zA-Z0-9_]", "_", name) + + # Collapse multiple consecutive underscores into one + result = re.sub(r"_+", "_", result) + + # If result is only underscores or empty, return fallback + if not result or result.strip("_") == "": + return "_resource" + + # If starts with a digit, prepend underscore + if result[0].isdigit(): + result = "_" + result + + return result diff --git a/src/iac_reverse/generator/variable_extractor.py b/src/iac_reverse/generator/variable_extractor.py new file mode 100644 index 0000000..3d54c8d --- /dev/null +++ b/src/iac_reverse/generator/variable_extractor.py @@ -0,0 +1,203 @@ +"""Variable extraction logic for Terraform code generation. + +Identifies attribute values that appear in 2+ resources and extracts them +into Terraform variables with appropriate type expressions and defaults. +""" + +import logging +from collections import defaultdict + +from iac_reverse.models import DiscoveredResource, ExtractedVariable + +logger = logging.getLogger(__name__) + + +def _infer_type_expr(value: object) -> str: + """Infer a Terraform type expression from a Python value. + + Args: + value: The Python value to infer a type for. + + Returns: + A Terraform type expression string (e.g., "string", "number", "bool"). + """ + if isinstance(value, bool): + return "bool" + elif isinstance(value, int) or isinstance(value, float): + return "number" + elif isinstance(value, str): + return "string" + elif isinstance(value, list): + return "list(string)" + elif isinstance(value, dict): + return "map(string)" + else: + return "string" + + +def _format_default_value(value: object) -> str: + """Format a Python value as a Terraform default value literal. + + Args: + value: The Python value to format. + + Returns: + A string representation suitable for a Terraform variable default. + """ + if isinstance(value, bool): + return "true" if value else "false" + elif isinstance(value, int) or isinstance(value, float): + return str(value) + elif isinstance(value, str): + return f'"{value}"' + elif isinstance(value, list): + items = ", ".join(f'"{item}"' if isinstance(item, str) else str(item) for item in value) + return f"[{items}]" + elif isinstance(value, dict): + entries = ", ".join(f'"{k}" = "{v}"' for k, v in value.items()) + return "{" + entries + "}" + else: + return f'"{value}"' + + +def _make_hashable(value: object) -> object: + """Convert a value to a hashable representation for counting. + + Args: + value: Any Python value from resource attributes. + + Returns: + A hashable version of the value. + """ + if isinstance(value, dict): + return tuple(sorted(value.items())) + elif isinstance(value, list): + return tuple(value) + else: + return value + + +class VariableExtractor: + """Extracts shared attribute values into Terraform variables. + + Scans a list of DiscoveredResource objects, identifies attribute values + that appear in 2 or more resources, and creates ExtractedVariable instances + for each shared value. + """ + + def extract_variables( + self, resources: list[DiscoveredResource] + ) -> list[ExtractedVariable]: + """Identify shared attribute values and extract them as variables. + + For each attribute key, collects all values across all resources. + If a value appears in 2+ resources for the same attribute key, + it becomes a variable with the most common value as the default. + + Args: + resources: List of discovered resources to analyze. + + Returns: + List of ExtractedVariable instances for shared values. + """ + if len(resources) < 2: + return [] + + # Collect attribute values grouped by attribute key + # key -> {hashable_value -> [list of (resource_unique_id, original_value)]} + attr_values: dict[str, dict[object, list[tuple[str, object]]]] = defaultdict( + lambda: defaultdict(list) + ) + + for resource in resources: + for attr_key, attr_value in resource.attributes.items(): + # Skip complex nested structures (dicts/lists) for variable extraction + # as they are less likely to be meaningfully shared + if isinstance(attr_value, (dict, list)): + continue + hashable = _make_hashable(attr_value) + attr_values[attr_key][hashable].append( + (resource.unique_id, attr_value) + ) + + # Build extracted variables for values appearing in 2+ resources + variables: list[ExtractedVariable] = [] + + for attr_key, value_groups in sorted(attr_values.items()): + # Find all values that appear in 2+ resources for this key + shared_values = [ + (hv, entries) + for hv, entries in value_groups.items() + if len(entries) >= 2 + ] + + if not shared_values: + continue + + # If only one shared value exists for this key, use the key as the var name + # If multiple shared values exist, disambiguate with a suffix + for idx, (hashable_value, resource_entries) in enumerate(shared_values): + original_value = resource_entries[0][1] + used_by = [entry[0] for entry in resource_entries] + + # Determine the most common value among the shared values for this key + # The default is set to the most common value overall + most_common_entries = max(shared_values, key=lambda x: len(x[1])) + most_common_value = most_common_entries[1][0][1] + + # Use the most common value as default for the primary variable, + # but each variable's default is its own value + default_value = _format_default_value(original_value) + + if len(shared_values) == 1: + var_name = f"var_{attr_key}" + else: + # Disambiguate when multiple shared values exist for same key + var_name = f"var_{attr_key}_{idx}" + + type_expr = _infer_type_expr(original_value) + description = ( + f"Shared {attr_key} value extracted from " + f"{len(resource_entries)} resources" + ) + + variables.append( + ExtractedVariable( + name=var_name, + type_expr=type_expr, + default_value=default_value, + description=description, + used_by=used_by, + ) + ) + + return variables + + def generate_variables_tf( + self, variables: list[ExtractedVariable] + ) -> str: + """Generate Terraform variables.tf file content. + + Produces variable blocks with type, description, and default values. + + Args: + variables: List of extracted variables to render. + + Returns: + String content for a variables.tf file. + """ + if not variables: + return "" + + blocks: list[str] = [] + for var in variables: + block = ( + f'variable "{var.name}" {{\n' + f' type = {var.type_expr}\n' + f' description = "{var.description}"\n' + f' default = {var.default_value}\n' + f'}}' + ) + blocks.append(block) + + return "\n\n".join(blocks) + "\n" diff --git a/src/iac_reverse/incremental/__init__.py b/src/iac_reverse/incremental/__init__.py new file mode 100644 index 0000000..99b3dd0 --- /dev/null +++ b/src/iac_reverse/incremental/__init__.py @@ -0,0 +1,7 @@ +"""Incremental scan engine for change detection.""" + +from iac_reverse.incremental.change_detector import ChangeDetector +from iac_reverse.incremental.incremental_updater import IncrementalUpdater +from iac_reverse.incremental.snapshot_store import SnapshotStore + +__all__ = ["ChangeDetector", "IncrementalUpdater", "SnapshotStore"] diff --git a/src/iac_reverse/incremental/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/incremental/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..a386056 Binary files /dev/null and b/src/iac_reverse/incremental/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/incremental/__pycache__/change_detector.cpython-313.pyc b/src/iac_reverse/incremental/__pycache__/change_detector.cpython-313.pyc new file mode 100644 index 0000000..acb6420 Binary files /dev/null and b/src/iac_reverse/incremental/__pycache__/change_detector.cpython-313.pyc differ diff --git a/src/iac_reverse/incremental/__pycache__/incremental_updater.cpython-313.pyc b/src/iac_reverse/incremental/__pycache__/incremental_updater.cpython-313.pyc new file mode 100644 index 0000000..45e04e4 Binary files /dev/null and b/src/iac_reverse/incremental/__pycache__/incremental_updater.cpython-313.pyc differ diff --git a/src/iac_reverse/incremental/__pycache__/snapshot_store.cpython-313.pyc b/src/iac_reverse/incremental/__pycache__/snapshot_store.cpython-313.pyc new file mode 100644 index 0000000..6be5fcd Binary files /dev/null and b/src/iac_reverse/incremental/__pycache__/snapshot_store.cpython-313.pyc differ diff --git a/src/iac_reverse/incremental/change_detector.py b/src/iac_reverse/incremental/change_detector.py new file mode 100644 index 0000000..8f4fa9c --- /dev/null +++ b/src/iac_reverse/incremental/change_detector.py @@ -0,0 +1,144 @@ +"""Change detection and classification for incremental scans. + +Compares current scan results against previous snapshots to identify +added, removed, and modified resources. +""" + +from typing import Optional + +from iac_reverse.models import ( + ChangeSummary, + ChangeType, + DiscoveredResource, + ResourceChange, + ScanResult, +) + + +class ChangeDetector: + """Detects and classifies changes between scan results. + + Compares resources by unique_id to determine which resources + have been added, removed, or modified between scans. + """ + + def compare( + self, current: ScanResult, previous: Optional[ScanResult] + ) -> ChangeSummary: + """Compare current scan against a previous scan result. + + Args: + current: The current scan result. + previous: The previous scan result, or None for first scan. + + Returns: + A ChangeSummary with counts and list of ResourceChange objects. + If previous is None, all current resources are classified as ADDED. + """ + if previous is None: + return self._handle_first_scan(current) + + current_map = {r.unique_id: r for r in current.resources} + previous_map = {r.unique_id: r for r in previous.resources} + + changes: list[ResourceChange] = [] + + # Detect ADDED resources: in current but not in previous + for uid, resource in current_map.items(): + if uid not in previous_map: + changes.append( + ResourceChange( + resource_id=resource.unique_id, + resource_type=resource.resource_type, + resource_name=resource.name, + change_type=ChangeType.ADDED, + changed_attributes=None, + ) + ) + + # Detect REMOVED resources: in previous but not in current + for uid, resource in previous_map.items(): + if uid not in current_map: + changes.append( + ResourceChange( + resource_id=resource.unique_id, + resource_type=resource.resource_type, + resource_name=resource.name, + change_type=ChangeType.REMOVED, + changed_attributes=None, + ) + ) + + # Detect MODIFIED resources: same unique_id but attributes differ + for uid in current_map: + if uid in previous_map: + changed_attrs = self._diff_attributes( + current_map[uid], previous_map[uid] + ) + if changed_attrs: + resource = current_map[uid] + changes.append( + ResourceChange( + resource_id=resource.unique_id, + resource_type=resource.resource_type, + resource_name=resource.name, + change_type=ChangeType.MODIFIED, + changed_attributes=changed_attrs, + ) + ) + + added_count = sum(1 for c in changes if c.change_type == ChangeType.ADDED) + removed_count = sum(1 for c in changes if c.change_type == ChangeType.REMOVED) + modified_count = sum(1 for c in changes if c.change_type == ChangeType.MODIFIED) + + return ChangeSummary( + added_count=added_count, + removed_count=removed_count, + modified_count=modified_count, + changes=changes, + ) + + def _handle_first_scan(self, current: ScanResult) -> ChangeSummary: + """Handle first scan with no previous snapshot. + + All resources in the current scan are classified as ADDED. + """ + changes = [ + ResourceChange( + resource_id=resource.unique_id, + resource_type=resource.resource_type, + resource_name=resource.name, + change_type=ChangeType.ADDED, + changed_attributes=None, + ) + for resource in current.resources + ] + + return ChangeSummary( + added_count=len(changes), + removed_count=0, + modified_count=0, + changes=changes, + ) + + def _diff_attributes( + self, current: DiscoveredResource, previous: DiscoveredResource + ) -> Optional[dict]: + """Compare attributes between two versions of the same resource. + + Returns a dict of changed attributes with 'old' and 'new' values, + or None if no attributes differ. + """ + if current.attributes == previous.attributes: + return None + + changed: dict = {} + all_keys = set(current.attributes.keys()) | set(previous.attributes.keys()) + + for key in all_keys: + old_val = previous.attributes.get(key) + new_val = current.attributes.get(key) + if old_val != new_val: + changed[key] = {"old": old_val, "new": new_val} + + return changed if changed else None diff --git a/src/iac_reverse/incremental/incremental_updater.py b/src/iac_reverse/incremental/incremental_updater.py new file mode 100644 index 0000000..7385bfe --- /dev/null +++ b/src/iac_reverse/incremental/incremental_updater.py @@ -0,0 +1,339 @@ +"""Incremental updater for Terraform IaC files and state. + +Applies a ChangeSummary to an existing output directory, modifying only +the .tf files and state file that contain changed resources. Supports +adding new resource blocks, removing existing blocks, and updating +modified resource attributes without full regeneration. +""" + +import json +import logging +import re +from pathlib import Path +from typing import Optional + +from iac_reverse.generator.code_generator import _format_hcl_value +from iac_reverse.generator.sanitize import sanitize_identifier +from iac_reverse.models import ChangeSummary, ChangeType, ResourceChange + +logger = logging.getLogger(__name__) + + +class IncrementalUpdater: + """Applies incremental changes to Terraform IaC files and state. + + Accepts a ChangeSummary and an output directory path. Modifies only + the .tf files containing changed resources (one .tf file per resource + type) and updates the terraform.tfstate file accordingly. + + For REMOVED resources: removes the resource block from the .tf file + and the corresponding entry from the state file. + + For ADDED resources: appends a new resource block to the appropriate + .tf file (creating the file if it doesn't exist). + + For MODIFIED resources: updates the existing resource block with new + attribute values. + """ + + def __init__( + self, + change_summary: ChangeSummary, + output_dir: str, + resource_attributes: Optional[dict[str, dict]] = None, + ) -> None: + """Initialize the IncrementalUpdater. + + Args: + change_summary: The ChangeSummary describing what changed. + output_dir: Path to the output directory containing .tf and + state files. + resource_attributes: Optional mapping of resource_id to full + attribute dict for ADDED/MODIFIED resources. Required for + generating resource blocks for added resources. + """ + self._change_summary = change_summary + self._output_dir = Path(output_dir) + self._resource_attributes = resource_attributes or {} + self._modified_files: set[str] = set() + + @property + def modified_files(self) -> set[str]: + """Return the set of file paths that were modified during apply.""" + return set(self._modified_files) + + def apply(self) -> None: + """Apply all changes from the ChangeSummary to the output directory. + + Processes removed, added, and modified resources, updating only + the affected .tf files and the state file. + """ + for change in self._change_summary.changes: + if change.change_type == ChangeType.REMOVED: + self._handle_removed(change) + elif change.change_type == ChangeType.ADDED: + self._handle_added(change) + elif change.change_type == ChangeType.MODIFIED: + self._handle_modified(change) + + def _handle_removed(self, change: ResourceChange) -> None: + """Remove a resource block from its .tf file and state entry. + + Args: + change: The ResourceChange describing the removed resource. + """ + tf_file = self._get_tf_file_path(change.resource_type) + if tf_file.exists(): + self._remove_resource_block(tf_file, change) + self._modified_files.add(str(tf_file)) + + self._remove_state_entry(change) + + def _handle_added(self, change: ResourceChange) -> None: + """Add a new resource block to the appropriate .tf file. + + Args: + change: The ResourceChange describing the added resource. + """ + tf_file = self._get_tf_file_path(change.resource_type) + attributes = self._resource_attributes.get(change.resource_id, {}) + self._add_resource_block(tf_file, change, attributes) + self._modified_files.add(str(tf_file)) + + def _handle_modified(self, change: ResourceChange) -> None: + """Update an existing resource block with new attribute values. + + Args: + change: The ResourceChange describing the modified resource. + """ + tf_file = self._get_tf_file_path(change.resource_type) + if tf_file.exists(): + self._update_resource_block(tf_file, change) + self._modified_files.add(str(tf_file)) + + def _get_tf_file_path(self, resource_type: str) -> Path: + """Get the .tf file path for a given resource type. + + Each resource type maps to a file named .tf. + + Args: + resource_type: The Terraform resource type string. + + Returns: + Path to the .tf file for this resource type. + """ + return self._output_dir / f"{resource_type}.tf" + + def _remove_resource_block( + self, tf_file: Path, change: ResourceChange + ) -> None: + """Remove a resource block from a .tf file. + + Identifies the block by matching the resource type and sanitized + resource name in the resource declaration line. + + Args: + tf_file: Path to the .tf file. + change: The ResourceChange identifying the resource to remove. + """ + content = tf_file.read_text(encoding="utf-8") + tf_name = sanitize_identifier(change.resource_name) + + # Pattern to match the full resource block including optional comment + # Matches: optional comment line + resource "type" "name" { ... } + pattern = self._build_block_pattern(change.resource_type, tf_name) + new_content = re.sub(pattern, "", content) + + # Clean up excessive blank lines + new_content = re.sub(r"\n{3,}", "\n\n", new_content) + new_content = new_content.strip() + if new_content: + new_content += "\n" + + tf_file.write_text(new_content, encoding="utf-8") + + def _add_resource_block( + self, tf_file: Path, change: ResourceChange, attributes: dict + ) -> None: + """Add a new resource block to a .tf file. + + Creates the file if it doesn't exist. Appends the block at the end. + + Args: + tf_file: Path to the .tf file. + change: The ResourceChange describing the added resource. + attributes: The full attribute dict for the resource. + """ + tf_name = sanitize_identifier(change.resource_name) + block = self._render_resource_block( + change.resource_type, tf_name, change.resource_id, attributes + ) + + if tf_file.exists(): + content = tf_file.read_text(encoding="utf-8") + if content and not content.endswith("\n"): + content += "\n" + content += "\n" + block + else: + content = block + + tf_file.write_text(content, encoding="utf-8") + + def _update_resource_block( + self, tf_file: Path, change: ResourceChange + ) -> None: + """Update an existing resource block with changed attributes. + + Replaces only the changed attribute lines within the block. + + Args: + tf_file: Path to the .tf file. + change: The ResourceChange with changed_attributes dict. + """ + if not change.changed_attributes: + return + + content = tf_file.read_text(encoding="utf-8") + tf_name = sanitize_identifier(change.resource_name) + + # Find the resource block + pattern = self._build_block_pattern(change.resource_type, tf_name) + match = re.search(pattern, content) + if not match: + logger.warning( + "Could not find resource block for %s.%s in %s", + change.resource_type, + tf_name, + tf_file, + ) + return + + block = match.group(0) + updated_block = block + + for attr_name, attr_change in change.changed_attributes.items(): + new_value = attr_change.get("new") + if new_value is None: + # Attribute was removed - remove the line + attr_pattern = re.compile( + rf"^[ \t]*{re.escape(attr_name)}\s*=\s*.*$\n?", + re.MULTILINE, + ) + updated_block = attr_pattern.sub("", updated_block) + else: + # Attribute was added or changed - update/add the line + hcl_value = _format_hcl_value(new_value) + attr_pattern = re.compile( + rf"^([ \t]*){re.escape(attr_name)}\s*=\s*.*$", + re.MULTILINE, + ) + attr_match = attr_pattern.search(updated_block) + if attr_match: + # Replace existing attribute line + indent = attr_match.group(1) + replacement = f"{indent}{attr_name} = {hcl_value}" + updated_block = attr_pattern.sub(replacement, updated_block) + else: + # Add new attribute before the closing brace + updated_block = re.sub( + r"(\n})", + f"\n {attr_name} = {hcl_value}\\1", + updated_block, + count=1, + ) + + content = content.replace(block, updated_block) + tf_file.write_text(content, encoding="utf-8") + + def _remove_state_entry(self, change: ResourceChange) -> None: + """Remove a resource entry from the terraform.tfstate file. + + Args: + change: The ResourceChange identifying the resource to remove. + """ + state_file = self._output_dir / "terraform.tfstate" + if not state_file.exists(): + return + + content = state_file.read_text(encoding="utf-8") + try: + state = json.loads(content) + except json.JSONDecodeError: + logger.warning("Could not parse state file: %s", state_file) + return + + tf_name = sanitize_identifier(change.resource_name) + resources = state.get("resources", []) + state["resources"] = [ + r + for r in resources + if not ( + r.get("type") == change.resource_type + and r.get("name") == tf_name + ) + ] + + # Increment serial to indicate state change + state["serial"] = state.get("serial", 0) + 1 + + state_file.write_text( + json.dumps(state, indent=2), encoding="utf-8" + ) + self._modified_files.add(str(state_file)) + + def _build_block_pattern( + self, resource_type: str, tf_name: str + ) -> re.Pattern: + """Build a regex pattern to match a full resource block. + + Matches an optional comment line (# Source: ...) followed by the + resource declaration and its body enclosed in braces. + + Args: + resource_type: The Terraform resource type. + tf_name: The sanitized Terraform resource name. + + Returns: + A compiled regex pattern matching the full block. + """ + # Match optional comment + resource block with balanced braces + # The block body can contain nested braces (e.g., tags = { ... }) + escaped_type = re.escape(resource_type) + escaped_name = re.escape(tf_name) + pattern = ( + rf"(?:# Source:.*\n)?" + rf'resource\s+"{escaped_type}"\s+"{escaped_name}"\s*\{{' + rf"[^{{}}]*(?:\{{[^{{}}]*\}}[^{{}}]*)*" + rf"\}}\n?" + ) + return re.compile(pattern, re.DOTALL) + + def _render_resource_block( + self, + resource_type: str, + tf_name: str, + source_id: str, + attributes: dict, + ) -> str: + """Render a Terraform resource block as HCL text. + + Args: + resource_type: The Terraform resource type. + tf_name: The sanitized Terraform resource name. + source_id: The source resource identifier for traceability. + attributes: The attribute dict to render. + + Returns: + A string containing the HCL resource block. + """ + lines = [f"# Source: {source_id}"] + lines.append(f'resource "{resource_type}" "{tf_name}" {{') + + for key, value in attributes.items(): + hcl_value = _format_hcl_value(value) + lines.append(f" {key} = {hcl_value}") + + lines.append("}") + lines.append("") # trailing newline + + return "\n".join(lines) diff --git a/src/iac_reverse/incremental/snapshot_store.py b/src/iac_reverse/incremental/snapshot_store.py new file mode 100644 index 0000000..b500f33 --- /dev/null +++ b/src/iac_reverse/incremental/snapshot_store.py @@ -0,0 +1,177 @@ +"""Snapshot storage and retrieval for incremental scan comparison. + +Stores scan results as timestamped JSON files in `.iac-reverse/snapshots/` +and provides retrieval of previous snapshots for change detection. +""" + +import json +import os +from datetime import datetime, timezone +from pathlib import Path +from typing import Optional + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) + +# Default directory for snapshot storage +SNAPSHOT_DIR = os.path.join(".iac-reverse", "snapshots") + +# Minimum number of snapshots to retain per profile +MIN_RETAINED_SNAPSHOTS = 2 + + +def _serialize_scan_result(result: ScanResult) -> dict: + """Serialize a ScanResult to a JSON-compatible dictionary.""" + return { + "scan_timestamp": result.scan_timestamp, + "profile_hash": result.profile_hash, + "is_partial": result.is_partial, + "warnings": result.warnings, + "errors": result.errors, + "resources": [_serialize_resource(r) for r in result.resources], + } + + +def _serialize_resource(resource: DiscoveredResource) -> dict: + """Serialize a DiscoveredResource to a JSON-compatible dictionary.""" + return { + "resource_type": resource.resource_type, + "unique_id": resource.unique_id, + "name": resource.name, + "provider": resource.provider.value, + "platform_category": resource.platform_category.value, + "architecture": resource.architecture.value, + "endpoint": resource.endpoint, + "attributes": resource.attributes, + "raw_references": resource.raw_references, + } + + +def _deserialize_scan_result(data: dict) -> ScanResult: + """Deserialize a dictionary into a ScanResult.""" + resources = [_deserialize_resource(r) for r in data["resources"]] + return ScanResult( + resources=resources, + warnings=data["warnings"], + errors=data["errors"], + scan_timestamp=data["scan_timestamp"], + profile_hash=data["profile_hash"], + is_partial=data.get("is_partial", False), + ) + + +def _deserialize_resource(data: dict) -> DiscoveredResource: + """Deserialize a dictionary into a DiscoveredResource.""" + return DiscoveredResource( + resource_type=data["resource_type"], + unique_id=data["unique_id"], + name=data["name"], + provider=ProviderType(data["provider"]), + platform_category=PlatformCategory(data["platform_category"]), + architecture=CpuArchitecture(data["architecture"]), + endpoint=data["endpoint"], + attributes=data["attributes"], + raw_references=data.get("raw_references", []), + ) + + +class SnapshotStore: + """Manages storage and retrieval of scan result snapshots. + + Stores scan results as timestamped JSON files in a configurable + directory (defaults to `.iac-reverse/snapshots/`). Supports + retrieval of the most recent snapshot for a given profile hash + and automatic pruning of old snapshots. + """ + + def __init__(self, base_dir: Optional[str] = None) -> None: + """Initialize the snapshot store. + + Args: + base_dir: Base directory for snapshot storage. + Defaults to `.iac-reverse/snapshots/`. + """ + self._snapshot_dir = Path(base_dir) if base_dir else Path(SNAPSHOT_DIR) + + @property + def snapshot_dir(self) -> Path: + """Return the snapshot directory path.""" + return self._snapshot_dir + + def store_snapshot(self, result: ScanResult, profile_hash: str) -> None: + """Store a scan result as a timestamped JSON snapshot. + + Args: + result: The scan result to store. + profile_hash: Hash identifying the scan profile. + + The snapshot is saved with filename format: + {profile_hash}_{timestamp}.json + where timestamp is ISO format with colons replaced by dashes. + + After storing, old snapshots are pruned to retain at least + MIN_RETAINED_SNAPSHOTS most recent files per profile_hash. + """ + self._snapshot_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ") + filename = f"{profile_hash}_{timestamp}.json" + filepath = self._snapshot_dir / filename + + data = _serialize_scan_result(result) + with open(filepath, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + self._prune_snapshots(profile_hash) + + def load_previous(self, profile_hash: str) -> Optional[ScanResult]: + """Load the most recent snapshot for a given profile hash. + + Args: + profile_hash: Hash identifying the scan profile. + + Returns: + The most recent ScanResult for the profile, or None if + no snapshot exists. + """ + snapshots = self._list_snapshots(profile_hash) + if not snapshots: + return None + + # Sort by filename (which includes timestamp) to get most recent + snapshots.sort() + most_recent = snapshots[-1] + + with open(most_recent, "r", encoding="utf-8") as f: + data = json.load(f) + + return _deserialize_scan_result(data) + + def _list_snapshots(self, profile_hash: str) -> list[Path]: + """List all snapshot files for a given profile hash.""" + if not self._snapshot_dir.exists(): + return [] + + prefix = f"{profile_hash}_" + return [ + p + for p in self._snapshot_dir.iterdir() + if p.is_file() and p.name.startswith(prefix) and p.name.endswith(".json") + ] + + def _prune_snapshots(self, profile_hash: str) -> None: + """Remove old snapshots, keeping at least MIN_RETAINED_SNAPSHOTS most recent.""" + snapshots = self._list_snapshots(profile_hash) + if len(snapshots) <= MIN_RETAINED_SNAPSHOTS: + return + + # Sort by filename (timestamp is embedded) and remove oldest + snapshots.sort() + to_remove = snapshots[: len(snapshots) - MIN_RETAINED_SNAPSHOTS] + for snapshot_path in to_remove: + snapshot_path.unlink() diff --git a/src/iac_reverse/models.py b/src/iac_reverse/models.py new file mode 100644 index 0000000..a301005 --- /dev/null +++ b/src/iac_reverse/models.py @@ -0,0 +1,425 @@ +"""Core data models for the IaC Reverse Engineering tool. + +Contains enums, dataclasses, and type definitions used across all components +of the pipeline: Scanner, Dependency Resolver, Code Generator, State Builder, +Validator, and Incremental Scan Engine. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class ProviderType(Enum): + """Supported on-premises infrastructure provider types.""" + + DOCKER_SWARM = "docker_swarm" + KUBERNETES = "kubernetes" + SYNOLOGY = "synology" + HARVESTER = "harvester" + BARE_METAL = "bare_metal" + WINDOWS = "windows" + + +class PlatformCategory(Enum): + """Categorizes providers by their infrastructure model.""" + + CONTAINER_ORCHESTRATION = "container" # Docker Swarm, Kubernetes + STORAGE_APPLIANCE = "storage" # Synology Disk Station + HCI = "hci" # SUSE Harvester (Hyper-Converged Infrastructure) + BARE_METAL = "bare_metal" # Physical servers (Linux) + WINDOWS = "windows" # Standalone Windows machines + + +PROVIDER_PLATFORM_MAP: dict[ProviderType, PlatformCategory] = { + ProviderType.DOCKER_SWARM: PlatformCategory.CONTAINER_ORCHESTRATION, + ProviderType.KUBERNETES: PlatformCategory.CONTAINER_ORCHESTRATION, + ProviderType.SYNOLOGY: PlatformCategory.STORAGE_APPLIANCE, + ProviderType.HARVESTER: PlatformCategory.HCI, + ProviderType.BARE_METAL: PlatformCategory.BARE_METAL, + ProviderType.WINDOWS: PlatformCategory.WINDOWS, +} + + +class CpuArchitecture(Enum): + """CPU architecture of the host or resource.""" + + AMD64 = "amd64" + ARM = "arm" + AARCH64 = "aarch64" + + +class ChangeType(Enum): + """Classification of resource changes between scan runs.""" + + ADDED = "added" + REMOVED = "removed" + MODIFIED = "modified" + + +# --------------------------------------------------------------------------- +# Provider supported resource types +# --------------------------------------------------------------------------- + +PROVIDER_SUPPORTED_RESOURCE_TYPES: dict[ProviderType, list[str]] = { + ProviderType.DOCKER_SWARM: [ + "docker_service", + "docker_network", + "docker_volume", + "docker_config", + "docker_secret", + ], + ProviderType.KUBERNETES: [ + "kubernetes_deployment", + "kubernetes_service", + "kubernetes_ingress", + "kubernetes_config_map", + "kubernetes_persistent_volume", + "kubernetes_namespace", + ], + ProviderType.SYNOLOGY: [ + "synology_shared_folder", + "synology_volume", + "synology_storage_pool", + "synology_replication_task", + "synology_user", + ], + ProviderType.HARVESTER: [ + "harvester_virtualmachine", + "harvester_volume", + "harvester_image", + "harvester_network", + ], + ProviderType.BARE_METAL: [ + "bare_metal_hardware", + "bare_metal_bmc_config", + "bare_metal_network_interface", + "bare_metal_raid_config", + ], + ProviderType.WINDOWS: [ + "windows_service", + "windows_scheduled_task", + "windows_iis_site", + "windows_iis_app_pool", + "windows_network_adapter", + "windows_firewall_rule", + "windows_installed_software", + "windows_feature", + "windows_hyperv_vm", + "windows_hyperv_switch", + "windows_dns_record", + "windows_local_user", + "windows_local_group", + ], +} + +MAX_RESOURCE_TYPE_FILTERS = 200 + + +# --------------------------------------------------------------------------- +# Scanner dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ScanProfile: + """Configuration for a single infrastructure scan.""" + + provider: ProviderType + credentials: dict[str, str] + endpoints: Optional[list[str]] = None + resource_type_filters: Optional[list[str]] = None + authentik_token: Optional[str] = None + + def validate(self) -> list[str]: + """Returns list of validation errors, empty if valid. + + Validates: + - credentials must not be empty + - resource_type_filters must have at most MAX_RESOURCE_TYPE_FILTERS entries + - resource_type_filters entries must be supported by the provider + """ + errors: list[str] = [] + + if not self.credentials: + errors.append("credentials must not be empty") + + if self.resource_type_filters is not None: + if len(self.resource_type_filters) > MAX_RESOURCE_TYPE_FILTERS: + errors.append( + f"resource_type_filters must have at most " + f"{MAX_RESOURCE_TYPE_FILTERS} entries, " + f"got {len(self.resource_type_filters)}" + ) + + supported = set(PROVIDER_SUPPORTED_RESOURCE_TYPES[self.provider]) + unsupported = [ + rt for rt in self.resource_type_filters if rt not in supported + ] + if unsupported: + errors.append( + f"unsupported resource types for provider " + f"'{self.provider.value}': {unsupported}" + ) + + return errors + + @property + def platform_category(self) -> PlatformCategory: + """Return the platform category for this profile's provider.""" + return PROVIDER_PLATFORM_MAP[self.provider] + + +@dataclass +class DiscoveredResource: + """A single resource discovered from an infrastructure provider.""" + + resource_type: str + unique_id: str + name: str + provider: ProviderType + platform_category: PlatformCategory + architecture: CpuArchitecture + endpoint: str + attributes: dict + raw_references: list[str] = field(default_factory=list) + + +@dataclass +class ScanResult: + """Complete result of a scan operation.""" + + resources: list[DiscoveredResource] + warnings: list[str] + errors: list[str] + scan_timestamp: str + profile_hash: str + is_partial: bool = False + + +@dataclass +class ScanProgress: + """Progress update during a scan operation.""" + + current_resource_type: str + resources_discovered: int + resource_types_completed: int + total_resource_types: int + + +# --------------------------------------------------------------------------- +# Dependency Resolver dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ResourceRelationship: + """A relationship between two discovered resources.""" + + source_id: str + target_id: str + relationship_type: str # "parent-child", "reference", "dependency" + source_attribute: str + + +@dataclass +class UnresolvedReference: + """A reference that could not be resolved to a known resource.""" + + source_resource_id: str + source_attribute: str + referenced_id: str + suggested_resolution: str # "data_source" or "variable" + + +@dataclass +class CycleReport: + """Report of a detected circular dependency with resolution suggestions.""" + + cycle: list[str] # Resource IDs forming the cycle + suggested_break: tuple[str, str] # (source_id, target_id) edge to break + break_relationship_type: str # Type of the relationship to break + resolution_strategy: str # Human-readable suggestion for resolution + + +@dataclass +class DependencyGraph: + """Complete dependency graph of discovered resources.""" + + resources: list[DiscoveredResource] + relationships: list[ResourceRelationship] + topological_order: list[str] + cycles: list[list[str]] + unresolved_references: list[UnresolvedReference] + cycle_reports: list[CycleReport] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Code Generator dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class GeneratedFile: + """A single generated Terraform HCL file.""" + + filename: str + content: str + resource_count: int + + +@dataclass +class ExtractedVariable: + """A Terraform variable extracted from common resource values.""" + + name: str + type_expr: str + default_value: str + description: str + used_by: list[str] = field(default_factory=list) + + +@dataclass +class CodeGenerationResult: + """Complete result of code generation.""" + + resource_files: list[GeneratedFile] + variables_file: GeneratedFile + provider_file: GeneratedFile + outputs_file: Optional[GeneratedFile] = None + skipped_resources: list[tuple[str, str]] = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# State Builder dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class StateEntry: + """A single resource entry in the Terraform state file.""" + + resource_type: str + resource_name: str + provider_id: str + attributes: dict + sensitive_attributes: list[str] = field(default_factory=list) + schema_version: int = 0 + dependencies: list[str] = field(default_factory=list) + + +@dataclass +class StateFile: + """Terraform state file representation (format version 4).""" + + version: int = 4 + terraform_version: str = "" + serial: int = 1 + lineage: str = "" + resources: list[StateEntry] = field(default_factory=list) + + def to_json(self) -> str: + """Serialize to Terraform state JSON format.""" + import json + import uuid + + lineage = self.lineage or str(uuid.uuid4()) + + state_resources = [] + for entry in self.resources: + state_resources.append( + { + "mode": "managed", + "type": entry.resource_type, + "name": entry.resource_name, + "provider": f'provider["registry.terraform.io/hashicorp/{entry.resource_type.split("_")[0]}"]', + "instances": [ + { + "schema_version": entry.schema_version, + "attributes": { + "id": entry.provider_id, + **entry.attributes, + }, + "sensitive_attributes": entry.sensitive_attributes, + "dependencies": entry.dependencies, + } + ], + } + ) + + state = { + "version": self.version, + "terraform_version": self.terraform_version, + "serial": self.serial, + "lineage": lineage, + "outputs": {}, + "resources": state_resources, + } + + return json.dumps(state, indent=2) + + +# --------------------------------------------------------------------------- +# Validator dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class PlannedChange: + """A single planned change reported by terraform plan.""" + + resource_address: str + change_type: str # "add", "modify", "destroy" + details: str + + +@dataclass +class ValidationError: + """A validation error from terraform validate or plan.""" + + file: str + message: str + line: Optional[int] = None + + +@dataclass +class ValidationResult: + """Complete result of terraform validation.""" + + init_success: bool + validate_success: bool + plan_success: bool + planned_changes: list[PlannedChange] = field(default_factory=list) + errors: list[ValidationError] = field(default_factory=list) + correction_attempts: int = 0 + + +# --------------------------------------------------------------------------- +# Incremental Scan dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ResourceChange: + """A single resource change between scan runs.""" + + resource_id: str + resource_type: str + resource_name: str + change_type: ChangeType + changed_attributes: Optional[dict] = None + + +@dataclass +class ChangeSummary: + """Summary of changes between two scan runs.""" + + added_count: int + removed_count: int + modified_count: int + changes: list[ResourceChange] = field(default_factory=list) diff --git a/src/iac_reverse/plugin_base.py b/src/iac_reverse/plugin_base.py new file mode 100644 index 0000000..e4098e9 --- /dev/null +++ b/src/iac_reverse/plugin_base.py @@ -0,0 +1,103 @@ +"""Provider plugin abstract base class. + +Defines the interface that all infrastructure provider plugins must implement +to participate in the scanning pipeline. +""" + +from abc import ABC, abstractmethod +from typing import Callable + +from iac_reverse.models import ( + CpuArchitecture, + PlatformCategory, + ScanProgress, + ScanResult, +) + + +class ProviderPlugin(ABC): + """Interface that all provider plugins must implement. + + Each on-premises platform (Docker Swarm, Kubernetes, Synology, Harvester, + Bare Metal, Windows) provides a concrete implementation of this class to + handle platform-specific authentication, discovery, and architecture detection. + """ + + @abstractmethod + def authenticate(self, credentials: dict[str, str]) -> None: + """Authenticate with the platform API. + + Args: + credentials: Provider-specific authentication parameters + (API tokens, usernames, passwords, kubeconfig paths, etc.) + + Raises: + AuthenticationError: If authentication fails, with a descriptive + message including the provider name and failure reason. + """ + ... + + @abstractmethod + def get_platform_category(self) -> PlatformCategory: + """Return the platform category for this provider. + + Returns: + The PlatformCategory enum value representing this provider's + infrastructure model (container orchestration, storage, HCI, etc.) + """ + ... + + @abstractmethod + def list_endpoints(self) -> list[str]: + """Return all reachable endpoints/hosts for this provider. + + Returns: + List of endpoint URLs or host addresses that can be scanned. + """ + ... + + @abstractmethod + def list_supported_resource_types(self) -> list[str]: + """Return all resource types this plugin can discover. + + Returns: + List of resource type strings (e.g., "kubernetes_deployment", + "windows_iis_site", "synology_shared_folder"). + """ + ... + + @abstractmethod + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect the CPU architecture of the target host/node. + + Args: + endpoint: The endpoint URL or host address to query. + + Returns: + The CpuArchitecture enum value for the target. + """ + ... + + @abstractmethod + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover resources from the infrastructure provider. + + Connects to the specified endpoints and enumerates resources of the + requested types. Reports progress via the callback function. + + Args: + endpoints: List of endpoint URLs or host addresses to scan. + resource_types: List of resource type strings to discover. + Should be a subset of list_supported_resource_types(). + progress_callback: Callable that receives ScanProgress updates + during the discovery process. + + Returns: + ScanResult containing all discovered resources, warnings, and errors. + """ + ... diff --git a/src/iac_reverse/resolver/__init__.py b/src/iac_reverse/resolver/__init__.py new file mode 100644 index 0000000..7aafef5 --- /dev/null +++ b/src/iac_reverse/resolver/__init__.py @@ -0,0 +1,5 @@ +"""Dependency resolver module for resource relationship mapping.""" + +from iac_reverse.resolver.resolver import DependencyResolver + +__all__ = ["DependencyResolver"] diff --git a/src/iac_reverse/resolver/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/resolver/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..6c6da3b Binary files /dev/null and b/src/iac_reverse/resolver/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/resolver/__pycache__/resolver.cpython-313.pyc b/src/iac_reverse/resolver/__pycache__/resolver.cpython-313.pyc new file mode 100644 index 0000000..7202e37 Binary files /dev/null and b/src/iac_reverse/resolver/__pycache__/resolver.cpython-313.pyc differ diff --git a/src/iac_reverse/resolver/resolver.py b/src/iac_reverse/resolver/resolver.py new file mode 100644 index 0000000..65fb9ea --- /dev/null +++ b/src/iac_reverse/resolver/resolver.py @@ -0,0 +1,443 @@ +"""Dependency resolver for resource relationship mapping. + +Analyzes discovered resources and their raw_references to build a dependency +graph with topological ordering. Identifies parent-child, reference, and +dependency relationships between resources. Detects circular dependencies +and suggests resolution strategies. +""" + +import logging +import networkx as nx + +from iac_reverse.models import ( + CycleReport, + DependencyGraph, + DiscoveredResource, + ResourceRelationship, + ScanResult, + UnresolvedReference, +) + +logger = logging.getLogger(__name__) + +# Resource types that represent namespace/container resources (parent-child targets) +_NAMESPACE_RESOURCE_TYPES = frozenset( + [ + "kubernetes_namespace", + "docker_network", + "harvester_network", + ] +) + +# Resource types that represent infrastructure dependencies (must exist before dependents) +# Maps: dependent resource type -> set of resource types it depends on +_DEPENDENCY_RESOURCE_TYPES: dict[str, frozenset[str]] = { + "windows_iis_site": frozenset(["windows_iis_app_pool"]), + "windows_hyperv_vm": frozenset(["windows_hyperv_switch"]), + "kubernetes_deployment": frozenset(["kubernetes_namespace", "kubernetes_config_map"]), + "kubernetes_service": frozenset(["kubernetes_namespace"]), + "kubernetes_ingress": frozenset(["kubernetes_namespace", "kubernetes_service"]), + "harvester_virtualmachine": frozenset(["harvester_network", "harvester_image"]), +} + +# Priority for breaking relationships (lower = prefer to break first) +_RELATIONSHIP_BREAK_PRIORITY: dict[str, int] = { + "reference": 0, + "dependency": 1, + "parent-child": 2, +} + + +class DependencyResolver: + """Resolves dependencies between discovered infrastructure resources. + + Analyzes raw_references on each DiscoveredResource to identify relationships + and builds a networkx DiGraph for topological ordering. Detects cycles and + suggests resolution strategies. + """ + + def __init__(self, scan_result: ScanResult) -> None: + """Initialize the resolver with a scan result. + + Args: + scan_result: The ScanResult containing discovered resources. + """ + self._scan_result = scan_result + self._resource_map: dict[str, DiscoveredResource] = { + r.unique_id: r for r in scan_result.resources + } + + def resolve(self) -> DependencyGraph: + """Analyze relationships and produce a dependency graph. + + Builds the graph, detects cycles, suggests resolutions, and produces + a topological ordering (breaking cycle edges if necessary). + + Returns: + DependencyGraph with resources, relationships, topological ordering, + cycles, cycle_reports, and unresolved_references. + """ + graph = nx.DiGraph() + relationships: list[ResourceRelationship] = [] + unresolved_references: list[UnresolvedReference] = [] + + # Add all resources as nodes + for resource in self._scan_result.resources: + graph.add_node(resource.unique_id) + + # Analyze raw_references to build edges and relationships + for resource in self._scan_result.resources: + for ref_id in resource.raw_references: + if ref_id not in self._resource_map: + # Unresolved reference - track it + source_attribute = self._identify_source_attribute_for_ref( + resource, ref_id + ) + suggested_resolution = self._suggest_resolution(ref_id) + + unresolved_references.append( + UnresolvedReference( + source_resource_id=resource.unique_id, + source_attribute=source_attribute, + referenced_id=ref_id, + suggested_resolution=suggested_resolution, + ) + ) + + logger.warning( + "Unresolved reference from resource '%s' (attribute: '%s') " + "to '%s' - suggested resolution: %s", + resource.unique_id, + source_attribute, + ref_id, + suggested_resolution, + ) + continue + + target_resource = self._resource_map[ref_id] + relationship_type = self._classify_relationship( + resource, target_resource + ) + + # Edge direction: source depends on target + # So target must come before source in topological order + graph.add_edge(ref_id, resource.unique_id) + + source_attribute = self._identify_source_attribute( + resource, target_resource + ) + + relationships.append( + ResourceRelationship( + source_id=resource.unique_id, + target_id=ref_id, + relationship_type=relationship_type, + source_attribute=source_attribute, + ) + ) + + # Detect cycles + cycle_reports = self.detect_cycles(graph, relationships) + cycles = [report.cycle for report in cycle_reports] + + # Produce topological ordering by breaking cycle edges if needed + topological_order = self._topological_order_with_cycle_breaking( + graph, cycle_reports + ) + + return DependencyGraph( + resources=self._scan_result.resources, + relationships=relationships, + topological_order=topological_order, + cycles=cycles, + unresolved_references=unresolved_references, + cycle_reports=cycle_reports, + ) + + def detect_cycles( + self, graph: nx.DiGraph, relationships: list[ResourceRelationship] + ) -> list[CycleReport]: + """Detect circular dependencies and suggest resolution strategies. + + Finds all simple cycles in the graph and for each cycle suggests which + edge to break. Prefers breaking "reference" over "dependency" over + "parent-child" relationships. + + Args: + graph: The networkx DiGraph with resource dependencies. + relationships: The list of ResourceRelationship objects. + + Returns: + List of CycleReport objects with cycle info and suggestions. + """ + # Build a lookup for relationship types by edge (target_id, source_id) + # Note: graph edges are (target_id, source_id) because edge direction + # means "target must come before source" + edge_relationship_map: dict[tuple[str, str], ResourceRelationship] = {} + for rel in relationships: + # In the graph, edge is (rel.target_id -> rel.source_id) + edge_relationship_map[(rel.target_id, rel.source_id)] = rel + + # Find all simple cycles + raw_cycles = list(nx.simple_cycles(graph)) + + cycle_reports: list[CycleReport] = [] + for cycle_nodes in raw_cycles: + if len(cycle_nodes) < 2: + continue + + # Find the best edge to break in this cycle + suggested_break, break_type = self._suggest_cycle_break( + cycle_nodes, edge_relationship_map + ) + + # Build resolution strategy message + source_id, target_id = suggested_break + # The relationship source_id is the resource that holds the reference + # In graph edge (A, B), A is target_id in relationship, B is source_id + resolution_strategy = ( + f"Break the '{break_type}' relationship by replacing the direct " + f"reference from '{target_id}' to '{source_id}' with a " + f"data source lookup (e.g., terraform data source) to decouple " + f"the circular dependency." + ) + + cycle_reports.append( + CycleReport( + cycle=cycle_nodes, + suggested_break=suggested_break, + break_relationship_type=break_type, + resolution_strategy=resolution_strategy, + ) + ) + + return cycle_reports + + def _suggest_cycle_break( + self, + cycle_nodes: list[str], + edge_relationship_map: dict[tuple[str, str], ResourceRelationship], + ) -> tuple[tuple[str, str], str]: + """Suggest which edge to break in a cycle. + + Prefers breaking "reference" over "dependency" over "parent-child". + + Args: + cycle_nodes: List of node IDs forming the cycle. + edge_relationship_map: Map from graph edge to ResourceRelationship. + + Returns: + Tuple of ((source_node, target_node) edge to break, relationship_type). + """ + # Build edges in the cycle: each consecutive pair + wrap-around + cycle_edges: list[tuple[str, str]] = [] + for i in range(len(cycle_nodes)): + from_node = cycle_nodes[i] + to_node = cycle_nodes[(i + 1) % len(cycle_nodes)] + cycle_edges.append((from_node, to_node)) + + # Find the edge with lowest break priority (prefer to break "reference" first) + best_edge = cycle_edges[0] + best_type = "reference" + best_priority = _RELATIONSHIP_BREAK_PRIORITY.get("reference", 0) + + for edge in cycle_edges: + rel = edge_relationship_map.get(edge) + if rel: + rel_type = rel.relationship_type + else: + # If no relationship found, treat as reference (easiest to break) + rel_type = "reference" + + priority = _RELATIONSHIP_BREAK_PRIORITY.get(rel_type, 0) + if priority < best_priority or ( + priority == best_priority and edge < best_edge + ): + best_priority = priority + best_edge = edge + best_type = rel_type + + return best_edge, best_type + + def _topological_order_with_cycle_breaking( + self, graph: nx.DiGraph, cycle_reports: list[CycleReport] + ) -> list[str]: + """Produce topological order by temporarily removing cycle-breaking edges. + + If the graph has cycles, removes the suggested edges from each cycle + report and attempts topological sort on the resulting DAG. + + Args: + graph: The original DiGraph (may contain cycles). + cycle_reports: Cycle reports with suggested edges to break. + + Returns: + List of resource IDs in topological order. + """ + if not cycle_reports: + # No cycles - straightforward topological sort + try: + return list(nx.topological_sort(graph)) + except nx.NetworkXUnfeasible: + # Shouldn't happen if cycle detection is correct, but be safe + return list(graph.nodes) + + # Create a copy and remove suggested break edges + working_graph = graph.copy() + for report in cycle_reports: + edge = report.suggested_break + if working_graph.has_edge(*edge): + working_graph.remove_edge(*edge) + + # Try topological sort on the modified graph + try: + return list(nx.topological_sort(working_graph)) + except nx.NetworkXUnfeasible: + # Still has cycles (overlapping cycles may need more breaks) + # Fall back to removing all cycle edges iteratively + while True: + try: + return list(nx.topological_sort(working_graph)) + except nx.NetworkXUnfeasible: + # Find remaining cycle and break an edge + try: + cycle = nx.find_cycle(working_graph) + # Remove the first edge in the found cycle + working_graph.remove_edge(*cycle[0][:2]) + except nx.NetworkXNoCycle: + return list(nx.topological_sort(working_graph)) + + def _classify_relationship( + self, source: DiscoveredResource, target: DiscoveredResource + ) -> str: + """Classify the relationship type between source and target. + + Args: + source: The resource that holds the reference. + target: The resource being referenced. + + Returns: + One of "parent-child", "dependency", or "reference". + """ + # Parent-child: target is a namespace/container resource + if target.resource_type in _NAMESPACE_RESOURCE_TYPES: + return "parent-child" + + # Dependency: source resource type has a known dependency on target's type + dependent_types = _DEPENDENCY_RESOURCE_TYPES.get(source.resource_type) + if dependent_types and target.resource_type in dependent_types: + return "dependency" + + # Default: reference relationship + return "reference" + + def _identify_source_attribute( + self, source: DiscoveredResource, target: DiscoveredResource + ) -> str: + """Identify which attribute in the source holds the reference to target. + + Searches the source's attributes for values matching the target's unique_id + or name. Falls back to "raw_references" if no specific attribute is found. + + Args: + source: The resource holding the reference. + target: The resource being referenced. + + Returns: + The attribute name that holds the reference. + """ + # Search attributes for the target's unique_id or name + for attr_name, attr_value in source.attributes.items(): + if isinstance(attr_value, str): + if attr_value == target.unique_id or attr_value == target.name: + return attr_name + elif isinstance(attr_value, list): + for item in attr_value: + if isinstance(item, str) and ( + item == target.unique_id or item == target.name + ): + return attr_name + + return "raw_references" + + def _identify_source_attribute_for_ref( + self, source: DiscoveredResource, ref_id: str + ) -> str: + """Identify which attribute in the source holds an unresolved reference. + + Searches the source's attributes for values matching the given ref_id. + Falls back to "raw_references" if no specific attribute is found. + + Args: + source: The resource holding the reference. + ref_id: The unresolved reference ID string. + + Returns: + The attribute name that holds the reference. + """ + for attr_name, attr_value in source.attributes.items(): + if isinstance(attr_value, str): + if attr_value == ref_id: + return attr_name + elif isinstance(attr_value, list): + for item in attr_value: + if isinstance(item, str) and item == ref_id: + return attr_name + + return "raw_references" + + def _suggest_resolution(self, ref_id: str) -> str: + """Suggest a resolution strategy for an unresolved reference. + + Args: + ref_id: The unresolved reference ID. + + Returns: + Either "data_source" or "variable" as the suggested resolution. + """ + # If the reference looks like a structured resource ID (contains /), + # suggest a data source lookup. Otherwise suggest a variable. + if "/" in ref_id: + return "data_source" + return "variable" + + def _identify_source_attribute_for_ref( + self, source: DiscoveredResource, ref_id: str + ) -> str: + """Identify which attribute in the source holds an unresolved reference. + + Searches the source's attributes for values matching the referenced ID. + Falls back to "raw_references" if no specific attribute is found. + + Args: + source: The resource holding the reference. + ref_id: The unresolved reference ID. + + Returns: + The attribute name that holds the reference. + """ + for attr_name, attr_value in source.attributes.items(): + if isinstance(attr_value, str): + if attr_value == ref_id: + return attr_name + elif isinstance(attr_value, list): + for item in attr_value: + if isinstance(item, str) and item == ref_id: + return attr_name + + return "raw_references" + + @staticmethod + def _suggest_resolution(ref_id: str) -> str: + """Determine the suggested resolution for an unresolved reference. + + Args: + ref_id: The unresolved reference ID. + + Returns: + "data_source" if the reference looks like a resource ID (contains + "/" or ":"), otherwise "variable" for simple value/name references. + """ + if "/" in ref_id or ":" in ref_id: + return "data_source" + return "variable" diff --git a/src/iac_reverse/scanner/__init__.py b/src/iac_reverse/scanner/__init__.py new file mode 100644 index 0000000..0f516c6 --- /dev/null +++ b/src/iac_reverse/scanner/__init__.py @@ -0,0 +1,45 @@ +"""Scanner module for infrastructure discovery.""" + +from iac_reverse.scanner.bare_metal_plugin import BareMetalPlugin +from iac_reverse.scanner.docker_swarm_plugin import DockerSwarmPlugin +from iac_reverse.scanner.harvester_plugin import HarvesterPlugin +from iac_reverse.scanner.kubernetes_plugin import KubernetesPlugin +from iac_reverse.scanner.multi_provider_scanner import ( + MultiProviderScanner, + MultiProviderScanResult, + ProviderFailure, + ProviderScanEntry, +) +from iac_reverse.scanner.scanner import ( + AuthenticationError, + ConnectionLostError, + Scanner, + ScanTimeoutError, +) +from iac_reverse.scanner.synology_plugin import SynologyPlugin +from iac_reverse.scanner.windows_plugin import ( + InsufficientPrivilegesError, + WindowsDiscoveryPlugin, + WinRMNotEnabledError, + WMIQueryError, +) + +__all__ = [ + "AuthenticationError", + "BareMetalPlugin", + "ConnectionLostError", + "DockerSwarmPlugin", + "HarvesterPlugin", + "InsufficientPrivilegesError", + "KubernetesPlugin", + "MultiProviderScanner", + "MultiProviderScanResult", + "ProviderFailure", + "ProviderScanEntry", + "Scanner", + "ScanTimeoutError", + "SynologyPlugin", + "WindowsDiscoveryPlugin", + "WinRMNotEnabledError", + "WMIQueryError", +] diff --git a/src/iac_reverse/scanner/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..721aaea Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/bare_metal_plugin.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/bare_metal_plugin.cpython-313.pyc new file mode 100644 index 0000000..5d06137 Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/bare_metal_plugin.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/docker_swarm_plugin.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/docker_swarm_plugin.cpython-313.pyc new file mode 100644 index 0000000..4f3987f Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/docker_swarm_plugin.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/harvester_plugin.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/harvester_plugin.cpython-313.pyc new file mode 100644 index 0000000..989b03f Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/harvester_plugin.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/kubernetes_plugin.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/kubernetes_plugin.cpython-313.pyc new file mode 100644 index 0000000..13e9a87 Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/kubernetes_plugin.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/multi_provider_scanner.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/multi_provider_scanner.cpython-313.pyc new file mode 100644 index 0000000..0b00e3f Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/multi_provider_scanner.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/scanner.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/scanner.cpython-313.pyc new file mode 100644 index 0000000..a94d3f4 Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/scanner.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/synology_plugin.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/synology_plugin.cpython-313.pyc new file mode 100644 index 0000000..e589d68 Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/synology_plugin.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/__pycache__/windows_plugin.cpython-313.pyc b/src/iac_reverse/scanner/__pycache__/windows_plugin.cpython-313.pyc new file mode 100644 index 0000000..df2d9bf Binary files /dev/null and b/src/iac_reverse/scanner/__pycache__/windows_plugin.cpython-313.pyc differ diff --git a/src/iac_reverse/scanner/bare_metal_plugin.py b/src/iac_reverse/scanner/bare_metal_plugin.py new file mode 100644 index 0000000..87ca588 --- /dev/null +++ b/src/iac_reverse/scanner/bare_metal_plugin.py @@ -0,0 +1,497 @@ +"""Bare Metal provider plugin using Redfish/IPMI API. + +Discovers hardware inventory, BMC configurations, network interfaces, +and RAID configurations from physical servers via the Redfish REST API +(standard BMC management interface). +""" + +import logging +from typing import Callable +from urllib.parse import urljoin + +import requests + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import AuthenticationError + +logger = logging.getLogger(__name__) + + +class BareMetalPlugin(ProviderPlugin): + """Provider plugin for bare metal servers using Redfish/IPMI API. + + Connects to a server's BMC (Baseboard Management Controller) via the + Redfish REST API to discover hardware inventory, BMC configuration, + network interfaces, and RAID configurations. + + Expected credentials dict keys: + host: BMC hostname or IP address (required) + username: BMC username (required) + password: BMC password (required) + port: BMC port (optional, default 443) + use_ssl: Whether to use HTTPS (optional, default "true") + """ + + SUPPORTED_RESOURCE_TYPES = [ + "bare_metal_hardware", + "bare_metal_bmc_config", + "bare_metal_network_interface", + "bare_metal_raid_config", + ] + + def __init__(self) -> None: + self._session: requests.Session | None = None + self._base_url: str = "" + self._host: str = "" + + def authenticate(self, credentials: dict[str, str]) -> None: + """Authenticate with the BMC via Redfish session creation. + + Args: + credentials: Dict with keys: host, username, password, + and optionally port (default 443) and use_ssl (default "true"). + + Raises: + AuthenticationError: If connection or login fails. + """ + host = credentials.get("host", "") + username = credentials.get("username", "") + password = credentials.get("password", "") + port = credentials.get("port", "443") + use_ssl = credentials.get("use_ssl", "true").lower() == "true" + + if not host or not username or not password: + raise AuthenticationError( + provider_name="bare_metal", + reason="Missing required credentials: host, username, and password are required", + ) + + scheme = "https" if use_ssl else "http" + self._base_url = f"{scheme}://{host}:{port}" + self._host = host + + session = requests.Session() + session.verify = False # BMC certs are typically self-signed + session.headers.update({ + "Content-Type": "application/json", + "Accept": "application/json", + }) + + # Attempt Redfish session-based authentication + session_url = f"{self._base_url}/redfish/v1/SessionService/Sessions" + payload = {"UserName": username, "Password": password} + + try: + response = session.post(session_url, json=payload, timeout=30) + if response.status_code in (200, 201): + # Extract session token from response headers + token = response.headers.get("X-Auth-Token", "") + if token: + session.headers["X-Auth-Token"] = token + elif response.status_code == 401: + raise AuthenticationError( + provider_name="bare_metal", + reason="Invalid credentials (HTTP 401)", + ) + else: + raise AuthenticationError( + provider_name="bare_metal", + reason=f"Unexpected response status {response.status_code}", + ) + except requests.exceptions.ConnectionError as exc: + raise AuthenticationError( + provider_name="bare_metal", + reason=f"Cannot connect to BMC at {self._base_url}: {exc}", + ) from exc + except requests.exceptions.Timeout as exc: + raise AuthenticationError( + provider_name="bare_metal", + reason=f"Connection to BMC timed out: {exc}", + ) from exc + except AuthenticationError: + raise + except Exception as exc: + raise AuthenticationError( + provider_name="bare_metal", + reason=f"Unexpected error during authentication: {exc}", + ) from exc + + self._session = session + + def get_platform_category(self) -> PlatformCategory: + """Return PlatformCategory.BARE_METAL.""" + return PlatformCategory.BARE_METAL + + def list_endpoints(self) -> list[str]: + """Return the BMC host as the single endpoint.""" + return [self._host] if self._host else [] + + def list_supported_resource_types(self) -> list[str]: + """Return supported bare metal resource types.""" + return list(self.SUPPORTED_RESOURCE_TYPES) + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect CPU architecture from Redfish system hardware info. + + Queries /redfish/v1/Systems/1/Processors to determine the + processor architecture. + + Args: + endpoint: The BMC host address. + + Returns: + CpuArchitecture enum value based on processor info. + """ + if self._session is None: + return CpuArchitecture.AMD64 + + processors_url = f"{self._base_url}/redfish/v1/Systems/1/Processors" + try: + response = self._session.get(processors_url, timeout=30) + if response.status_code == 200: + data = response.json() + members = data.get("Members", []) + if members: + # Query first processor for architecture details + proc_uri = members[0].get("@odata.id", "") + if proc_uri: + proc_url = f"{self._base_url}{proc_uri}" + proc_response = self._session.get(proc_url, timeout=30) + if proc_response.status_code == 200: + proc_data = proc_response.json() + return self._parse_architecture(proc_data) + except Exception as exc: + logger.warning("Failed to detect architecture: %s", exc) + + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover bare metal resources via Redfish API. + + Args: + endpoints: List of BMC host addresses to scan. + resource_types: Resource types to discover. + progress_callback: Progress reporting callback. + + Returns: + ScanResult with discovered resources. + """ + resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + total_types = len(resource_types) + types_completed = 0 + + for endpoint in endpoints: + architecture = self.detect_architecture(endpoint) + + for resource_type in resource_types: + try: + discovered = self._discover_resource_type( + endpoint, resource_type, architecture + ) + resources.extend(discovered) + except Exception as exc: + error_msg = ( + f"Error discovering {resource_type} on {endpoint}: {exc}" + ) + errors.append(error_msg) + logger.error(error_msg) + + types_completed += 1 + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(resources), + resource_types_completed=types_completed, + total_resource_types=total_types, + ) + ) + + return ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp="", + profile_hash="", + ) + + # ----------------------------------------------------------------------- + # Private helpers + # ----------------------------------------------------------------------- + + def _discover_resource_type( + self, + endpoint: str, + resource_type: str, + architecture: CpuArchitecture, + ) -> list[DiscoveredResource]: + """Dispatch discovery to the appropriate handler.""" + handlers = { + "bare_metal_hardware": self._discover_hardware, + "bare_metal_bmc_config": self._discover_bmc_config, + "bare_metal_network_interface": self._discover_network_interfaces, + "bare_metal_raid_config": self._discover_raid_config, + } + handler = handlers.get(resource_type) + if handler is None: + return [] + return handler(endpoint, architecture) + + def _discover_hardware( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover hardware inventory via /redfish/v1/Systems/1.""" + if self._session is None: + return [] + + url = f"{self._base_url}/redfish/v1/Systems/1" + try: + response = self._session.get(url, timeout=30) + if response.status_code != 200: + return [] + data = response.json() + except Exception as exc: + logger.warning("Failed to discover hardware: %s", exc) + return [] + + system_id = data.get("Id", "System.1") + return [ + DiscoveredResource( + resource_type="bare_metal_hardware", + unique_id=f"{endpoint}:{system_id}", + name=data.get("Name", f"System {system_id}"), + provider=ProviderType.BARE_METAL, + platform_category=PlatformCategory.BARE_METAL, + architecture=architecture, + endpoint=endpoint, + attributes={ + "manufacturer": data.get("Manufacturer", ""), + "model": data.get("Model", ""), + "serial_number": data.get("SerialNumber", ""), + "sku": data.get("SKU", ""), + "bios_version": data.get("BiosVersion", ""), + "total_memory_gib": data.get("MemorySummary", {}).get( + "TotalSystemMemoryGiB", 0 + ), + "processor_count": data.get("ProcessorSummary", {}).get( + "Count", 0 + ), + "processor_model": data.get("ProcessorSummary", {}).get( + "Model", "" + ), + "power_state": data.get("PowerState", ""), + "status": data.get("Status", {}), + }, + ) + ] + + def _discover_bmc_config( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover BMC configuration via /redfish/v1/Managers/1.""" + if self._session is None: + return [] + + url = f"{self._base_url}/redfish/v1/Managers/1" + try: + response = self._session.get(url, timeout=30) + if response.status_code != 200: + return [] + data = response.json() + except Exception as exc: + logger.warning("Failed to discover BMC config: %s", exc) + return [] + + manager_id = data.get("Id", "BMC.1") + return [ + DiscoveredResource( + resource_type="bare_metal_bmc_config", + unique_id=f"{endpoint}:{manager_id}", + name=data.get("Name", f"BMC {manager_id}"), + provider=ProviderType.BARE_METAL, + platform_category=PlatformCategory.BARE_METAL, + architecture=architecture, + endpoint=endpoint, + attributes={ + "manager_type": data.get("ManagerType", ""), + "firmware_version": data.get("FirmwareVersion", ""), + "model": data.get("Model", ""), + "status": data.get("Status", {}), + "uuid": data.get("UUID", ""), + }, + ) + ] + + def _discover_network_interfaces( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover network interfaces via /redfish/v1/Systems/1/EthernetInterfaces.""" + if self._session is None: + return [] + + url = f"{self._base_url}/redfish/v1/Systems/1/EthernetInterfaces" + try: + response = self._session.get(url, timeout=30) + if response.status_code != 200: + return [] + data = response.json() + except Exception as exc: + logger.warning("Failed to discover network interfaces: %s", exc) + return [] + + resources: list[DiscoveredResource] = [] + for member in data.get("Members", []): + nic_uri = member.get("@odata.id", "") + if not nic_uri: + continue + + try: + nic_url = f"{self._base_url}{nic_uri}" + nic_response = self._session.get(nic_url, timeout=30) + if nic_response.status_code != 200: + continue + nic_data = nic_response.json() + except Exception as exc: + logger.warning("Failed to get NIC details at %s: %s", nic_uri, exc) + continue + + nic_id = nic_data.get("Id", "") + resources.append( + DiscoveredResource( + resource_type="bare_metal_network_interface", + unique_id=f"{endpoint}:{nic_id}", + name=nic_data.get("Name", f"NIC {nic_id}"), + provider=ProviderType.BARE_METAL, + platform_category=PlatformCategory.BARE_METAL, + architecture=architecture, + endpoint=endpoint, + attributes={ + "mac_address": nic_data.get("MACAddress", ""), + "speed_mbps": nic_data.get("SpeedMbps", 0), + "status": nic_data.get("Status", {}), + "ipv4_addresses": nic_data.get("IPv4Addresses", []), + "ipv6_addresses": nic_data.get("IPv6Addresses", []), + "vlan": nic_data.get("VLAN", {}), + "link_status": nic_data.get("LinkStatus", ""), + "auto_neg": nic_data.get("AutoNeg", False), + }, + ) + ) + + return resources + + def _discover_raid_config( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover RAID configuration via /redfish/v1/Systems/1/Storage.""" + if self._session is None: + return [] + + url = f"{self._base_url}/redfish/v1/Systems/1/Storage" + try: + response = self._session.get(url, timeout=30) + if response.status_code != 200: + return [] + data = response.json() + except Exception as exc: + logger.warning("Failed to discover RAID config: %s", exc) + return [] + + resources: list[DiscoveredResource] = [] + for member in data.get("Members", []): + storage_uri = member.get("@odata.id", "") + if not storage_uri: + continue + + try: + storage_url = f"{self._base_url}{storage_uri}" + storage_response = self._session.get(storage_url, timeout=30) + if storage_response.status_code != 200: + continue + storage_data = storage_response.json() + except Exception as exc: + logger.warning( + "Failed to get storage details at %s: %s", storage_uri, exc + ) + continue + + storage_id = storage_data.get("Id", "") + drives = [] + for drive in storage_data.get("Drives", []): + drive_uri = drive.get("@odata.id", "") + if drive_uri: + drives.append(drive_uri) + + volumes = [] + volumes_link = storage_data.get("Volumes", {}).get("@odata.id", "") + if volumes_link: + try: + vol_url = f"{self._base_url}{volumes_link}" + vol_response = self._session.get(vol_url, timeout=30) + if vol_response.status_code == 200: + vol_data = vol_response.json() + for vol_member in vol_data.get("Members", []): + vol_uri = vol_member.get("@odata.id", "") + if vol_uri: + volumes.append(vol_uri) + except Exception as exc: + logger.warning("Failed to get volumes: %s", exc) + + resources.append( + DiscoveredResource( + resource_type="bare_metal_raid_config", + unique_id=f"{endpoint}:{storage_id}", + name=storage_data.get("Name", f"Storage {storage_id}"), + provider=ProviderType.BARE_METAL, + platform_category=PlatformCategory.BARE_METAL, + architecture=architecture, + endpoint=endpoint, + attributes={ + "storage_controllers": [ + ctrl.get("Name", "") + for ctrl in storage_data.get( + "StorageControllers", [] + ) + ], + "drive_count": len(drives), + "drives": drives, + "volumes": volumes, + "status": storage_data.get("Status", {}), + }, + ) + ) + + return resources + + @staticmethod + def _parse_architecture(proc_data: dict) -> CpuArchitecture: + """Parse CPU architecture from Redfish processor data. + + Examines InstructionSet and Model fields to determine architecture. + """ + instruction_set = proc_data.get("InstructionSet", "").lower() + model = proc_data.get("Model", "").lower() + + if "aarch64" in instruction_set or "arm" in instruction_set: + return CpuArchitecture.AARCH64 + if "arm" in model: + if "64" in model or "aarch64" in model or "v8" in model: + return CpuArchitecture.AARCH64 + return CpuArchitecture.ARM + + # Default to AMD64 for x86/x86_64/IA-32e + return CpuArchitecture.AMD64 diff --git a/src/iac_reverse/scanner/docker_swarm_plugin.py b/src/iac_reverse/scanner/docker_swarm_plugin.py new file mode 100644 index 0000000..f733831 --- /dev/null +++ b/src/iac_reverse/scanner/docker_swarm_plugin.py @@ -0,0 +1,433 @@ +"""Docker Swarm provider plugin. + +Discovers services, networks, volumes, configs, and secrets from a Docker Swarm +cluster using the docker-sdk-python library. +""" + +import logging +from typing import Callable, Optional + +import docker +from docker.tls import TLSConfig + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import AuthenticationError + +logger = logging.getLogger(__name__) + +# Resource types supported by this plugin +SUPPORTED_RESOURCE_TYPES = [ + "docker_service", + "docker_network", + "docker_volume", + "docker_config", + "docker_secret", +] + +# Mapping from Docker platform architecture strings to CpuArchitecture enum +_ARCH_MAP: dict[str, CpuArchitecture] = { + "x86_64": CpuArchitecture.AMD64, + "amd64": CpuArchitecture.AMD64, + "aarch64": CpuArchitecture.AARCH64, + "arm64": CpuArchitecture.AARCH64, + "armv7l": CpuArchitecture.ARM, + "armhf": CpuArchitecture.ARM, + "arm": CpuArchitecture.ARM, +} + + +class DockerSwarmPlugin(ProviderPlugin): + """Provider plugin for Docker Swarm infrastructure discovery. + + Connects to a Docker daemon (in Swarm mode) and enumerates services, + networks, volumes, configs, and secrets. + + Expected credentials dict keys: + - host: Docker daemon URL (e.g., "tcp://192.168.1.10:2376") + - tls_verify: (optional) "true" or "false" to enable TLS verification + - cert_path: (optional) path to TLS certificates directory + """ + + def __init__(self) -> None: + self._client: Optional[docker.DockerClient] = None + self._host: str = "" + + def authenticate(self, credentials: dict[str, str]) -> None: + """Connect to the Docker daemon using the provided credentials. + + Args: + credentials: Dict with keys 'host' (required), 'tls_verify' (optional), + and 'cert_path' (optional). + + Raises: + AuthenticationError: If connection to the Docker daemon fails. + """ + host = credentials.get("host", "") + if not host: + raise AuthenticationError( + provider_name="docker_swarm", + reason="'host' is required in credentials", + ) + + tls_verify = credentials.get("tls_verify", "").lower() == "true" + cert_path = credentials.get("cert_path") + + tls_config: Optional[TLSConfig] = None + if tls_verify or cert_path: + tls_config = TLSConfig( + verify=tls_verify, + client_cert=( + (f"{cert_path}/cert.pem", f"{cert_path}/key.pem") + if cert_path + else None + ), + ca_cert=f"{cert_path}/ca.pem" if cert_path else None, + ) + + try: + self._client = docker.DockerClient( + base_url=host, + tls=tls_config if tls_config else False, + ) + # Verify connection by pinging the daemon + self._client.ping() + except Exception as exc: + raise AuthenticationError( + provider_name="docker_swarm", + reason=str(exc), + ) from exc + + self._host = host + + def get_platform_category(self) -> PlatformCategory: + """Return CONTAINER_ORCHESTRATION platform category.""" + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + """Return the Docker daemon host as the single endpoint.""" + if self._host: + return [self._host] + return [] + + def list_supported_resource_types(self) -> list[str]: + """Return supported Docker Swarm resource types.""" + return list(SUPPORTED_RESOURCE_TYPES) + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect CPU architecture from Docker node info. + + Queries the Docker daemon's system info to determine the architecture + of the Swarm node. + + Args: + endpoint: The Docker daemon endpoint (used for context). + + Returns: + CpuArchitecture enum value detected from node info. + """ + if self._client is None: + return CpuArchitecture.AMD64 + + try: + info = self._client.info() + arch_str = info.get("Architecture", "x86_64").lower() + return _ARCH_MAP.get(arch_str, CpuArchitecture.AMD64) + except Exception: + logger.warning( + "Failed to detect architecture for endpoint %s, defaulting to AMD64", + endpoint, + ) + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover Docker Swarm resources. + + Enumerates services, networks, volumes, configs, and secrets + based on the requested resource_types. + + Args: + endpoints: List of Docker daemon endpoints. + resource_types: Resource types to discover. + progress_callback: Callback for progress updates. + + Returns: + ScanResult with discovered resources. + """ + resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + + if self._client is None: + return ScanResult( + resources=[], + warnings=[], + errors=["Not authenticated. Call authenticate() first."], + scan_timestamp="", + profile_hash="", + ) + + endpoint = endpoints[0] if endpoints else self._host + architecture = self.detect_architecture(endpoint) + total_types = len(resource_types) + + discovery_methods = { + "docker_service": self._discover_services, + "docker_network": self._discover_networks, + "docker_volume": self._discover_volumes, + "docker_config": self._discover_configs, + "docker_secret": self._discover_secrets, + } + + for idx, resource_type in enumerate(resource_types): + method = discovery_methods.get(resource_type) + if method is None: + warnings.append(f"Unknown resource type: {resource_type}") + continue + + try: + discovered = method(endpoint, architecture) + resources.extend(discovered) + except Exception as exc: + error_msg = f"Error discovering {resource_type}: {exc}" + errors.append(error_msg) + logger.error(error_msg) + + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(resources), + resource_types_completed=idx + 1, + total_resource_types=total_types, + ) + ) + + return ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp="", + profile_hash="", + ) + + # ------------------------------------------------------------------ + # Private discovery methods + # ------------------------------------------------------------------ + + def _discover_services( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Docker Swarm services.""" + resources: list[DiscoveredResource] = [] + services = self._client.services.list() + + for svc in services: + attrs = svc.attrs + spec = attrs.get("Spec", {}) + task_template = spec.get("TaskTemplate", {}) + container_spec = task_template.get("ContainerSpec", {}) + + resources.append( + DiscoveredResource( + resource_type="docker_service", + unique_id=attrs.get("ID", ""), + name=spec.get("Name", ""), + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "image": container_spec.get("Image", ""), + "replicas": spec.get("Mode", {}) + .get("Replicated", {}) + .get("Replicas", 1), + "labels": spec.get("Labels", {}), + }, + raw_references=self._extract_service_references(spec), + ) + ) + + return resources + + def _discover_networks( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Docker networks.""" + resources: list[DiscoveredResource] = [] + networks = self._client.networks.list() + + for net in networks: + attrs = net.attrs + resources.append( + DiscoveredResource( + resource_type="docker_network", + unique_id=attrs.get("Id", ""), + name=attrs.get("Name", ""), + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "driver": attrs.get("Driver", ""), + "scope": attrs.get("Scope", ""), + "attachable": attrs.get("Attachable", False), + "ingress": attrs.get("Ingress", False), + "labels": attrs.get("Labels", {}), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_volumes( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Docker volumes.""" + resources: list[DiscoveredResource] = [] + volumes = self._client.volumes.list() + + for vol in volumes: + attrs = vol.attrs + resources.append( + DiscoveredResource( + resource_type="docker_volume", + unique_id=attrs.get("Name", ""), + name=attrs.get("Name", ""), + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "driver": attrs.get("Driver", ""), + "mountpoint": attrs.get("Mountpoint", ""), + "labels": attrs.get("Labels", {}), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_configs( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Docker configs (metadata only, no data content).""" + resources: list[DiscoveredResource] = [] + configs = self._client.configs.list() + + for cfg in configs: + attrs = cfg.attrs + spec = attrs.get("Spec", {}) + resources.append( + DiscoveredResource( + resource_type="docker_config", + unique_id=attrs.get("ID", ""), + name=spec.get("Name", ""), + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "labels": spec.get("Labels", {}), + "created_at": attrs.get("CreatedAt", ""), + "updated_at": attrs.get("UpdatedAt", ""), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_secrets( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Docker secrets (metadata only, no secret data).""" + resources: list[DiscoveredResource] = [] + secrets = self._client.secrets.list() + + for secret in secrets: + attrs = secret.attrs + spec = attrs.get("Spec", {}) + resources.append( + DiscoveredResource( + resource_type="docker_secret", + unique_id=attrs.get("ID", ""), + name=spec.get("Name", ""), + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "labels": spec.get("Labels", {}), + "created_at": attrs.get("CreatedAt", ""), + "updated_at": attrs.get("UpdatedAt", ""), + }, + raw_references=[], + ) + ) + + return resources + + @staticmethod + def _extract_service_references(spec: dict) -> list[str]: + """Extract resource references from a service spec. + + Looks for network attachments, volume mounts, config references, + and secret references. + """ + refs: list[str] = [] + + # Network references + networks = spec.get("TaskTemplate", {}).get("Networks", []) + for net in networks: + target = net.get("Target", "") + if target: + refs.append(f"network:{target}") + + # Volume mount references + mounts = ( + spec.get("TaskTemplate", {}) + .get("ContainerSpec", {}) + .get("Mounts", []) + ) + for mount in mounts: + source = mount.get("Source", "") + if source: + refs.append(f"volume:{source}") + + # Config references + configs = ( + spec.get("TaskTemplate", {}) + .get("ContainerSpec", {}) + .get("Configs", []) + ) + for cfg in configs: + config_id = cfg.get("ConfigID", "") + if config_id: + refs.append(f"config:{config_id}") + + # Secret references + secrets = ( + spec.get("TaskTemplate", {}) + .get("ContainerSpec", {}) + .get("Secrets", []) + ) + for secret in secrets: + secret_id = secret.get("SecretID", "") + if secret_id: + refs.append(f"secret:{secret_id}") + + return refs diff --git a/src/iac_reverse/scanner/harvester_plugin.py b/src/iac_reverse/scanner/harvester_plugin.py new file mode 100644 index 0000000..97771df --- /dev/null +++ b/src/iac_reverse/scanner/harvester_plugin.py @@ -0,0 +1,458 @@ +"""Harvester provider plugin for HCI infrastructure discovery. + +Uses the Kubernetes Python client to interact with Harvester's K8s-based API, +discovering virtual machines, volumes, images, and networks via custom resources. +""" + +import logging +from typing import Callable + +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import AuthenticationError + +logger = logging.getLogger(__name__) + +# Harvester CRD API groups and versions +HARVESTER_API_GROUP = "kubevirt.io" +HARVESTER_VM_VERSION = "v1" +HARVESTER_VM_PLURAL = "virtualmachines" + +HARVESTER_CDI_GROUP = "cdi.kubevirt.io" +HARVESTER_CDI_VERSION = "v1beta1" +HARVESTER_VOLUME_PLURAL = "datavolumes" + +HARVESTER_IMAGE_GROUP = "harvesterhci.io" +HARVESTER_IMAGE_VERSION = "v1beta1" +HARVESTER_IMAGE_PLURAL = "virtualmachineimages" + +HARVESTER_NETWORK_GROUP = "k8s.cni.cncf.io" +HARVESTER_NETWORK_VERSION = "v1" +HARVESTER_NETWORK_PLURAL = "network-attachment-definitions" + +# Default namespace for Harvester resources +DEFAULT_NAMESPACE = "default" + + +class HarvesterPlugin(ProviderPlugin): + """Provider plugin for SUSE Harvester HCI platform. + + Harvester runs on top of Kubernetes and exposes its resources as CRDs. + This plugin uses the kubernetes Python client to authenticate via kubeconfig + and discover VMs, volumes, images, and networks. + + Expected credentials: + kubeconfig_path: Path to the kubeconfig file for the Harvester cluster. + context: (optional) Kubernetes context name to use. + """ + + def __init__(self) -> None: + self._api_client: client.ApiClient | None = None + self._custom_api: client.CustomObjectsApi | None = None + self._core_api: client.CoreV1Api | None = None + self._kubeconfig_path: str | None = None + self._context: str | None = None + + def authenticate(self, credentials: dict[str, str]) -> None: + """Authenticate with the Harvester cluster via kubeconfig. + + Args: + credentials: Must contain 'kubeconfig_path'. May contain 'context'. + + Raises: + AuthenticationError: If kubeconfig cannot be loaded or is invalid. + """ + kubeconfig_path = credentials.get("kubeconfig_path") + if not kubeconfig_path: + raise AuthenticationError( + provider_name="harvester", + reason="'kubeconfig_path' is required in credentials", + ) + + context = credentials.get("context") or None + self._kubeconfig_path = kubeconfig_path + self._context = context + + try: + self._api_client = config.new_client_from_config( + config_file=kubeconfig_path, + context=context, + ) + self._custom_api = client.CustomObjectsApi(self._api_client) + self._core_api = client.CoreV1Api(self._api_client) + except Exception as exc: + raise AuthenticationError( + provider_name="harvester", + reason=f"Failed to load kubeconfig: {exc}", + ) from exc + + def get_platform_category(self) -> PlatformCategory: + """Return HCI platform category for Harvester.""" + return PlatformCategory.HCI + + def list_endpoints(self) -> list[str]: + """Return the Harvester cluster API endpoint. + + Extracts the server URL from the loaded kubeconfig. + """ + if self._api_client is None: + return [] + host = self._api_client.configuration.host or "" + return [host] if host else [] + + def list_supported_resource_types(self) -> list[str]: + """Return resource types supported by the Harvester plugin.""" + return [ + "harvester_virtualmachine", + "harvester_volume", + "harvester_image", + "harvester_network", + ] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect CPU architecture from Harvester cluster node info. + + Queries the Kubernetes node list and inspects the architecture label. + Harvester typically runs on AMD64 (Dell PowerEdge servers). + + Args: + endpoint: The cluster API endpoint (used for logging context). + + Returns: + CpuArchitecture detected from node info. + """ + if self._core_api is None: + return CpuArchitecture.AMD64 + + try: + nodes = self._core_api.list_node() + if nodes.items: + node = nodes.items[0] + arch = node.status.node_info.architecture + arch_lower = arch.lower() if arch else "" + if arch_lower in ("arm64", "aarch64"): + return CpuArchitecture.AARCH64 + elif arch_lower == "arm": + return CpuArchitecture.ARM + else: + return CpuArchitecture.AMD64 + except ApiException as exc: + logger.warning( + "Failed to detect architecture from node info for %s: %s", + endpoint, + exc, + ) + + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover Harvester resources via Kubernetes CRDs. + + Enumerates VMs, volumes, images, and networks from the Harvester cluster. + + Args: + endpoints: List of cluster API endpoints. + resource_types: Resource types to discover. + progress_callback: Callback for progress updates. + + Returns: + ScanResult with discovered resources. + """ + resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + + endpoint = endpoints[0] if endpoints else "" + architecture = self.detect_architecture(endpoint) + + total_types = len(resource_types) + completed = 0 + + discovery_map = { + "harvester_virtualmachine": self._discover_vms, + "harvester_volume": self._discover_volumes, + "harvester_image": self._discover_images, + "harvester_network": self._discover_networks, + } + + for resource_type in resource_types: + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(resources), + resource_types_completed=completed, + total_resource_types=total_types, + ) + ) + + discover_fn = discovery_map.get(resource_type) + if discover_fn is None: + warnings.append(f"Unknown resource type: {resource_type}") + completed += 1 + continue + + try: + discovered = discover_fn(endpoint, architecture) + resources.extend(discovered) + except ApiException as exc: + error_msg = ( + f"Failed to discover {resource_type}: " + f"HTTP {exc.status} - {exc.reason}" + ) + errors.append(error_msg) + logger.error(error_msg) + except Exception as exc: + error_msg = f"Failed to discover {resource_type}: {exc}" + errors.append(error_msg) + logger.error(error_msg) + + completed += 1 + + # Final progress update + progress_callback( + ScanProgress( + current_resource_type="", + resources_discovered=len(resources), + resource_types_completed=total_types, + total_resource_types=total_types, + ) + ) + + return ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp="", + profile_hash="", + ) + + # ------------------------------------------------------------------ + # Private discovery methods + # ------------------------------------------------------------------ + + def _discover_vms( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Harvester virtual machines via kubevirt.io CRD.""" + items = self._list_cluster_custom_objects( + group=HARVESTER_API_GROUP, + version=HARVESTER_VM_VERSION, + plural=HARVESTER_VM_PLURAL, + ) + + resources = [] + for item in items: + metadata = item.get("metadata", {}) + spec = item.get("spec", {}) + name = metadata.get("name", "unknown") + namespace = metadata.get("namespace", DEFAULT_NAMESPACE) + uid = metadata.get("uid", f"{namespace}/{name}") + + resources.append( + DiscoveredResource( + resource_type="harvester_virtualmachine", + unique_id=uid, + name=name, + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "running": spec.get("running", False), + "spec": spec, + "labels": metadata.get("labels", {}), + "annotations": metadata.get("annotations", {}), + }, + raw_references=self._extract_vm_references(spec), + ) + ) + + return resources + + def _discover_volumes( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Harvester data volumes via cdi.kubevirt.io CRD.""" + items = self._list_cluster_custom_objects( + group=HARVESTER_CDI_GROUP, + version=HARVESTER_CDI_VERSION, + plural=HARVESTER_VOLUME_PLURAL, + ) + + resources = [] + for item in items: + metadata = item.get("metadata", {}) + spec = item.get("spec", {}) + name = metadata.get("name", "unknown") + namespace = metadata.get("namespace", DEFAULT_NAMESPACE) + uid = metadata.get("uid", f"{namespace}/{name}") + + resources.append( + DiscoveredResource( + resource_type="harvester_volume", + unique_id=uid, + name=name, + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "spec": spec, + "labels": metadata.get("labels", {}), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_images( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Harvester VM images via harvesterhci.io CRD.""" + items = self._list_cluster_custom_objects( + group=HARVESTER_IMAGE_GROUP, + version=HARVESTER_IMAGE_VERSION, + plural=HARVESTER_IMAGE_PLURAL, + ) + + resources = [] + for item in items: + metadata = item.get("metadata", {}) + spec = item.get("spec", {}) + name = metadata.get("name", "unknown") + namespace = metadata.get("namespace", DEFAULT_NAMESPACE) + uid = metadata.get("uid", f"{namespace}/{name}") + + resources.append( + DiscoveredResource( + resource_type="harvester_image", + unique_id=uid, + name=name, + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "display_name": spec.get("displayName", name), + "url": spec.get("url", ""), + "spec": spec, + "labels": metadata.get("labels", {}), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_networks( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Harvester networks via k8s.cni.cncf.io CRD.""" + items = self._list_cluster_custom_objects( + group=HARVESTER_NETWORK_GROUP, + version=HARVESTER_NETWORK_VERSION, + plural=HARVESTER_NETWORK_PLURAL, + ) + + resources = [] + for item in items: + metadata = item.get("metadata", {}) + spec = item.get("spec", {}) + name = metadata.get("name", "unknown") + namespace = metadata.get("namespace", DEFAULT_NAMESPACE) + uid = metadata.get("uid", f"{namespace}/{name}") + + resources.append( + DiscoveredResource( + resource_type="harvester_network", + unique_id=uid, + name=name, + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "config": spec.get("config", ""), + "labels": metadata.get("labels", {}), + }, + raw_references=[], + ) + ) + + return resources + + def _list_cluster_custom_objects( + self, group: str, version: str, plural: str + ) -> list[dict]: + """List all custom objects across all namespaces. + + Args: + group: API group (e.g., 'kubevirt.io'). + version: API version (e.g., 'v1'). + plural: Resource plural name (e.g., 'virtualmachines'). + + Returns: + List of resource items as dicts. + """ + if self._custom_api is None: + return [] + + result = self._custom_api.list_cluster_custom_object( + group=group, + version=version, + plural=plural, + ) + return result.get("items", []) + + @staticmethod + def _extract_vm_references(spec: dict) -> list[str]: + """Extract resource references from a VM spec. + + Looks for volume and network references in the VM template spec. + """ + references: list[str] = [] + + template = spec.get("template", {}) + template_spec = template.get("spec", {}) + + # Extract volume references + volumes = template_spec.get("volumes", []) + for volume in volumes: + if "dataVolume" in volume: + dv_name = volume["dataVolume"].get("name", "") + if dv_name: + references.append(f"volume:{dv_name}") + if "persistentVolumeClaim" in volume: + pvc_name = volume["persistentVolumeClaim"].get("claimName", "") + if pvc_name: + references.append(f"volume:{pvc_name}") + + # Extract network references + networks = template_spec.get("networks", []) + for network in networks: + if "multus" in network: + net_name = network["multus"].get("networkName", "") + if net_name: + references.append(f"network:{net_name}") + + return references diff --git a/src/iac_reverse/scanner/kubernetes_plugin.py b/src/iac_reverse/scanner/kubernetes_plugin.py new file mode 100644 index 0000000..5c246ee --- /dev/null +++ b/src/iac_reverse/scanner/kubernetes_plugin.py @@ -0,0 +1,454 @@ +"""Kubernetes provider plugin for infrastructure discovery. + +Uses the official kubernetes-client library to discover deployments, services, +ingresses, config maps, persistent volumes, and namespaces from a Kubernetes +cluster. Detects CPU architecture from node labels. +""" + +import logging +from typing import Callable + +from kubernetes import client, config + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import AuthenticationError + +logger = logging.getLogger(__name__) + +# Mapping from kubernetes.io/arch label values to CpuArchitecture enum +_ARCH_LABEL_MAP: dict[str, CpuArchitecture] = { + "amd64": CpuArchitecture.AMD64, + "arm": CpuArchitecture.ARM, + "arm64": CpuArchitecture.AARCH64, + "aarch64": CpuArchitecture.AARCH64, +} + +_SUPPORTED_RESOURCE_TYPES = [ + "kubernetes_deployment", + "kubernetes_service", + "kubernetes_ingress", + "kubernetes_config_map", + "kubernetes_persistent_volume", + "kubernetes_namespace", +] + + +class KubernetesPlugin(ProviderPlugin): + """Kubernetes provider plugin using the official kubernetes-client. + + Authenticates via kubeconfig file and discovers cluster resources + including deployments, services, ingresses, config maps, persistent + volumes, and namespaces. + """ + + def __init__(self) -> None: + self._api_client: client.ApiClient | None = None + self._core_v1: client.CoreV1Api | None = None + self._apps_v1: client.AppsV1Api | None = None + self._networking_v1: client.NetworkingV1Api | None = None + + def authenticate(self, credentials: dict[str, str]) -> None: + """Load kubeconfig and initialize Kubernetes API clients. + + Args: + credentials: Dict with keys: + - kubeconfig_path: Path to the kubeconfig file (required) + - context: Kubernetes context name (optional) + + Raises: + AuthenticationError: If kubeconfig cannot be loaded. + """ + kubeconfig_path = credentials.get("kubeconfig_path") + if not kubeconfig_path: + raise AuthenticationError( + provider_name="kubernetes", + reason="kubeconfig_path is required in credentials", + ) + + context = credentials.get("context") or None + + try: + config.load_kube_config( + config_file=kubeconfig_path, + context=context, + ) + except Exception as exc: + raise AuthenticationError( + provider_name="kubernetes", + reason=f"Failed to load kubeconfig from '{kubeconfig_path}': {exc}", + ) from exc + + self._api_client = client.ApiClient() + self._core_v1 = client.CoreV1Api(self._api_client) + self._apps_v1 = client.AppsV1Api(self._api_client) + self._networking_v1 = client.NetworkingV1Api(self._api_client) + + def get_platform_category(self) -> PlatformCategory: + """Return CONTAINER_ORCHESTRATION platform category.""" + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + """Return node addresses as endpoints. + + Returns: + List of node internal IP addresses or hostnames. + """ + if self._core_v1 is None: + return [] + + try: + nodes = self._core_v1.list_node() + endpoints: list[str] = [] + for node in nodes.items: + if node.status and node.status.addresses: + for addr in node.status.addresses: + if addr.type == "InternalIP": + endpoints.append(addr.address) + break + else: + # Fallback to first address + endpoints.append(node.status.addresses[0].address) + return endpoints + except Exception as exc: + logger.warning("Failed to list node endpoints: %s", exc) + return [] + + def list_supported_resource_types(self) -> list[str]: + """Return all Kubernetes resource types this plugin can discover.""" + return list(_SUPPORTED_RESOURCE_TYPES) + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect CPU architecture from node labels. + + Queries node labels for 'kubernetes.io/arch' to determine the + CPU architecture. Falls back to AMD64 if the label is not found. + + Args: + endpoint: Node IP address or hostname to query. + + Returns: + CpuArchitecture enum value for the node. + """ + if self._core_v1 is None: + return CpuArchitecture.AMD64 + + try: + nodes = self._core_v1.list_node() + for node in nodes.items: + # Match node by address + if node.status and node.status.addresses: + node_addresses = [ + addr.address for addr in node.status.addresses + ] + if endpoint in node_addresses: + labels = node.metadata.labels or {} + arch_label = labels.get( + "kubernetes.io/arch", + labels.get("beta.kubernetes.io/arch", "amd64"), + ) + return _ARCH_LABEL_MAP.get( + arch_label, CpuArchitecture.AMD64 + ) + except Exception as exc: + logger.warning( + "Failed to detect architecture for endpoint '%s': %s", + endpoint, + exc, + ) + + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover Kubernetes resources across all namespaces. + + Args: + endpoints: List of node addresses (used for architecture detection). + resource_types: List of resource type strings to discover. + progress_callback: Callable for progress updates. + + Returns: + ScanResult with all discovered resources. + """ + resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + + # Determine architecture from first endpoint + architecture = CpuArchitecture.AMD64 + if endpoints: + architecture = self.detect_architecture(endpoints[0]) + + endpoint_str = endpoints[0] if endpoints else "cluster" + total_types = len(resource_types) + + for idx, resource_type in enumerate(resource_types): + try: + discovered = self._discover_resource_type( + resource_type, architecture, endpoint_str + ) + resources.extend(discovered) + except Exception as exc: + error_msg = ( + f"Error discovering {resource_type}: {exc}" + ) + errors.append(error_msg) + logger.error(error_msg) + + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(resources), + resource_types_completed=idx + 1, + total_resource_types=total_types, + ) + ) + + return ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp="", + profile_hash="", + ) + + def _discover_resource_type( + self, + resource_type: str, + architecture: CpuArchitecture, + endpoint: str, + ) -> list[DiscoveredResource]: + """Discover resources of a specific type. + + Args: + resource_type: The resource type string to discover. + architecture: Detected CPU architecture. + endpoint: Endpoint string for the resource. + + Returns: + List of DiscoveredResource objects. + """ + dispatch = { + "kubernetes_deployment": self._discover_deployments, + "kubernetes_service": self._discover_services, + "kubernetes_ingress": self._discover_ingresses, + "kubernetes_config_map": self._discover_config_maps, + "kubernetes_persistent_volume": self._discover_persistent_volumes, + "kubernetes_namespace": self._discover_namespaces, + } + + handler = dispatch.get(resource_type) + if handler is None: + return [] + + return handler(architecture, endpoint) + + def _discover_deployments( + self, architecture: CpuArchitecture, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all deployments across namespaces.""" + results: list[DiscoveredResource] = [] + deployments = self._apps_v1.list_deployment_for_all_namespaces() + + for dep in deployments.items: + name = dep.metadata.name + namespace = dep.metadata.namespace + results.append( + DiscoveredResource( + resource_type="kubernetes_deployment", + unique_id=f"{namespace}/{name}", + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "replicas": dep.spec.replicas if dep.spec else None, + "labels": dict(dep.metadata.labels or {}), + }, + raw_references=[ + f"kubernetes_namespace:{namespace}", + ], + ) + ) + + return results + + def _discover_services( + self, architecture: CpuArchitecture, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all services across namespaces.""" + results: list[DiscoveredResource] = [] + services = self._core_v1.list_service_for_all_namespaces() + + for svc in services.items: + name = svc.metadata.name + namespace = svc.metadata.namespace + results.append( + DiscoveredResource( + resource_type="kubernetes_service", + unique_id=f"{namespace}/{name}", + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "type": svc.spec.type if svc.spec else None, + "cluster_ip": svc.spec.cluster_ip if svc.spec else None, + "labels": dict(svc.metadata.labels or {}), + }, + raw_references=[ + f"kubernetes_namespace:{namespace}", + ], + ) + ) + + return results + + def _discover_ingresses( + self, architecture: CpuArchitecture, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all ingresses across namespaces.""" + results: list[DiscoveredResource] = [] + ingresses = self._networking_v1.list_ingress_for_all_namespaces() + + for ing in ingresses.items: + name = ing.metadata.name + namespace = ing.metadata.namespace + results.append( + DiscoveredResource( + resource_type="kubernetes_ingress", + unique_id=f"{namespace}/{name}", + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "labels": dict(ing.metadata.labels or {}), + }, + raw_references=[ + f"kubernetes_namespace:{namespace}", + ], + ) + ) + + return results + + def _discover_config_maps( + self, architecture: CpuArchitecture, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all config maps across namespaces.""" + results: list[DiscoveredResource] = [] + config_maps = self._core_v1.list_config_map_for_all_namespaces() + + for cm in config_maps.items: + name = cm.metadata.name + namespace = cm.metadata.namespace + results.append( + DiscoveredResource( + resource_type="kubernetes_config_map", + unique_id=f"{namespace}/{name}", + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "namespace": namespace, + "data_keys": list((cm.data or {}).keys()), + "labels": dict(cm.metadata.labels or {}), + }, + raw_references=[ + f"kubernetes_namespace:{namespace}", + ], + ) + ) + + return results + + def _discover_persistent_volumes( + self, architecture: CpuArchitecture, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all persistent volumes (cluster-scoped).""" + results: list[DiscoveredResource] = [] + pvs = self._core_v1.list_persistent_volume() + + for pv in pvs.items: + name = pv.metadata.name + results.append( + DiscoveredResource( + resource_type="kubernetes_persistent_volume", + unique_id=name, + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "capacity": ( + dict(pv.spec.capacity) + if pv.spec and pv.spec.capacity + else {} + ), + "access_modes": ( + list(pv.spec.access_modes) + if pv.spec and pv.spec.access_modes + else [] + ), + "storage_class": ( + pv.spec.storage_class_name if pv.spec else None + ), + "labels": dict(pv.metadata.labels or {}), + }, + raw_references=[], + ) + ) + + return results + + def _discover_namespaces( + self, architecture: CpuArchitecture, endpoint: str + ) -> list[DiscoveredResource]: + """Discover all namespaces.""" + results: list[DiscoveredResource] = [] + namespaces = self._core_v1.list_namespace() + + for ns in namespaces.items: + name = ns.metadata.name + results.append( + DiscoveredResource( + resource_type="kubernetes_namespace", + unique_id=name, + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=architecture, + endpoint=endpoint, + attributes={ + "status": ( + ns.status.phase if ns.status else None + ), + "labels": dict(ns.metadata.labels or {}), + }, + raw_references=[], + ) + ) + + return results diff --git a/src/iac_reverse/scanner/multi_provider_scanner.py b/src/iac_reverse/scanner/multi_provider_scanner.py new file mode 100644 index 0000000..98f71e5 --- /dev/null +++ b/src/iac_reverse/scanner/multi_provider_scanner.py @@ -0,0 +1,140 @@ +"""Multi-provider scanner for infrastructure discovery. + +Coordinates scanning across multiple providers independently, handling +partial failures gracefully. If one provider fails, scanning continues +for all remaining providers. Successfully discovered resources are +collected into a unified inventory, and failed providers are reported +with error details. + +Implements Requirement 5.5: IF one or more Provider scans fail during a +multi-provider scan, THEN THE Scanner SHALL complete scanning for all +remaining Providers, include successfully discovered Resources in the +inventory, and report which Providers failed along with the corresponding +error details. +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Callable, Optional + +from iac_reverse.models import ( + DiscoveredResource, + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import Scanner + +logger = logging.getLogger(__name__) + + +@dataclass +class ProviderFailure: + """Details about a provider that failed during multi-provider scanning.""" + + provider_name: str + error_type: str + error_message: str + + +@dataclass +class MultiProviderScanResult: + """Result of scanning across multiple providers. + + Contains all successfully discovered resources from providers that + completed scanning, plus details about any providers that failed. + """ + + resources: list[DiscoveredResource] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + failed_providers: list[ProviderFailure] = field(default_factory=list) + successful_providers: list[str] = field(default_factory=list) + scan_timestamp: str = "" + + +@dataclass +class ProviderScanEntry: + """A pairing of a ScanProfile with its corresponding ProviderPlugin.""" + + profile: ScanProfile + plugin: ProviderPlugin + + +class MultiProviderScanner: + """Orchestrates infrastructure discovery across multiple providers. + + Scans each provider independently. If one provider fails (auth error, + connection error, etc.), continues with remaining providers. Collects + all successfully discovered resources into a unified inventory and + reports which providers failed and why. + """ + + def __init__(self, entries: list[ProviderScanEntry]): + """Initialize with a list of provider scan entries. + + Args: + entries: List of ProviderScanEntry, each pairing a ScanProfile + with its corresponding ProviderPlugin. + """ + self.entries = entries + + def scan( + self, + progress_callback: Optional[Callable[[ScanProgress], None]] = None, + ) -> MultiProviderScanResult: + """Execute scans across all configured providers. + + Each provider is scanned independently. If a provider fails for + any reason (authentication, connection, timeout, validation, etc.), + the error is recorded and scanning continues with remaining providers. + + Args: + progress_callback: Optional callable invoked with ScanProgress + updates from each provider scan. + + Returns: + MultiProviderScanResult containing all successfully discovered + resources and details about any failed providers. + """ + result = MultiProviderScanResult( + scan_timestamp=datetime.now(timezone.utc).isoformat(), + ) + + for entry in self.entries: + provider_name = entry.profile.provider.value + try: + scanner = Scanner(entry.profile, entry.plugin) + scan_result = scanner.scan(progress_callback=progress_callback) + + # Collect successful resources + result.resources.extend(scan_result.resources) + result.warnings.extend(scan_result.warnings) + result.errors.extend(scan_result.errors) + result.successful_providers.append(provider_name) + + logger.info( + "Provider '%s' scan completed: %d resources discovered", + provider_name, + len(scan_result.resources), + ) + + except Exception as exc: + # Record the failure and continue with remaining providers + failure = ProviderFailure( + provider_name=provider_name, + error_type=type(exc).__name__, + error_message=str(exc), + ) + result.failed_providers.append(failure) + + logger.warning( + "Provider '%s' scan failed (%s): %s", + provider_name, + type(exc).__name__, + exc, + ) + + return result diff --git a/src/iac_reverse/scanner/scanner.py b/src/iac_reverse/scanner/scanner.py new file mode 100644 index 0000000..0780135 --- /dev/null +++ b/src/iac_reverse/scanner/scanner.py @@ -0,0 +1,287 @@ +"""Scanner orchestrator for infrastructure discovery. + +Coordinates provider plugins to discover infrastructure resources, +handling authentication, retries, progress reporting, and error recovery. +""" + +import hashlib +import logging +import time +from datetime import datetime, timezone +from typing import Callable, Optional + +from iac_reverse.models import ( + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Custom Exceptions +# --------------------------------------------------------------------------- + + +class AuthenticationError(Exception): + """Raised when authentication with a provider fails.""" + + def __init__(self, provider_name: str, reason: str): + self.provider_name = provider_name + self.reason = reason + super().__init__( + f"Authentication failed for provider '{provider_name}': {reason}" + ) + + +class ConnectionLostError(Exception): + """Raised when the provider connection is lost during a scan.""" + + def __init__(self, partial_result: ScanResult): + self.partial_result = partial_result + super().__init__("Connection lost during scan; partial results available") + + +class ScanTimeoutError(Exception): + """Raised when a scan operation exceeds the allowed timeout.""" + + def __init__(self, message: str = "Scan operation timed out"): + super().__init__(message) + + +# --------------------------------------------------------------------------- +# Scanner Orchestrator +# --------------------------------------------------------------------------- + +# Default constants +CONNECTION_TIMEOUT_SECONDS = 30 +MAX_RETRIES = 3 +INITIAL_BACKOFF_SECONDS = 1.0 + + +class Scanner: + """Orchestrates infrastructure discovery using a provider plugin. + + Accepts a ScanProfile and an optional ProviderPlugin instance. + Handles authentication, progress reporting, retry logic with + exponential backoff, and graceful degradation on errors. + """ + + def __init__( + self, + profile: ScanProfile, + plugin: Optional[ProviderPlugin] = None, + ): + self.profile = profile + self.plugin = plugin + + def scan( + self, + progress_callback: Optional[Callable[[ScanProgress], None]] = None, + ) -> ScanResult: + """Execute a full infrastructure scan. + + Args: + progress_callback: Optional callable invoked per resource type + completion with a ScanProgress update. + + Returns: + ScanResult containing discovered resources, warnings, and errors. + + Raises: + AuthenticationError: If authentication with the provider fails. + ScanTimeoutError: If the connection attempt exceeds 30 seconds. + ValueError: If the scan profile is invalid. + """ + # 1. Validate the scan profile (critical fields only) + validation_errors = self._validate_profile() + if validation_errors: + raise ValueError( + f"Invalid scan profile: {'; '.join(validation_errors)}" + ) + + if self.plugin is None: + raise ValueError("No provider plugin configured for scanning") + + # 2. Authenticate with the provider (30 second timeout) + self._authenticate() + + # 3. Determine resource types to scan + supported_types = self.plugin.list_supported_resource_types() + resource_types, warnings = self._resolve_resource_types(supported_types) + + # 4. Determine endpoints + endpoints = self.profile.endpoints or self.plugin.list_endpoints() + + # 5. Discover resources with retry logic + scan_result = self._discover_with_retries( + endpoints=endpoints, + resource_types=resource_types, + progress_callback=progress_callback, + ) + + # Merge any warnings from unsupported resource type filtering + scan_result.warnings = warnings + scan_result.warnings + + # Set metadata + scan_result.scan_timestamp = datetime.now(timezone.utc).isoformat() + scan_result.profile_hash = self._compute_profile_hash() + + return scan_result + + def _authenticate(self) -> None: + """Authenticate with the provider plugin, enforcing a 30s timeout.""" + provider_name = self.profile.provider.value + start_time = time.monotonic() + + try: + self.plugin.authenticate(self.profile.credentials) + except Exception as exc: + elapsed = time.monotonic() - start_time + if elapsed >= CONNECTION_TIMEOUT_SECONDS: + raise ScanTimeoutError( + f"Authentication with provider '{provider_name}' " + f"timed out after {CONNECTION_TIMEOUT_SECONDS} seconds" + ) + # Wrap any auth exception in our AuthenticationError + if isinstance(exc, AuthenticationError): + raise + raise AuthenticationError( + provider_name=provider_name, + reason=str(exc), + ) from exc + + elapsed = time.monotonic() - start_time + if elapsed >= CONNECTION_TIMEOUT_SECONDS: + raise ScanTimeoutError( + f"Authentication with provider '{provider_name}' " + f"timed out after {CONNECTION_TIMEOUT_SECONDS} seconds" + ) + + def _resolve_resource_types( + self, supported_types: list[str] + ) -> tuple[list[str], list[str]]: + """Determine which resource types to scan and log warnings for unsupported ones. + + Returns: + Tuple of (resource_types_to_scan, warnings_list) + """ + warnings: list[str] = [] + + if self.profile.resource_type_filters is None: + # No filters: scan all supported types + return supported_types, warnings + + # Filter requested types against supported types + valid_types: list[str] = [] + for rt in self.profile.resource_type_filters: + if rt in supported_types: + valid_types.append(rt) + else: + warning_msg = ( + f"Unsupported resource type '{rt}' for provider " + f"'{self.profile.provider.value}'; skipping" + ) + warnings.append(warning_msg) + logger.warning(warning_msg) + + return valid_types, warnings + + def _discover_with_retries( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Optional[Callable[[ScanProgress], None]], + ) -> ScanResult: + """Call the plugin's discover_resources with retry logic. + + Retries up to MAX_RETRIES times with exponential backoff for + transient errors. On connection loss, returns partial inventory. + """ + last_exception: Optional[Exception] = None + + for attempt in range(MAX_RETRIES + 1): + try: + result = self.plugin.discover_resources( + endpoints=endpoints, + resource_types=resource_types, + progress_callback=progress_callback or self._noop_callback, + ) + return result + except ConnectionLostError: + # Connection lost: return partial results immediately + raise + except ConnectionError as exc: + # Connection lost during scan: build partial result + logger.warning( + "Connection lost during scan (attempt %d/%d): %s", + attempt + 1, + MAX_RETRIES + 1, + exc, + ) + partial = ScanResult( + resources=[], + warnings=[f"Connection lost: {exc}"], + errors=[str(exc)], + scan_timestamp=datetime.now(timezone.utc).isoformat(), + profile_hash=self._compute_profile_hash(), + is_partial=True, + ) + raise ConnectionLostError(partial_result=partial) from exc + except Exception as exc: + last_exception = exc + if attempt < MAX_RETRIES: + backoff = INITIAL_BACKOFF_SECONDS * (2**attempt) + logger.warning( + "Transient error during scan (attempt %d/%d), " + "retrying in %.1fs: %s", + attempt + 1, + MAX_RETRIES + 1, + backoff, + exc, + ) + time.sleep(backoff) + else: + logger.error( + "Scan failed after %d attempts: %s", + MAX_RETRIES + 1, + exc, + ) + + # All retries exhausted — return error result + return ScanResult( + resources=[], + warnings=[], + errors=[f"Scan failed after {MAX_RETRIES + 1} attempts: {last_exception}"], + scan_timestamp=datetime.now(timezone.utc).isoformat(), + profile_hash=self._compute_profile_hash(), + is_partial=True, + ) + + def _validate_profile(self) -> list[str]: + """Validate critical scan profile fields. + + Only checks fields that prevent scanning entirely (e.g., missing + credentials). Unsupported resource types are handled as warnings + during the scan per Requirement 1.4. + """ + errors: list[str] = [] + if not self.profile.credentials: + errors.append("credentials must not be empty") + return errors + + def _compute_profile_hash(self) -> str: + """Compute a stable hash of the scan profile for snapshot matching.""" + content = ( + f"{self.profile.provider.value}:" + f"{sorted(self.profile.credentials.items())}:" + f"{self.profile.endpoints}:" + f"{self.profile.resource_type_filters}" + ) + return hashlib.sha256(content.encode()).hexdigest()[:16] + + @staticmethod + def _noop_callback(progress: ScanProgress) -> None: + """No-op progress callback used when none is provided.""" + pass diff --git a/src/iac_reverse/scanner/synology_plugin.py b/src/iac_reverse/scanner/synology_plugin.py new file mode 100644 index 0000000..240ce9d --- /dev/null +++ b/src/iac_reverse/scanner/synology_plugin.py @@ -0,0 +1,482 @@ +"""Synology DSM provider plugin. + +Discovers shared folders, volumes, storage pools, replication tasks, and users +from a Synology DiskStation Manager (DSM) appliance via its HTTP API. +""" + +import logging +from datetime import datetime, timezone +from typing import Callable, Optional + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import AuthenticationError + +try: + from synology_dsm import SynologyDSM +except ImportError: # pragma: no cover + SynologyDSM = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + +# Resource type constants +SYNOLOGY_SHARED_FOLDER = "synology_shared_folder" +SYNOLOGY_VOLUME = "synology_volume" +SYNOLOGY_STORAGE_POOL = "synology_storage_pool" +SYNOLOGY_REPLICATION_TASK = "synology_replication_task" +SYNOLOGY_USER = "synology_user" + +SUPPORTED_RESOURCE_TYPES = [ + SYNOLOGY_SHARED_FOLDER, + SYNOLOGY_VOLUME, + SYNOLOGY_STORAGE_POOL, + SYNOLOGY_REPLICATION_TASK, + SYNOLOGY_USER, +] + + +class SynologyPlugin(ProviderPlugin): + """Provider plugin for Synology DiskStation Manager (DSM). + + Connects to the Synology DSM API to discover storage infrastructure + including shared folders, volumes, storage pools, replication tasks, + and local users. + + Expected credentials: + - host: DSM hostname or IP address + - port: DSM port (default "5001") + - username: DSM admin username + - password: DSM admin password + - use_ssl: "true" or "false" (default "true") + """ + + def __init__(self) -> None: + self._api: Optional[object] = None + self._host: str = "" + self._port: str = "5001" + self._use_ssl: bool = True + self._authenticated: bool = False + + def authenticate(self, credentials: dict[str, str]) -> None: + """Authenticate with the Synology DSM API. + + Args: + credentials: Dict with keys: host, port, username, password, + and optionally use_ssl. + + Raises: + AuthenticationError: If connection or login fails. + """ + host = credentials.get("host", "") + port = credentials.get("port", "5001") + username = credentials.get("username", "") + password = credentials.get("password", "") + use_ssl = credentials.get("use_ssl", "true").lower() == "true" + + if not host: + raise AuthenticationError("synology", "host is required") + if not username: + raise AuthenticationError("synology", "username is required") + if not password: + raise AuthenticationError("synology", "password is required") + + self._host = host + self._port = port + self._use_ssl = use_ssl + + try: + if SynologyDSM is None: + raise AuthenticationError( + "synology", + "python-synology library is not installed", + ) + + api = SynologyDSM( + host, + int(port), + username, + password, + use_https=use_ssl, + verify_ssl=False, + ) + # Attempt login + if not api.login(): + raise AuthenticationError( + "synology", + f"Login failed for user '{username}' on {host}:{port}", + ) + self._api = api + self._authenticated = True + logger.info("Authenticated with Synology DSM at %s:%s", host, port) + except AuthenticationError: + raise + except Exception as exc: + raise AuthenticationError( + "synology", + f"Failed to connect to DSM at {host}:{port}: {exc}", + ) from exc + + def get_platform_category(self) -> PlatformCategory: + """Return STORAGE_APPLIANCE platform category.""" + return PlatformCategory.STORAGE_APPLIANCE + + def list_endpoints(self) -> list[str]: + """Return the DSM endpoint address.""" + protocol = "https" if self._use_ssl else "http" + return [f"{protocol}://{self._host}:{self._port}"] + + def list_supported_resource_types(self) -> list[str]: + """Return all Synology resource types this plugin can discover.""" + return list(SUPPORTED_RESOURCE_TYPES) + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect CPU architecture from Synology system info. + + Queries the DSM information API to determine if the NAS + runs on ARM or AMD64 hardware. + + Args: + endpoint: The DSM endpoint (used for context, not connection). + + Returns: + CpuArchitecture.ARM for ARM-based models, + CpuArchitecture.AARCH64 for 64-bit ARM models, + CpuArchitecture.AMD64 for x86-64 models. + """ + if self._api is None: + return CpuArchitecture.AMD64 + + try: + info = self._api.information + if info is None: + return CpuArchitecture.AMD64 + + # The model name or CPU info can indicate architecture + model = getattr(info, "model", "") or "" + cpu_name = getattr(info, "cpu_hardware_name", "") or "" + + # Combine for matching + hw_info = f"{model} {cpu_name}".lower() + + if "aarch64" in hw_info or "arm64" in hw_info: + return CpuArchitecture.AARCH64 + elif "arm" in hw_info or "rtd" in hw_info or "alpine" in hw_info: + return CpuArchitecture.ARM + else: + return CpuArchitecture.AMD64 + except Exception as exc: + logger.warning("Failed to detect architecture: %s", exc) + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover Synology resources from the DSM API. + + Enumerates shared folders, volumes, storage pools, replication tasks, + and users based on the requested resource_types. + + Args: + endpoints: List of DSM endpoints (typically one). + resource_types: Resource types to discover. + progress_callback: Callback for progress updates. + + Returns: + ScanResult with discovered resources. + """ + resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + + endpoint = endpoints[0] if endpoints else self.list_endpoints()[0] + architecture = self.detect_architecture(endpoint) + + total_types = len(resource_types) + completed = 0 + + # Discovery dispatch table + discovery_methods = { + SYNOLOGY_SHARED_FOLDER: self._discover_shared_folders, + SYNOLOGY_VOLUME: self._discover_volumes, + SYNOLOGY_STORAGE_POOL: self._discover_storage_pools, + SYNOLOGY_REPLICATION_TASK: self._discover_replication_tasks, + SYNOLOGY_USER: self._discover_users, + } + + for rt in resource_types: + progress_callback( + ScanProgress( + current_resource_type=rt, + resources_discovered=len(resources), + resource_types_completed=completed, + total_resource_types=total_types, + ) + ) + + method = discovery_methods.get(rt) + if method is None: + warnings.append(f"Unsupported resource type: {rt}") + completed += 1 + continue + + try: + discovered = method(endpoint, architecture) + resources.extend(discovered) + except Exception as exc: + error_msg = f"Error discovering {rt}: {exc}" + errors.append(error_msg) + logger.error(error_msg) + + completed += 1 + + # Final progress update + progress_callback( + ScanProgress( + current_resource_type="", + resources_discovered=len(resources), + resource_types_completed=completed, + total_resource_types=total_types, + ) + ) + + return ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp=datetime.now(timezone.utc).isoformat(), + profile_hash="", + ) + + # ------------------------------------------------------------------ + # Private discovery methods + # ------------------------------------------------------------------ + + def _discover_shared_folders( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover shared folders from DSM.""" + resources: list[DiscoveredResource] = [] + + storage = self._api.storage + if storage is None: + return resources + + # Access shared folders via the storage API + shares = getattr(storage, "shares", None) + if shares is None: + return resources + + for share in shares: + name = share.get("name", "unknown") + resources.append( + DiscoveredResource( + resource_type=SYNOLOGY_SHARED_FOLDER, + unique_id=f"synology/shared_folder/{name}", + name=name, + provider=ProviderType.SYNOLOGY, + platform_category=PlatformCategory.STORAGE_APPLIANCE, + architecture=architecture, + endpoint=endpoint, + attributes={ + "name": name, + "path": share.get("path", ""), + "desc": share.get("desc", ""), + "encryption": share.get("is_encrypted", False), + "recycle_bin": share.get("enable_recycle_bin", False), + "vol_path": share.get("vol_path", ""), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_volumes( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover volumes from DSM.""" + resources: list[DiscoveredResource] = [] + + storage = self._api.storage + if storage is None: + return resources + + volumes = getattr(storage, "volumes", None) + if volumes is None: + return resources + + for volume in volumes: + vol_id = volume.get("id", "unknown") + name = volume.get("display_name", vol_id) + resources.append( + DiscoveredResource( + resource_type=SYNOLOGY_VOLUME, + unique_id=f"synology/volume/{vol_id}", + name=name, + provider=ProviderType.SYNOLOGY, + platform_category=PlatformCategory.STORAGE_APPLIANCE, + architecture=architecture, + endpoint=endpoint, + attributes={ + "id": vol_id, + "display_name": name, + "status": volume.get("status", ""), + "fs_type": volume.get("fs_type", ""), + "size_total": volume.get("size", {}).get("total", ""), + "size_used": volume.get("size", {}).get("used", ""), + "pool_path": volume.get("pool_path", ""), + }, + raw_references=[ + f"synology/storage_pool/{volume.get('pool_path', '')}" + ] + if volume.get("pool_path") + else [], + ) + ) + + return resources + + def _discover_storage_pools( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover storage pools from DSM.""" + resources: list[DiscoveredResource] = [] + + storage = self._api.storage + if storage is None: + return resources + + pools = getattr(storage, "storage_pools", None) + if pools is None: + return resources + + for pool in pools: + pool_id = pool.get("id", "unknown") + name = pool.get("display_name", pool_id) + resources.append( + DiscoveredResource( + resource_type=SYNOLOGY_STORAGE_POOL, + unique_id=f"synology/storage_pool/{pool_id}", + name=name, + provider=ProviderType.SYNOLOGY, + platform_category=PlatformCategory.STORAGE_APPLIANCE, + architecture=architecture, + endpoint=endpoint, + attributes={ + "id": pool_id, + "display_name": name, + "status": pool.get("status", ""), + "raid_type": pool.get("raid_type", ""), + "size_total": pool.get("size", {}).get("total", ""), + "size_used": pool.get("size", {}).get("used", ""), + "disk_count": len(pool.get("disks", [])), + }, + raw_references=[], + ) + ) + + return resources + + def _discover_replication_tasks( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover replication tasks from DSM.""" + resources: list[DiscoveredResource] = [] + + # Replication tasks are accessed via a separate API module + api = self._api + if api is None: + return resources + + # Try to access replication info if available + replication = getattr(api, "replication", None) + if replication is None: + return resources + + tasks = getattr(replication, "tasks", None) + if tasks is None: + return resources + + for task in tasks: + task_id = task.get("id", "unknown") + name = task.get("name", task_id) + resources.append( + DiscoveredResource( + resource_type=SYNOLOGY_REPLICATION_TASK, + unique_id=f"synology/replication_task/{task_id}", + name=name, + provider=ProviderType.SYNOLOGY, + platform_category=PlatformCategory.STORAGE_APPLIANCE, + architecture=architecture, + endpoint=endpoint, + attributes={ + "id": task_id, + "name": name, + "status": task.get("status", ""), + "type": task.get("type", ""), + "destination": task.get("destination", ""), + "schedule": task.get("schedule", {}), + "shared_folder": task.get("shared_folder", ""), + }, + raw_references=[ + f"synology/shared_folder/{task.get('shared_folder', '')}" + ] + if task.get("shared_folder") + else [], + ) + ) + + return resources + + def _discover_users( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover local users from DSM.""" + resources: list[DiscoveredResource] = [] + + api = self._api + if api is None: + return resources + + # Users are typically accessed via SYNO.Core.User API + users_api = getattr(api, "users", None) + if users_api is None: + return resources + + users = getattr(users_api, "users", None) + if users is None: + return resources + + for user in users: + username = user.get("name", "unknown") + resources.append( + DiscoveredResource( + resource_type=SYNOLOGY_USER, + unique_id=f"synology/user/{username}", + name=username, + provider=ProviderType.SYNOLOGY, + platform_category=PlatformCategory.STORAGE_APPLIANCE, + architecture=architecture, + endpoint=endpoint, + attributes={ + "name": username, + "description": user.get("description", ""), + "email": user.get("email", ""), + "expired": user.get("expired", False), + "groups": user.get("groups", []), + }, + raw_references=[], + ) + ) + + return resources diff --git a/src/iac_reverse/scanner/windows_plugin.py b/src/iac_reverse/scanner/windows_plugin.py new file mode 100644 index 0000000..61217c8 --- /dev/null +++ b/src/iac_reverse/scanner/windows_plugin.py @@ -0,0 +1,825 @@ +"""Windows provider plugin for infrastructure discovery via WinRM. + +Uses pywinrm to connect to Windows machines and discover services, +scheduled tasks, IIS sites, app pools, network adapters, firewall rules, +installed software, Windows features, Hyper-V VMs, Hyper-V switches, +DNS records, local users, and local groups. +""" + +import json +import logging +from typing import Callable + +import winrm + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import AuthenticationError + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Custom Exceptions +# --------------------------------------------------------------------------- + + +class WinRMNotEnabledError(Exception): + """Raised when WinRM is not enabled on the target host.""" + + def __init__(self, host: str, reason: str = ""): + self.host = host + self.reason = reason + super().__init__( + f"WinRM is not enabled or unreachable on host '{host}'" + + (f": {reason}" if reason else "") + ) + + +class WMIQueryError(Exception): + """Raised when a WMI query fails on the target host.""" + + def __init__(self, query: str, reason: str = ""): + self.query = query + self.reason = reason + super().__init__( + f"WMI query failed: '{query}'" + + (f": {reason}" if reason else "") + ) + + +class InsufficientPrivilegesError(Exception): + """Raised when the authenticated user lacks required privileges.""" + + def __init__(self, operation: str, reason: str = ""): + self.operation = operation + self.reason = reason + super().__init__( + f"Insufficient privileges for operation '{operation}'" + + (f": {reason}" if reason else "") + ) + + +# --------------------------------------------------------------------------- +# Windows Discovery Plugin +# --------------------------------------------------------------------------- + +WINDOWS_RESOURCE_TYPES = [ + "windows_service", + "windows_scheduled_task", + "windows_iis_site", + "windows_iis_app_pool", + "windows_network_adapter", + "windows_firewall_rule", + "windows_installed_software", + "windows_feature", + "windows_hyperv_vm", + "windows_hyperv_switch", + "windows_dns_record", + "windows_local_user", + "windows_local_group", +] + + +class WindowsDiscoveryPlugin(ProviderPlugin): + """Provider plugin for discovering Windows infrastructure via WinRM. + + Connects to Windows machines using pywinrm and discovers resources + through PowerShell commands and WMI queries executed over WinRM. + + Expected credentials dict keys: + host: Target hostname or IP address + username: Windows username (domain\\user or user@domain) + password: Password for authentication + transport: Authentication transport - "ntlm" (default) or "kerberos" + port: WinRM port - "5985" (HTTP) or "5986" (HTTPS, default) + use_ssl: Whether to use SSL - "true" (default) or "false" + """ + + def __init__(self) -> None: + self._session: winrm.Session | None = None + self._host: str = "" + self._credentials: dict[str, str] = {} + + def authenticate(self, credentials: dict[str, str]) -> None: + """Authenticate with the Windows host via WinRM. + + Args: + credentials: Dict with keys: host, username, password, + transport (default "ntlm"), port (default "5986"), + use_ssl (default "true"). + + Raises: + AuthenticationError: If authentication fails. + WinRMNotEnabledError: If WinRM is not reachable. + """ + host = credentials.get("host", "") + username = credentials.get("username", "") + password = credentials.get("password", "") + transport = credentials.get("transport", "ntlm") + port = credentials.get("port", "5986") + use_ssl = credentials.get("use_ssl", "true").lower() == "true" + + if not host: + raise AuthenticationError("windows", "host is required") + if not username: + raise AuthenticationError("windows", "username is required") + if not password: + raise AuthenticationError("windows", "password is required") + + self._host = host + self._credentials = credentials + + scheme = "https" if use_ssl else "http" + endpoint = f"{scheme}://{host}:{port}/wsman" + + try: + self._session = winrm.Session( + endpoint, + auth=(username, password), + transport=transport, + server_cert_validation="ignore" if use_ssl else "validate", + ) + # Test connectivity with a simple command + result = self._session.run_ps("$env:COMPUTERNAME") + if result.status_code != 0: + stderr = result.std_err.decode("utf-8", errors="replace").strip() + if "access" in stderr.lower() or "denied" in stderr.lower(): + raise InsufficientPrivilegesError( + "authenticate", stderr + ) + raise AuthenticationError("windows", stderr or "Authentication test failed") + except AuthenticationError: + raise + except InsufficientPrivilegesError as exc: + raise AuthenticationError("windows", str(exc)) from exc + except WinRMNotEnabledError: + raise + except Exception as exc: + error_msg = str(exc).lower() + if "connection" in error_msg or "refused" in error_msg or "unreachable" in error_msg: + raise WinRMNotEnabledError(host, str(exc)) from exc + raise AuthenticationError("windows", str(exc)) from exc + + def get_platform_category(self) -> PlatformCategory: + """Return PlatformCategory.WINDOWS.""" + return PlatformCategory.WINDOWS + + def list_endpoints(self) -> list[str]: + """Return the single Windows host as the endpoint.""" + return [self._host] if self._host else [] + + def list_supported_resource_types(self) -> list[str]: + """Return all 13 Windows resource types.""" + return list(WINDOWS_RESOURCE_TYPES) + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + """Detect CPU architecture via WMI Win32_Processor query. + + Args: + endpoint: The Windows host to query. + + Returns: + CpuArchitecture enum value. + + Raises: + WMIQueryError: If the WMI query fails. + """ + query = "Get-WmiObject Win32_Processor | Select-Object -First 1 -ExpandProperty Architecture" + result = self._run_powershell(query) + + if result.status_code != 0: + stderr = result.std_err.decode("utf-8", errors="replace").strip() + raise WMIQueryError("Win32_Processor.Architecture", stderr) + + arch_code = result.std_out.decode("utf-8", errors="replace").strip() + + # WMI Architecture codes: + # 0 = x86, 5 = ARM, 9 = x64, 12 = ARM64 + arch_map = { + "0": CpuArchitecture.AMD64, # x86 mapped to amd64 for simplicity + "5": CpuArchitecture.ARM, + "9": CpuArchitecture.AMD64, + "12": CpuArchitecture.AARCH64, + } + + return arch_map.get(arch_code, CpuArchitecture.AMD64) + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover Windows resources via WinRM/PowerShell. + + Args: + endpoints: List of Windows hosts to scan. + resource_types: List of resource type strings to discover. + progress_callback: Callable for progress updates. + + Returns: + ScanResult with discovered resources, warnings, and errors. + """ + all_resources: list[DiscoveredResource] = [] + warnings: list[str] = [] + errors: list[str] = [] + + total_types = len(resource_types) + + for endpoint in endpoints: + # Detect architecture for this endpoint + try: + architecture = self.detect_architecture(endpoint) + except (WMIQueryError, Exception) as exc: + warnings.append( + f"Could not detect architecture for {endpoint}: {exc}. " + f"Defaulting to AMD64." + ) + architecture = CpuArchitecture.AMD64 + + # Check if Hyper-V is installed (needed for hyperv resource types) + hyperv_installed = self._is_hyperv_installed() + + for idx, resource_type in enumerate(resource_types): + try: + # Skip Hyper-V resources if role not installed + if resource_type in ("windows_hyperv_vm", "windows_hyperv_switch"): + if not hyperv_installed: + warnings.append( + f"Skipping {resource_type}: Hyper-V role not installed on {endpoint}" + ) + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(all_resources), + resource_types_completed=idx + 1, + total_resource_types=total_types, + ) + ) + continue + + discovered = self._discover_resource_type( + endpoint, resource_type, architecture + ) + all_resources.extend(discovered) + + except InsufficientPrivilegesError as exc: + errors.append( + f"Insufficient privileges for {resource_type} on {endpoint}: {exc}" + ) + except WMIQueryError as exc: + errors.append( + f"WMI query failed for {resource_type} on {endpoint}: {exc}" + ) + except Exception as exc: + errors.append( + f"Error discovering {resource_type} on {endpoint}: {exc}" + ) + + progress_callback( + ScanProgress( + current_resource_type=resource_type, + resources_discovered=len(all_resources), + resource_types_completed=idx + 1, + total_resource_types=total_types, + ) + ) + + return ScanResult( + resources=all_resources, + warnings=warnings, + errors=errors, + scan_timestamp="", + profile_hash="", + ) + + # ----------------------------------------------------------------------- + # Private helpers + # ----------------------------------------------------------------------- + + def _run_powershell(self, script: str) -> winrm.Response: + """Execute a PowerShell script via WinRM. + + Args: + script: PowerShell script to execute. + + Returns: + winrm.Response object. + + Raises: + WinRMNotEnabledError: If the session is not established. + """ + if self._session is None: + raise WinRMNotEnabledError(self._host, "No active WinRM session") + return self._session.run_ps(script) + + def _run_powershell_json(self, script: str) -> list[dict]: + """Execute a PowerShell script and parse JSON output. + + The script should output ConvertTo-Json formatted data. + + Args: + script: PowerShell script that outputs JSON. + + Returns: + List of dicts parsed from JSON output. + + Raises: + WMIQueryError: If the command fails. + InsufficientPrivilegesError: If access is denied. + """ + result = self._run_powershell(script) + + if result.status_code != 0: + stderr = result.std_err.decode("utf-8", errors="replace").strip() + if "access" in stderr.lower() or "denied" in stderr.lower() or "privilege" in stderr.lower(): + raise InsufficientPrivilegesError(script, stderr) + raise WMIQueryError(script, stderr) + + stdout = result.std_out.decode("utf-8", errors="replace").strip() + if not stdout: + return [] + + try: + data = json.loads(stdout) + if isinstance(data, dict): + return [data] + return data if isinstance(data, list) else [] + except json.JSONDecodeError: + return [] + + def _is_hyperv_installed(self) -> bool: + """Check if the Hyper-V role is installed on the target. + + Returns: + True if Hyper-V is installed, False otherwise. + """ + script = ( + "Get-WindowsFeature -Name Hyper-V | " + "Select-Object -ExpandProperty Installed | " + "ConvertTo-Json" + ) + try: + result = self._run_powershell(script) + if result.status_code != 0: + return False + stdout = result.std_out.decode("utf-8", errors="replace").strip() + return stdout.lower() == "true" + except Exception: + return False + + def _discover_resource_type( + self, + endpoint: str, + resource_type: str, + architecture: CpuArchitecture, + ) -> list[DiscoveredResource]: + """Discover resources of a specific type. + + Args: + endpoint: The Windows host. + resource_type: The resource type to discover. + architecture: Detected CPU architecture. + + Returns: + List of DiscoveredResource objects. + """ + discovery_map = { + "windows_service": self._discover_services, + "windows_scheduled_task": self._discover_scheduled_tasks, + "windows_iis_site": self._discover_iis_sites, + "windows_iis_app_pool": self._discover_iis_app_pools, + "windows_network_adapter": self._discover_network_adapters, + "windows_firewall_rule": self._discover_firewall_rules, + "windows_installed_software": self._discover_installed_software, + "windows_feature": self._discover_windows_features, + "windows_hyperv_vm": self._discover_hyperv_vms, + "windows_hyperv_switch": self._discover_hyperv_switches, + "windows_dns_record": self._discover_dns_records, + "windows_local_user": self._discover_local_users, + "windows_local_group": self._discover_local_groups, + } + + discover_fn = discovery_map.get(resource_type) + if discover_fn is None: + return [] + + return discover_fn(endpoint, architecture) + + def _discover_services( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Windows services.""" + script = ( + "Get-Service | Select-Object Name, DisplayName, Status, StartType | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_service", + unique_id=f"{endpoint}/service/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "display_name": item.get("DisplayName", ""), + "status": str(item.get("Status", "")), + "start_type": str(item.get("StartType", "")), + }, + ) + ) + return resources + + def _discover_scheduled_tasks( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Windows scheduled tasks.""" + script = ( + "Get-ScheduledTask | Where-Object {$_.TaskPath -notlike '\\\\Microsoft\\\\*'} | " + "Select-Object TaskName, TaskPath, State | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("TaskName", "") + task_path = item.get("TaskPath", "\\") + resources.append( + DiscoveredResource( + resource_type="windows_scheduled_task", + unique_id=f"{endpoint}/scheduled_task/{task_path}{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "task_path": task_path, + "state": str(item.get("State", "")), + }, + ) + ) + return resources + + def _discover_iis_sites( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover IIS websites.""" + script = ( + "Import-Module WebAdministration; " + "Get-Website | Select-Object Name, ID, State, PhysicalPath | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_iis_site", + unique_id=f"{endpoint}/iis_site/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "site_id": str(item.get("ID", "")), + "state": str(item.get("State", "")), + "physical_path": item.get("PhysicalPath", ""), + }, + ) + ) + return resources + + def _discover_iis_app_pools( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover IIS application pools.""" + script = ( + "Import-Module WebAdministration; " + "Get-ChildItem IIS:\\AppPools | " + "Select-Object Name, State, ManagedRuntimeVersion | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_iis_app_pool", + unique_id=f"{endpoint}/iis_app_pool/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "state": str(item.get("State", "")), + "managed_runtime_version": item.get("ManagedRuntimeVersion", ""), + }, + ) + ) + return resources + + def _discover_network_adapters( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover network adapters.""" + script = ( + "Get-NetAdapter | Select-Object Name, InterfaceDescription, " + "Status, MacAddress, LinkSpeed | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_network_adapter", + unique_id=f"{endpoint}/network_adapter/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "interface_description": item.get("InterfaceDescription", ""), + "status": str(item.get("Status", "")), + "mac_address": item.get("MacAddress", ""), + "link_speed": item.get("LinkSpeed", ""), + }, + ) + ) + return resources + + def _discover_firewall_rules( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Windows firewall rules.""" + script = ( + "Get-NetFirewallRule | Where-Object {$_.Enabled -eq 'True'} | " + "Select-Object Name, DisplayName, Direction, Action, Profile | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_firewall_rule", + unique_id=f"{endpoint}/firewall_rule/{name}", + name=item.get("DisplayName", name), + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "rule_name": name, + "direction": str(item.get("Direction", "")), + "action": str(item.get("Action", "")), + "profile": str(item.get("Profile", "")), + }, + ) + ) + return resources + + def _discover_installed_software( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover installed software via registry.""" + script = ( + "Get-ItemProperty HKLM:\\Software\\Microsoft\\Windows\\CurrentVersion\\Uninstall\\* | " + "Where-Object {$_.DisplayName -ne $null} | " + "Select-Object DisplayName, DisplayVersion, Publisher, InstallDate | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("DisplayName", "") + resources.append( + DiscoveredResource( + resource_type="windows_installed_software", + unique_id=f"{endpoint}/installed_software/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "version": item.get("DisplayVersion", ""), + "publisher": item.get("Publisher", ""), + "install_date": item.get("InstallDate", ""), + }, + ) + ) + return resources + + def _discover_windows_features( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover installed Windows features.""" + script = ( + "Get-WindowsFeature | Where-Object {$_.Installed -eq $true} | " + "Select-Object Name, DisplayName, FeatureType | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_feature", + unique_id=f"{endpoint}/feature/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "display_name": item.get("DisplayName", ""), + "feature_type": item.get("FeatureType", ""), + }, + ) + ) + return resources + + def _discover_hyperv_vms( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Hyper-V virtual machines.""" + script = ( + "Get-VM | Select-Object Name, VMId, State, " + "MemoryAssigned, ProcessorCount, Generation | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + vm_id = str(item.get("VMId", "")) + resources.append( + DiscoveredResource( + resource_type="windows_hyperv_vm", + unique_id=f"{endpoint}/hyperv_vm/{vm_id}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "vm_id": vm_id, + "state": str(item.get("State", "")), + "memory_assigned": str(item.get("MemoryAssigned", "")), + "processor_count": str(item.get("ProcessorCount", "")), + "generation": str(item.get("Generation", "")), + }, + ) + ) + return resources + + def _discover_hyperv_switches( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover Hyper-V virtual switches.""" + script = ( + "Get-VMSwitch | Select-Object Name, Id, SwitchType, " + "NetAdapterInterfaceDescription | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + switch_id = str(item.get("Id", "")) + resources.append( + DiscoveredResource( + resource_type="windows_hyperv_switch", + unique_id=f"{endpoint}/hyperv_switch/{switch_id}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "switch_id": switch_id, + "switch_type": str(item.get("SwitchType", "")), + "net_adapter": item.get("NetAdapterInterfaceDescription", ""), + }, + ) + ) + return resources + + def _discover_dns_records( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover DNS records from local DNS server.""" + script = ( + "Get-DnsServerZone | ForEach-Object { " + "Get-DnsServerResourceRecord -ZoneName $_.ZoneName " + "-ErrorAction SilentlyContinue } | " + "Select-Object HostName, RecordType, " + "@{N='RecordData';E={$_.RecordData.IPv4Address.IPAddressToString}} | " + "ConvertTo-Json -Depth 3" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + hostname = item.get("HostName", "") + record_type = item.get("RecordType", "") + resources.append( + DiscoveredResource( + resource_type="windows_dns_record", + unique_id=f"{endpoint}/dns_record/{hostname}/{record_type}", + name=f"{hostname} ({record_type})", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "hostname": hostname, + "record_type": record_type, + "record_data": item.get("RecordData", ""), + }, + ) + ) + return resources + + def _discover_local_users( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover local user accounts.""" + script = ( + "Get-LocalUser | Select-Object Name, Enabled, " + "Description, LastLogon | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_local_user", + unique_id=f"{endpoint}/local_user/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "enabled": str(item.get("Enabled", "")), + "description": item.get("Description", ""), + "last_logon": str(item.get("LastLogon", "")), + }, + ) + ) + return resources + + def _discover_local_groups( + self, endpoint: str, architecture: CpuArchitecture + ) -> list[DiscoveredResource]: + """Discover local groups.""" + script = ( + "Get-LocalGroup | Select-Object Name, Description, SID | " + "ConvertTo-Json -Depth 2" + ) + items = self._run_powershell_json(script) + resources = [] + for item in items: + name = item.get("Name", "") + resources.append( + DiscoveredResource( + resource_type="windows_local_group", + unique_id=f"{endpoint}/local_group/{name}", + name=name, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=architecture, + endpoint=endpoint, + attributes={ + "description": item.get("Description", ""), + "sid": str(item.get("SID", "")), + }, + ) + ) + return resources diff --git a/src/iac_reverse/state_builder/__init__.py b/src/iac_reverse/state_builder/__init__.py new file mode 100644 index 0000000..98a99c0 --- /dev/null +++ b/src/iac_reverse/state_builder/__init__.py @@ -0,0 +1,5 @@ +"""State builder module for Terraform state file generation.""" + +from iac_reverse.state_builder.state_builder import StateBuilder + +__all__ = ["StateBuilder"] diff --git a/src/iac_reverse/state_builder/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/state_builder/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..d230ce7 Binary files /dev/null and b/src/iac_reverse/state_builder/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/state_builder/__pycache__/state_builder.cpython-313.pyc b/src/iac_reverse/state_builder/__pycache__/state_builder.cpython-313.pyc new file mode 100644 index 0000000..2763a24 Binary files /dev/null and b/src/iac_reverse/state_builder/__pycache__/state_builder.cpython-313.pyc differ diff --git a/src/iac_reverse/state_builder/state_builder.py b/src/iac_reverse/state_builder/state_builder.py new file mode 100644 index 0000000..d8a4fa6 --- /dev/null +++ b/src/iac_reverse/state_builder/state_builder.py @@ -0,0 +1,332 @@ +"""Terraform state file builder (format version 4). + +Generates a valid Terraform state file that binds generated resource blocks +to their corresponding live infrastructure resources using provider-assigned +unique identifiers. This enables Terraform to recognize existing resources +without attempting to recreate them. +""" + +import logging +import uuid + +from iac_reverse.generator.sanitize import sanitize_identifier +from iac_reverse.models import ( + CodeGenerationResult, + DependencyGraph, + DiscoveredResource, + PROVIDER_SUPPORTED_RESOURCE_TYPES, + ResourceRelationship, + StateEntry, + StateFile, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# All supported resource types across all providers (for state mapping) +# --------------------------------------------------------------------------- + +SUPPORTED_STATE_RESOURCE_TYPES: set[str] = set() +for _types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values(): + SUPPORTED_STATE_RESOURCE_TYPES.update(_types) + + +# --------------------------------------------------------------------------- +# Sensitive attribute patterns +# --------------------------------------------------------------------------- + +SENSITIVE_ATTRIBUTE_PATTERNS = [ + "password", + "secret", + "token", + "key", + "certificate", +] + + +# --------------------------------------------------------------------------- +# StateBuilder +# --------------------------------------------------------------------------- + + +class StateBuilder: + """Builds Terraform state files (format v4) from code generation results. + + Accepts a CodeGenerationResult, DependencyGraph, and provider_version string. + Produces a StateFile with version=4, unique UUID lineage, serial=1, and + state entries for each resource in the dependency graph. + + Resources that cannot be mapped (missing provider-assigned identifier or + unrecognized resource type) are excluded from the state file and tracked + in the ``unmapped_resources`` attribute. + """ + + def __init__(self, terraform_version: str = "1.7.0") -> None: + """Initialize the StateBuilder. + + Args: + terraform_version: The Terraform version string to embed in the + state file. Defaults to "1.7.0". + """ + self._terraform_version = terraform_version + self._unmapped_resources: list[tuple[str, str]] = [] + + @property + def unmapped_resources(self) -> list[tuple[str, str]]: + """Return the list of unmapped resources from the last build. + + Each entry is a tuple of (resource_identifier, reason) where + resource_identifier is a string combining type and name, and + reason explains why the resource was excluded. + """ + return list(self._unmapped_resources) + + def _is_mappable(self, resource: DiscoveredResource) -> tuple[bool, str]: + """Check whether a resource can be mapped to a state entry. + + A resource is unmappable if: + - Its unique_id is empty, None, or whitespace-only (missing + provider-assigned identifier) + - Its resource_type is not recognized/supported for state mapping + + Args: + resource: The DiscoveredResource to check. + + Returns: + A tuple of (is_mappable, reason). If mappable, reason is empty. + """ + # Check for missing provider-assigned identifier + if not resource.unique_id or not resource.unique_id.strip(): + return ( + False, + "missing provider-assigned resource identifier (empty unique_id)", + ) + + # Check for unrecognized resource type + if resource.resource_type not in SUPPORTED_STATE_RESOURCE_TYPES: + return ( + False, + f"resource type '{resource.resource_type}' is not recognized " + f"for state mapping", + ) + + return (True, "") + + def build( + self, + code_result: CodeGenerationResult, + graph: DependencyGraph, + provider_version: str, + ) -> StateFile: + """Build a Terraform state file from generated code and dependency graph. + + Resources that cannot be mapped are excluded from the state file. + Warnings are logged for each unmapped resource, and the list of + unmapped resources is available via the ``unmapped_resources`` property. + + Args: + code_result: The result of code generation (used for context). + graph: The DependencyGraph containing resources and relationships. + provider_version: The provider version string used to set + schema_version on state entries. + + Returns: + A StateFile instance ready for serialization via to_json(). + """ + # Reset unmapped resources tracking for this build + self._unmapped_resources = [] + + # Build lookup maps for dependency resolution + resource_map: dict[str, DiscoveredResource] = { + r.unique_id: r for r in graph.resources if r.unique_id + } + + # Build relationships by source for dependency lookup + relationships_by_source: dict[str, list[ResourceRelationship]] = {} + for rel in graph.relationships: + relationships_by_source.setdefault(rel.source_id, []).append(rel) + + # Parse schema version from provider_version string + schema_version = self._parse_schema_version(provider_version) + + # Build state entries for each resource, skipping unmappable ones + entries: list[StateEntry] = [] + for resource in graph.resources: + mappable, reason = self._is_mappable(resource) + if not mappable: + resource_identifier = ( + f"{resource.resource_type}.{resource.name}" + ) + logger.warning( + "Excluding resource '%s' from state file: %s", + resource_identifier, + reason, + ) + self._unmapped_resources.append( + (resource_identifier, reason) + ) + continue + + entry = self._build_state_entry( + resource=resource, + resource_map=resource_map, + relationships_by_source=relationships_by_source, + schema_version=schema_version, + ) + entries.append(entry) + + # Generate unique lineage UUID + lineage = str(uuid.uuid4()) + + return StateFile( + version=4, + terraform_version=self._terraform_version, + serial=1, + lineage=lineage, + resources=entries, + ) + + def _build_state_entry( + self, + resource: DiscoveredResource, + resource_map: dict[str, DiscoveredResource], + relationships_by_source: dict[str, list[ResourceRelationship]], + schema_version: int, + ) -> StateEntry: + """Build a single state entry for a discovered resource. + + Args: + resource: The DiscoveredResource to create a state entry for. + resource_map: Map of unique_id -> DiscoveredResource for lookups. + relationships_by_source: Map of source_id -> relationships. + schema_version: The schema version to set on the entry. + + Returns: + A StateEntry binding the resource to its live infrastructure ID. + """ + # Sanitize the resource name for Terraform identifier + resource_name = sanitize_identifier(resource.name) + + # Get full attribute set from discovery data + attributes = dict(resource.attributes) + + # Identify sensitive attributes + sensitive_attributes = self._identify_sensitive_attributes(attributes) + + # Build dependency references as Terraform resource addresses + dependencies = self._build_dependencies( + resource, resource_map, relationships_by_source + ) + + return StateEntry( + resource_type=resource.resource_type, + resource_name=resource_name, + provider_id=resource.unique_id, + attributes=attributes, + sensitive_attributes=sensitive_attributes, + schema_version=schema_version, + dependencies=dependencies, + ) + + def _identify_sensitive_attributes( + self, attributes: dict + ) -> list[str]: + """Identify attributes that should be marked as sensitive. + + Checks attribute keys against known sensitive patterns: + password, secret, token, key, certificate. + + Args: + attributes: The full attribute dictionary. + + Returns: + List of attribute key paths that are sensitive. + """ + sensitive: list[str] = [] + self._find_sensitive_keys(attributes, "", sensitive) + return sensitive + + def _find_sensitive_keys( + self, obj: object, prefix: str, sensitive: list[str] + ) -> None: + """Recursively find sensitive attribute keys in nested structures. + + Args: + obj: The current object to inspect (dict, list, or scalar). + prefix: The current key path prefix. + sensitive: Accumulator list for sensitive key paths. + """ + if isinstance(obj, dict): + for key, value in obj.items(): + current_path = f"{prefix}.{key}" if prefix else key + key_lower = key.lower() + if any( + pattern in key_lower + for pattern in SENSITIVE_ATTRIBUTE_PATTERNS + ): + sensitive.append(current_path) + # Recurse into nested dicts + if isinstance(value, dict): + self._find_sensitive_keys(value, current_path, sensitive) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + self._find_sensitive_keys( + item, f"{current_path}[{i}]", sensitive + ) + + def _build_dependencies( + self, + resource: DiscoveredResource, + resource_map: dict[str, DiscoveredResource], + relationships_by_source: dict[str, list[ResourceRelationship]], + ) -> list[str]: + """Build Terraform resource address references for dependencies. + + Converts relationship targets into Terraform resource addresses + of the form: resource_type.resource_name + + Args: + resource: The source resource. + resource_map: Map of unique_id -> DiscoveredResource. + relationships_by_source: Map of source_id -> relationships. + + Returns: + List of Terraform resource addresses for dependencies. + """ + dependencies: list[str] = [] + rels = relationships_by_source.get(resource.unique_id, []) + + for rel in rels: + target = resource_map.get(rel.target_id) + if target is not None: + target_tf_name = sanitize_identifier(target.name) + address = f"{target.resource_type}.{target_tf_name}" + if address not in dependencies: + dependencies.append(address) + + return dependencies + + def _parse_schema_version(self, provider_version: str) -> int: + """Parse a schema version integer from the provider version string. + + Extracts the major version number from a semver-like string. + For example, "3.2.1" returns 3, "1" returns 1. + + Args: + provider_version: A version string (e.g., "3.2.1", "1.0.0"). + + Returns: + The major version number as an integer, or 0 if parsing fails. + """ + try: + # Take the first numeric segment as the schema version + parts = provider_version.strip().split(".") + return int(parts[0]) + except (ValueError, IndexError): + logger.warning( + "Could not parse schema version from '%s', defaulting to 0", + provider_version, + ) + return 0 diff --git a/src/iac_reverse/validator/__init__.py b/src/iac_reverse/validator/__init__.py new file mode 100644 index 0000000..6bbb89f --- /dev/null +++ b/src/iac_reverse/validator/__init__.py @@ -0,0 +1,5 @@ +"""Validator module for Terraform output validation.""" + +from iac_reverse.validator.validator import Validator + +__all__ = ["Validator"] diff --git a/src/iac_reverse/validator/__pycache__/__init__.cpython-313.pyc b/src/iac_reverse/validator/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..fd15073 Binary files /dev/null and b/src/iac_reverse/validator/__pycache__/__init__.cpython-313.pyc differ diff --git a/src/iac_reverse/validator/__pycache__/validator.cpython-313.pyc b/src/iac_reverse/validator/__pycache__/validator.cpython-313.pyc new file mode 100644 index 0000000..35e911e Binary files /dev/null and b/src/iac_reverse/validator/__pycache__/validator.cpython-313.pyc differ diff --git a/src/iac_reverse/validator/validator.py b/src/iac_reverse/validator/validator.py new file mode 100644 index 0000000..b8a9131 --- /dev/null +++ b/src/iac_reverse/validator/validator.py @@ -0,0 +1,653 @@ +"""Terraform validation runner. + +Runs terraform init, validate, and plan against generated output +to verify syntactic correctness and detect infrastructure drift. +Includes auto-correction logic that attempts to fix common validation +errors heuristically. +""" + +import json +import re +import shutil +import subprocess +from pathlib import Path + +from iac_reverse.models import PlannedChange, ValidationError, ValidationResult + + +class Validator: + """Runs Terraform commands to validate generated IaC output. + + Validates generated .tf and .tfstate files by running terraform init, + terraform validate, and terraform plan. Reports validation errors and + planned changes (drift) back to the caller. + + When validation fails, attempts heuristic-based auto-corrections up to + max_correction_attempts times before reporting failure. + """ + + def validate( + self, output_dir: str, max_correction_attempts: int = 3 + ) -> ValidationResult: + """Run terraform init, validate, and plan against the output directory. + + After terraform validate fails, attempts auto-correction of common + errors (unknown attributes, missing required blocks, syntax issues) + up to max_correction_attempts times. Re-validates after each correction. + + Args: + output_dir: Path to directory containing generated .tf and .tfstate files. + max_correction_attempts: Maximum number of auto-correction attempts + before reporting failure. Defaults to 3. + + Returns: + ValidationResult with init/validate/plan success flags, + any planned changes (drift), validation errors, and the number + of correction attempts made. + """ + # Check terraform binary availability + terraform_bin = shutil.which("terraform") + if terraform_bin is None: + return ValidationResult( + init_success=False, + validate_success=False, + plan_success=False, + errors=[ + ValidationError( + file="", + message=( + "Terraform binary not found. " + "Terraform is required for validation. " + "Please install Terraform and ensure it is on your PATH." + ), + ) + ], + correction_attempts=0, + ) + + output_path = Path(output_dir) + errors: list[ValidationError] = [] + planned_changes: list[PlannedChange] = [] + + # Run terraform init + init_success = self._run_init(output_path, errors) + if not init_success: + return ValidationResult( + init_success=False, + validate_success=False, + plan_success=False, + errors=errors, + correction_attempts=0, + ) + + # Run terraform validate with auto-correction loop + correction_attempts = 0 + validate_success = self._run_validate(output_path, errors) + + while not validate_success and correction_attempts < max_correction_attempts: + # Attempt to correct the errors + corrected = self._attempt_correction(output_path, errors) + + if not corrected: + # No corrections could be applied, stop trying + break + + correction_attempts += 1 + + # Re-validate after correction + errors = [] + validate_success = self._run_validate(output_path, errors) + + if not validate_success: + return ValidationResult( + init_success=True, + validate_success=False, + plan_success=False, + errors=errors, + correction_attempts=correction_attempts, + ) + + # Run terraform plan + plan_success = self._run_plan(output_path, errors, planned_changes) + + return ValidationResult( + init_success=True, + validate_success=True, + plan_success=plan_success, + planned_changes=planned_changes, + errors=errors, + correction_attempts=correction_attempts, + ) + + def _run_init( + self, output_path: Path, errors: list[ValidationError] + ) -> bool: + """Run terraform init in the output directory. + + Returns True if init succeeds, False otherwise. + """ + try: + result = subprocess.run( + ["terraform", "init", "-no-color"], + cwd=str(output_path), + capture_output=True, + text=True, + timeout=120, + ) + if result.returncode != 0: + errors.append( + ValidationError( + file="", + message=f"terraform init failed: {result.stderr.strip()}", + ) + ) + return False + return True + except subprocess.TimeoutExpired: + errors.append( + ValidationError( + file="", + message="terraform init timed out after 120 seconds", + ) + ) + return False + except OSError as e: + errors.append( + ValidationError( + file="", + message=f"Failed to execute terraform init: {e}", + ) + ) + return False + + def _run_validate( + self, output_path: Path, errors: list[ValidationError] + ) -> bool: + """Run terraform validate with JSON output and parse errors. + + Returns True if validation passes, False otherwise. + """ + try: + result = subprocess.run( + ["terraform", "validate", "-json"], + cwd=str(output_path), + capture_output=True, + text=True, + timeout=60, + ) + return self._parse_validate_output(result.stdout, errors) + except subprocess.TimeoutExpired: + errors.append( + ValidationError( + file="", + message="terraform validate timed out after 60 seconds", + ) + ) + return False + except OSError as e: + errors.append( + ValidationError( + file="", + message=f"Failed to execute terraform validate: {e}", + ) + ) + return False + + def _parse_validate_output( + self, stdout: str, errors: list[ValidationError] + ) -> bool: + """Parse terraform validate JSON output. + + Expected format: + { + "valid": true/false, + "error_count": N, + "diagnostics": [ + { + "severity": "error", + "summary": "...", + "detail": "...", + "range": { + "filename": "main.tf", + "start": {"line": 1, "column": 1}, + ... + } + } + ] + } + """ + try: + data = json.loads(stdout) + except (json.JSONDecodeError, TypeError): + errors.append( + ValidationError( + file="", + message="Failed to parse terraform validate output as JSON", + ) + ) + return False + + if data.get("valid", False): + return True + + diagnostics = data.get("diagnostics", []) + for diag in diagnostics: + if diag.get("severity") != "error": + continue + + filename = "" + line = None + range_info = diag.get("range") + if range_info: + filename = range_info.get("filename", "") + start = range_info.get("start") + if start: + line = start.get("line") + + summary = diag.get("summary", "") + detail = diag.get("detail", "") + message = summary + if detail: + message = f"{summary}: {detail}" + + errors.append( + ValidationError(file=filename, message=message, line=line) + ) + + return False + + def _run_plan( + self, + output_path: Path, + errors: list[ValidationError], + planned_changes: list[PlannedChange], + ) -> bool: + """Run terraform plan with JSON output and parse planned changes. + + Returns True if zero changes are planned, False otherwise. + """ + try: + result = subprocess.run( + ["terraform", "plan", "-json", "-no-color"], + cwd=str(output_path), + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode not in (0, 2): + # returncode 2 means changes are planned, which is valid output + errors.append( + ValidationError( + file="", + message=f"terraform plan failed: {result.stderr.strip()}", + ) + ) + return False + + return self._parse_plan_output( + result.stdout, errors, planned_changes + ) + except subprocess.TimeoutExpired: + errors.append( + ValidationError( + file="", + message="terraform plan timed out after 300 seconds", + ) + ) + return False + except OSError as e: + errors.append( + ValidationError( + file="", + message=f"Failed to execute terraform plan: {e}", + ) + ) + return False + + def _parse_plan_output( + self, + stdout: str, + errors: list[ValidationError], + planned_changes: list[PlannedChange], + ) -> bool: + """Parse terraform plan JSON output (streaming JSON lines format). + + Terraform plan -json outputs one JSON object per line. We look for + lines with type "resource_drift" or "planned_change" to identify + changes, and "change_summary" for the overall result. + + Each resource change line looks like: + { + "type": "planned_change", + "change": { + "resource": { + "addr": "aws_instance.example" + }, + "action": "create" | "update" | "delete" + } + } + """ + has_changes = False + + for line in stdout.strip().splitlines(): + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + except json.JSONDecodeError: + continue + + entry_type = entry.get("type", "") + + if entry_type in ("planned_change", "resource_drift"): + change = entry.get("change", {}) + resource = change.get("resource", {}) + resource_addr = resource.get("addr", "unknown") + action = change.get("action", "unknown") + + # Map terraform action names to our change types + change_type = self._map_action_to_change_type(action) + + # Build details from before/after if available + details = f"Action: {action}" + + planned_changes.append( + PlannedChange( + resource_address=resource_addr, + change_type=change_type, + details=details, + ) + ) + has_changes = True + + elif entry_type == "change_summary": + changes_info = entry.get("changes", {}) + add = changes_info.get("add", 0) + change = changes_info.get("change", 0) + remove = changes_info.get("remove", 0) + if add + change + remove > 0: + has_changes = True + + # plan_success is True only when there are zero planned changes + return not has_changes + + @staticmethod + def _map_action_to_change_type(action: str) -> str: + """Map terraform plan action to our change type vocabulary.""" + action_map = { + "create": "add", + "update": "modify", + "delete": "destroy", + "replace": "modify", + "read": "add", + } + return action_map.get(action, action) + + # ------------------------------------------------------------------ + # Auto-correction logic + # ------------------------------------------------------------------ + + def _attempt_correction( + self, output_path: Path, errors: list[ValidationError] + ) -> bool: + """Attempt to auto-correct validation errors using heuristics. + + Applies corrections for: + - Unknown/unsupported attributes (removes the offending line) + - Missing required provider blocks (adds empty provider block) + - Common syntax issues (unclosed braces, trailing commas) + + Args: + output_path: Path to the directory containing .tf files. + errors: List of validation errors to attempt to correct. + + Returns: + True if at least one correction was applied, False otherwise. + """ + any_corrected = False + + for error in errors: + corrected = self._correct_single_error(output_path, error) + if corrected: + any_corrected = True + + return any_corrected + + def _correct_single_error( + self, output_path: Path, error: ValidationError + ) -> bool: + """Attempt to correct a single validation error. + + Returns True if a correction was applied. + """ + message = error.message.lower() + + # Handle unknown/unsupported attribute errors + if self._is_unknown_attribute_error(message): + return self._remove_attribute_line(output_path, error) + + # Handle missing required provider block + if self._is_missing_provider_error(message): + return self._add_missing_provider_block(output_path, error) + + # Handle syntax errors (unclosed braces, trailing commas) + if self._is_syntax_error(message): + return self._fix_syntax_error(output_path, error) + + return False + + @staticmethod + def _is_unknown_attribute_error(message: str) -> bool: + """Check if the error is about an unknown or unsupported attribute.""" + patterns = [ + "unsupported argument", + "unsupported attribute", + "unknown attribute", + "an argument named", + "is not expected here", + "no such attribute", + ] + return any(p in message for p in patterns) + + @staticmethod + def _is_missing_provider_error(message: str) -> bool: + """Check if the error is about a missing required provider.""" + patterns = [ + "missing required provider", + "provider configuration not present", + "no provider", + "required provider", + ] + return any(p in message for p in patterns) + + @staticmethod + def _is_syntax_error(message: str) -> bool: + """Check if the error is a syntax error that might be fixable.""" + patterns = [ + "unexpected closing brace", + "unclosed configuration block", + "expected closing brace", + "invalid character", + "trailing comma", + "argument or block definition required", + ] + return any(p in message for p in patterns) + + def _remove_attribute_line( + self, output_path: Path, error: ValidationError + ) -> bool: + """Remove the line containing an unknown/unsupported attribute. + + If the error has file and line info, removes that specific line. + Otherwise, attempts to find and remove the attribute by name from + the error message. + """ + if not error.file: + return False + + file_path = output_path / error.file + if not file_path.exists(): + return False + + try: + lines = file_path.read_text(encoding="utf-8").splitlines() + except OSError: + return False + + if error.line is not None and 1 <= error.line <= len(lines): + # Remove the specific line + line_idx = error.line - 1 + removed_line = lines[line_idx].strip() + + # Only remove if it looks like an attribute assignment + if "=" in removed_line or removed_line.endswith("{"): + lines.pop(line_idx) + try: + file_path.write_text( + "\n".join(lines) + "\n", encoding="utf-8" + ) + return True + except OSError: + return False + + # Try to find the attribute name from the error message + attr_name = self._extract_attribute_name(error.message) + if attr_name: + return self._remove_attribute_by_name(file_path, attr_name, lines) + + return False + + @staticmethod + def _extract_attribute_name(message: str) -> str: + """Extract the attribute name from an error message. + + Looks for patterns like: + - "An argument named 'foo' is not expected here" + - "Unsupported argument: foo" + """ + # Pattern: quoted attribute name + match = re.search(r"['\"](\w+)['\"]", message) + if match: + return match.group(1) + + # Pattern: "named X is not" + match = re.search(r"named\s+(\w+)\s+is", message) + if match: + return match.group(1) + + return "" + + @staticmethod + def _remove_attribute_by_name( + file_path: Path, attr_name: str, lines: list[str] + ) -> bool: + """Remove lines containing the given attribute assignment.""" + pattern = re.compile(rf"^\s*{re.escape(attr_name)}\s*=") + new_lines = [line for line in lines if not pattern.match(line)] + + if len(new_lines) == len(lines): + return False # Nothing was removed + + try: + file_path.write_text("\n".join(new_lines) + "\n", encoding="utf-8") + return True + except OSError: + return False + + def _add_missing_provider_block( + self, output_path: Path, error: ValidationError + ) -> bool: + """Add a missing provider block to the configuration. + + Extracts the provider name from the error message and creates + an empty provider block in a providers.tf file. + """ + provider_name = self._extract_provider_name(error.message) + if not provider_name: + return False + + providers_file = output_path / "providers.tf" + provider_block = f'\nprovider "{provider_name}" {{}}\n' + + try: + if providers_file.exists(): + existing = providers_file.read_text(encoding="utf-8") + # Don't add if already present + if f'provider "{provider_name}"' in existing: + return False + providers_file.write_text( + existing + provider_block, encoding="utf-8" + ) + else: + providers_file.write_text(provider_block, encoding="utf-8") + return True + except OSError: + return False + + @staticmethod + def _extract_provider_name(message: str) -> str: + """Extract provider name from a missing provider error message. + + Looks for patterns like: + - "Missing required provider 'aws'" + - 'provider "kubernetes" configuration not present' + """ + match = re.search(r"provider\s+['\"](\w+)['\"]", message) + if match: + return match.group(1) + + match = re.search(r"['\"](\w+)['\"]", message) + if match: + return match.group(1) + + return "" + + def _fix_syntax_error( + self, output_path: Path, error: ValidationError + ) -> bool: + """Attempt to fix common syntax errors. + + Handles: + - Trailing commas before closing braces + - Missing closing braces + - Lines with 'argument or block definition required' (remove empty/bad lines) + """ + if not error.file: + return False + + file_path = output_path / error.file + if not file_path.exists(): + return False + + try: + content = file_path.read_text(encoding="utf-8") + except OSError: + return False + + original_content = content + + # Fix trailing commas before closing braces/brackets + content = re.sub(r",(\s*[}\]])", r"\1", content) + + # Fix 'argument or block definition required' - remove empty lines + # at the error location + if error.line is not None and "argument or block definition required" in error.message.lower(): + lines = content.splitlines() + if 1 <= error.line <= len(lines): + line_idx = error.line - 1 + line = lines[line_idx].strip() + # Remove the problematic line if it's empty or just whitespace/punctuation + if not line or line in (",", ";"): + lines.pop(line_idx) + content = "\n".join(lines) + "\n" + + if content != original_content: + try: + file_path.write_text(content, encoding="utf-8") + return True + except OSError: + return False + + return False diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..9135517 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for IaC Reverse Engineering Tool.""" diff --git a/tests/__pycache__/__init__.cpython-313.pyc b/tests/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..f5c61d8 Binary files /dev/null and b/tests/__pycache__/__init__.cpython-313.pyc differ diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..bb496ee --- /dev/null +++ b/tests/integration/__init__.py @@ -0,0 +1 @@ +"""Integration tests for IaC Reverse Engineering Tool.""" diff --git a/tests/property/__init__.py b/tests/property/__init__.py new file mode 100644 index 0000000..2cd0eaa --- /dev/null +++ b/tests/property/__init__.py @@ -0,0 +1 @@ +"""Property-based tests for IaC Reverse Engineering Tool.""" diff --git a/tests/property/__pycache__/__init__.cpython-313.pyc b/tests/property/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..12b29ab Binary files /dev/null and b/tests/property/__pycache__/__init__.cpython-313.pyc differ diff --git a/tests/property/__pycache__/test_code_generator_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_code_generator_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..dcca31b Binary files /dev/null and b/tests/property/__pycache__/test_code_generator_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_dependency_resolver_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_dependency_resolver_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..05b0009 Binary files /dev/null and b/tests/property/__pycache__/test_dependency_resolver_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_drift_report_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_drift_report_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..a7c9004 Binary files /dev/null and b/tests/property/__pycache__/test_drift_report_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_incremental_scan_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_incremental_scan_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..07544a8 Binary files /dev/null and b/tests/property/__pycache__/test_incremental_scan_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_multi_provider_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_multi_provider_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..646e8fc Binary files /dev/null and b/tests/property/__pycache__/test_multi_provider_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_resource_inventory_completeness_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_resource_inventory_completeness_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..7d33ff4 Binary files /dev/null and b/tests/property/__pycache__/test_resource_inventory_completeness_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_scan_profile_validation_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_scan_profile_validation_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..e6f752c Binary files /dev/null and b/tests/property/__pycache__/test_scan_profile_validation_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_scanner_behavior_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_scanner_behavior_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..f5f4bfa Binary files /dev/null and b/tests/property/__pycache__/test_scanner_behavior_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/__pycache__/test_state_builder_prop.cpython-313-pytest-9.0.3.pyc b/tests/property/__pycache__/test_state_builder_prop.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..7e71432 Binary files /dev/null and b/tests/property/__pycache__/test_state_builder_prop.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/property/test_code_generator_prop.py b/tests/property/test_code_generator_prop.py new file mode 100644 index 0000000..36639af --- /dev/null +++ b/tests/property/test_code_generator_prop.py @@ -0,0 +1,719 @@ +"""Property-based tests for the Code Generator. + +**Validates: Requirements 2.2, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6** + +Properties tested: +- Property 10: References in generated output use Terraform syntax +- Property 11: Generated HCL syntactic validity +- Property 12: File organization by resource type +- Property 13: Variable extraction for shared values +- Property 14: Identifier sanitization validity +- Property 15: Traceability comments in generated code +""" + +import re + +from hypothesis import given, settings, assume, HealthCheck +from hypothesis import strategies as st + +from iac_reverse.generator import CodeGenerator, VariableExtractor, sanitize_identifier +from iac_reverse.models import ( + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + PlatformCategory, + ProviderType, + ResourceRelationship, + ScanProfile, +) + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) +platform_category_strategy = st.sampled_from(list(PlatformCategory)) +cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture)) + +# Strategy for resource names (valid identifiers with some variety) +resource_name_strategy = st.text( + min_size=1, + max_size=20, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"), +).filter(lambda s: s.strip() != "") + +# Strategy for resource types (terraform-style: provider_type) +resource_type_strategy = st.sampled_from([ + "kubernetes_deployment", + "kubernetes_service", + "kubernetes_namespace", + "docker_service", + "docker_network", + "docker_volume", + "synology_shared_folder", + "synology_volume", + "harvester_virtualmachine", + "harvester_volume", + "bare_metal_hardware", + "windows_service", + "windows_iis_site", +]) + +# Strategy for simple attribute values (strings, ints, bools) +simple_attr_value_strategy = st.one_of( + st.text(min_size=1, max_size=30, alphabet=st.characters( + whitelist_categories=("L", "N"), whitelist_characters="_-./: " + )).filter(lambda s: s.strip() != ""), + st.integers(min_value=0, max_value=10000), + st.booleans(), +) + +# Strategy for attribute dictionaries +attributes_strategy = st.dictionaries( + keys=st.text( + min_size=1, + max_size=15, + alphabet=st.characters(whitelist_categories=("L",), whitelist_characters="_"), + ).filter(lambda s: s.strip() != "" and s[0].isalpha()), + values=simple_attr_value_strategy, + min_size=1, + max_size=5, +) + + +def make_resource( + unique_id: str, + resource_type: str = "kubernetes_deployment", + name: str = "my_resource", + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, + architecture: CpuArchitecture = CpuArchitecture.AMD64, + attributes: dict | None = None, + raw_references: list[str] | None = None, +) -> DiscoveredResource: + """Helper to create a DiscoveredResource with sensible defaults.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=architecture, + endpoint="https://api.internal.lab:6443", + attributes=attributes or {"key": "value"}, + raw_references=raw_references or [], + ) + + +def make_dependency_graph( + resources: list[DiscoveredResource], + relationships: list[ResourceRelationship] | None = None, +) -> DependencyGraph: + """Helper to create a DependencyGraph from resources.""" + return DependencyGraph( + resources=resources, + relationships=relationships or [], + topological_order=[r.unique_id for r in resources], + cycles=[], + unresolved_references=[], + ) + + +@st.composite +def resource_with_dependency_strategy(draw): + """Generate a pair of resources where one depends on the other. + + Returns (resources, relationships) where the first resource references the second. + """ + resource_type_a = draw(resource_type_strategy) + resource_type_b = draw(resource_type_strategy) + name_a = draw(resource_name_strategy) + name_b = draw(resource_name_strategy) + arch = draw(cpu_architecture_strategy) + + # Ensure unique IDs are different + uid_a = f"ns/{resource_type_a}/{name_a}" + uid_b = f"ns/{resource_type_b}/{name_b}" + assume(uid_a != uid_b) + + # Resource B is the dependency target + resource_b = make_resource( + unique_id=uid_b, + resource_type=resource_type_b, + name=name_b, + architecture=arch, + attributes={"port": 8080}, + ) + + # Resource A references resource B's unique_id in its attributes + resource_a = make_resource( + unique_id=uid_a, + resource_type=resource_type_a, + name=name_a, + architecture=arch, + attributes={"target_id": uid_b, "replicas": 3}, + raw_references=[uid_b], + ) + + relationship = ResourceRelationship( + source_id=uid_a, + target_id=uid_b, + relationship_type="reference", + source_attribute="target_id", + ) + + return [resource_a, resource_b], [relationship] + + +@st.composite +def multiple_resources_strategy(draw): + """Generate a list of resources with distinct types for file organization testing.""" + num_types = draw(st.integers(min_value=1, max_value=5)) + types = draw( + st.lists( + resource_type_strategy, + min_size=num_types, + max_size=num_types, + unique=True, + ) + ) + + resources = [] + for i, rtype in enumerate(types): + # Each type gets 1-3 resources + num_resources_of_type = draw(st.integers(min_value=1, max_value=3)) + for j in range(num_resources_of_type): + uid = f"{rtype}/instance_{i}_{j}" + name = f"res_{i}_{j}" + attrs = draw(attributes_strategy) + resource = make_resource( + unique_id=uid, + resource_type=rtype, + name=name, + attributes=attrs, + ) + resources.append(resource) + + return resources + + +@st.composite +def resources_with_shared_values_strategy(draw): + """Generate resources where at least one attribute value appears in 2+ resources.""" + shared_key = draw(st.sampled_from(["region", "environment", "zone", "cluster"])) + shared_value = draw(st.text( + min_size=1, + max_size=15, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"), + ).filter(lambda s: s.strip() != "")) + + num_resources = draw(st.integers(min_value=2, max_value=5)) + + resources = [] + for i in range(num_resources): + uid = f"resource_{i}" + name = f"res_{i}" + # All resources share the same key-value pair + attrs = {shared_key: shared_value, "name": f"instance_{i}"} + resource = make_resource( + unique_id=uid, + resource_type="kubernetes_deployment", + name=name, + attributes=attrs, + ) + resources.append(resource) + + return resources, shared_key, shared_value + + +# Strategy for arbitrary strings to test sanitize_identifier +arbitrary_string_strategy = st.text(min_size=0, max_size=50) + + +# --------------------------------------------------------------------------- +# Property 10: References in generated output use Terraform syntax +# --------------------------------------------------------------------------- + + +class TestReferencesUseTerraformSyntax: + """Property 10: References in generated output use Terraform syntax. + + **Validates: Requirements 2.2, 3.5** + + For any resource with dependencies, the generated HCL uses Terraform + resource references (type.name.id) not hardcoded IDs. + """ + + @given(data=resource_with_dependency_strategy()) + @settings(max_examples=100) + def test_references_use_terraform_resource_syntax( + self, data: tuple[list[DiscoveredResource], list[ResourceRelationship]] + ): + """Generated HCL uses type.name.id references instead of hardcoded IDs.""" + resources, relationships = data + graph = make_dependency_graph(resources, relationships) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + # The source resource (resources[0]) references resources[1] + target = resources[1] + target_tf_name = sanitize_identifier(target.name) + expected_ref = f"{target.resource_type}.{target_tf_name}.id" + + # Find the file containing the source resource + source = resources[0] + source_file = None + for f in result.resource_files: + if f.filename == f"{source.resource_type}.tf": + source_file = f + break + + assert source_file is not None, ( + f"Expected file {source.resource_type}.tf not found" + ) + + # The generated content should contain the Terraform reference + assert expected_ref in source_file.content, ( + f"Expected Terraform reference '{expected_ref}' not found in output. " + f"Content: {source_file.content[:500]}" + ) + + @given(data=resource_with_dependency_strategy()) + @settings(max_examples=100) + def test_hardcoded_ids_not_present_for_resolved_references( + self, data: tuple[list[DiscoveredResource], list[ResourceRelationship]] + ): + """The target resource's unique_id should not appear as a hardcoded string in the source resource's block.""" + resources, relationships = data + graph = make_dependency_graph(resources, relationships) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + target = resources[1] + source = resources[0] + + # Find the file containing the source resource + source_file = None + for f in result.resource_files: + if f.filename == f"{source.resource_type}.tf": + source_file = f + break + + assert source_file is not None + + # The hardcoded unique_id of the target should NOT appear as a quoted string + hardcoded_pattern = f'"{target.unique_id}"' + assert hardcoded_pattern not in source_file.content, ( + f"Hardcoded ID '{hardcoded_pattern}' should not appear in generated HCL. " + f"Should use Terraform reference instead." + ) + + +# --------------------------------------------------------------------------- +# Property 11: Generated HCL syntactic validity +# --------------------------------------------------------------------------- + + +class TestGeneratedHclSyntacticValidity: + """Property 11: Generated HCL syntactic validity. + + **Validates: Requirements 3.1** + + For any set of resources, the generated HCL contains valid resource blocks + with proper structure (resource keyword, type, name, braces). + """ + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow]) + def test_generated_hcl_has_valid_resource_blocks( + self, resources: list[DiscoveredResource] + ): + """Each generated file contains properly structured resource blocks.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + for gen_file in result.resource_files: + content = gen_file.content + + # Each resource block should have the pattern: + # resource "type" "name" { + resource_block_pattern = re.compile( + r'resource\s+"[^"]+"\s+"[^"]+"\s*\{' + ) + blocks_found = resource_block_pattern.findall(content) + assert len(blocks_found) == gen_file.resource_count, ( + f"Expected {gen_file.resource_count} resource blocks in " + f"{gen_file.filename}, found {len(blocks_found)}" + ) + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_generated_hcl_has_balanced_braces( + self, resources: list[DiscoveredResource] + ): + """Generated HCL has balanced opening and closing braces.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + for gen_file in result.resource_files: + content = gen_file.content + open_braces = content.count("{") + close_braces = content.count("}") + assert open_braces == close_braces, ( + f"Unbalanced braces in {gen_file.filename}: " + f"{open_braces} opening vs {close_braces} closing" + ) + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_generated_hcl_resource_type_matches_filename( + self, resources: list[DiscoveredResource] + ): + """Each resource block's type matches the file it's in (filename = type.tf).""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + for gen_file in result.resource_files: + expected_type = gen_file.filename.replace(".tf", "") + # All resource blocks in this file should be of the expected type + resource_types_in_file = re.findall( + r'resource\s+"([^"]+)"', gen_file.content + ) + for rtype in resource_types_in_file: + assert rtype == expected_type, ( + f"Resource type '{rtype}' found in {gen_file.filename} " + f"but expected only '{expected_type}'" + ) + + +# --------------------------------------------------------------------------- +# Property 12: File organization by resource type +# --------------------------------------------------------------------------- + + +class TestFileOrganizationByResourceType: + """Property 12: File organization by resource type. + + **Validates: Requirements 3.2** + + For any set of resources, each resource type gets its own .tf file. + """ + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_one_file_per_resource_type( + self, resources: list[DiscoveredResource] + ): + """The number of resource files equals the number of distinct resource types.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + distinct_types = {r.resource_type for r in resources} + assert len(result.resource_files) == len(distinct_types), ( + f"Expected {len(distinct_types)} files for {len(distinct_types)} " + f"distinct types, got {len(result.resource_files)}" + ) + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_each_file_named_after_resource_type( + self, resources: list[DiscoveredResource] + ): + """Each generated file is named .tf.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + distinct_types = {r.resource_type for r in resources} + expected_filenames = {f"{rt}.tf" for rt in distinct_types} + actual_filenames = {f.filename for f in result.resource_files} + + assert actual_filenames == expected_filenames, ( + f"Expected filenames {expected_filenames}, got {actual_filenames}" + ) + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_every_resource_appears_in_exactly_one_file( + self, resources: list[DiscoveredResource] + ): + """Every resource's unique_id appears in exactly one generated file.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + for resource in resources: + files_containing = [ + f.filename + for f in result.resource_files + if resource.unique_id in f.content + ] + assert len(files_containing) == 1, ( + f"Resource '{resource.unique_id}' found in {len(files_containing)} " + f"files: {files_containing}. Expected exactly 1." + ) + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_resource_count_per_file_matches( + self, resources: list[DiscoveredResource] + ): + """Each file's resource_count matches the actual number of resources of that type.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + # Count resources per type + from collections import Counter + type_counts = Counter(r.resource_type for r in resources) + + for gen_file in result.resource_files: + expected_type = gen_file.filename.replace(".tf", "") + assert gen_file.resource_count == type_counts[expected_type], ( + f"File {gen_file.filename} reports {gen_file.resource_count} resources " + f"but expected {type_counts[expected_type]}" + ) + + +# --------------------------------------------------------------------------- +# Property 13: Variable extraction for shared values +# --------------------------------------------------------------------------- + + +class TestVariableExtractionForSharedValues: + """Property 13: Variable extraction for shared values. + + **Validates: Requirements 3.3** + + For any set of resources where a value appears in 2+ resources, + a variable is extracted. + """ + + @given(data=resources_with_shared_values_strategy()) + @settings(max_examples=100) + def test_shared_value_produces_extracted_variable( + self, data: tuple[list[DiscoveredResource], str, str] + ): + """A value appearing in 2+ resources results in an extracted variable.""" + resources, shared_key, shared_value = data + + extractor = VariableExtractor() + variables = extractor.extract_variables(resources) + + # There should be at least one variable extracted for the shared key + var_names = [v.name for v in variables] + # The variable name should contain the shared key + matching_vars = [v for v in variables if shared_key in v.name] + assert len(matching_vars) >= 1, ( + f"Expected at least one variable for shared key '{shared_key}', " + f"got variables: {var_names}" + ) + + @given(data=resources_with_shared_values_strategy()) + @settings(max_examples=100) + def test_extracted_variable_has_correct_default( + self, data: tuple[list[DiscoveredResource], str, str] + ): + """The extracted variable's default value matches the shared value.""" + resources, shared_key, shared_value = data + + extractor = VariableExtractor() + variables = extractor.extract_variables(resources) + + matching_vars = [v for v in variables if shared_key in v.name] + assert len(matching_vars) >= 1 + + # The default should be the shared value (formatted as a string literal) + var = matching_vars[0] + assert shared_value in var.default_value, ( + f"Expected default to contain '{shared_value}', got '{var.default_value}'" + ) + + @given(data=resources_with_shared_values_strategy()) + @settings(max_examples=100) + def test_extracted_variable_tracks_usage( + self, data: tuple[list[DiscoveredResource], str, str] + ): + """The extracted variable's used_by list contains at least 2 resource IDs.""" + resources, shared_key, shared_value = data + + extractor = VariableExtractor() + variables = extractor.extract_variables(resources) + + matching_vars = [v for v in variables if shared_key in v.name] + assert len(matching_vars) >= 1 + + var = matching_vars[0] + assert len(var.used_by) >= 2, ( + f"Expected variable to be used by 2+ resources, " + f"got {len(var.used_by)}: {var.used_by}" + ) + + @given(data=resources_with_shared_values_strategy()) + @settings(max_examples=100) + def test_extracted_variable_has_type_and_description( + self, data: tuple[list[DiscoveredResource], str, str] + ): + """Each extracted variable has a non-empty type expression and description.""" + resources, shared_key, shared_value = data + + extractor = VariableExtractor() + variables = extractor.extract_variables(resources) + + for var in variables: + assert var.type_expr != "", f"Variable '{var.name}' has empty type_expr" + assert var.description != "", f"Variable '{var.name}' has empty description" + + +# --------------------------------------------------------------------------- +# Property 14: Identifier sanitization validity +# --------------------------------------------------------------------------- + + +class TestIdentifierSanitizationValidity: + """Property 14: Identifier sanitization validity. + + **Validates: Requirements 3.4** + + For any input string, sanitize_identifier produces a valid Terraform identifier. + """ + + TERRAFORM_IDENTIFIER_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + @given(name=arbitrary_string_strategy) + @settings(max_examples=200) + def test_sanitized_identifier_matches_terraform_pattern(self, name: str): + """The output always matches ^[a-zA-Z_][a-zA-Z0-9_]*$.""" + result = sanitize_identifier(name) + assert self.TERRAFORM_IDENTIFIER_REGEX.match(result), ( + f"sanitize_identifier({name!r}) = {result!r} does not match " + f"Terraform identifier pattern" + ) + + @given(name=arbitrary_string_strategy) + @settings(max_examples=200) + def test_sanitized_identifier_is_non_empty(self, name: str): + """The output is always a non-empty string.""" + result = sanitize_identifier(name) + assert len(result) > 0, ( + f"sanitize_identifier({name!r}) produced empty string" + ) + + @given(name=st.text(min_size=1, max_size=30, alphabet="0123456789")) + @settings(max_examples=100) + def test_digit_only_input_produces_valid_identifier(self, name: str): + """Input consisting only of digits still produces a valid identifier.""" + result = sanitize_identifier(name) + assert self.TERRAFORM_IDENTIFIER_REGEX.match(result), ( + f"sanitize_identifier({name!r}) = {result!r} is not valid for digit-only input" + ) + # Must not start with a digit + assert not result[0].isdigit(), ( + f"sanitize_identifier({name!r}) = {result!r} starts with a digit" + ) + + @given(name=st.text( + min_size=1, + max_size=30, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"), + ).filter(lambda s: s[0].isalpha() or s[0] == "_")) + @settings(max_examples=100) + def test_already_valid_identifiers_are_preserved_or_simplified(self, name: str): + """Input that is already a valid identifier produces a valid result.""" + result = sanitize_identifier(name) + assert self.TERRAFORM_IDENTIFIER_REGEX.match(result), ( + f"sanitize_identifier({name!r}) = {result!r} is not valid" + ) + + +# --------------------------------------------------------------------------- +# Property 15: Traceability comments in generated code +# --------------------------------------------------------------------------- + + +class TestTraceabilityCommentsInGeneratedCode: + """Property 15: Traceability comments in generated code. + + **Validates: Requirements 3.6** + + For any resource, the generated HCL includes a comment with the original unique_id. + """ + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_each_resource_has_traceability_comment( + self, resources: list[DiscoveredResource] + ): + """Every resource's unique_id appears in a comment in the generated output.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + # Collect all generated content + all_content = "\n".join(f.content for f in result.resource_files) + + for resource in resources: + # The unique_id should appear in a comment line + comment_pattern = f"# Source: {resource.unique_id}" + assert comment_pattern in all_content, ( + f"Traceability comment for resource '{resource.unique_id}' " + f"not found in generated output" + ) + + @given(resources=multiple_resources_strategy()) + @settings(max_examples=100) + def test_traceability_comment_precedes_resource_block( + self, resources: list[DiscoveredResource] + ): + """The traceability comment appears before its corresponding resource block.""" + graph = make_dependency_graph(resources) + profiles: list[ScanProfile] = [] + + generator = CodeGenerator() + result = generator.generate(graph, profiles) + + for resource in resources: + # Find the file containing this resource + target_file = None + for f in result.resource_files: + if resource.unique_id in f.content: + target_file = f + break + + assert target_file is not None + + content = target_file.content + comment_pos = content.find(f"# Source: {resource.unique_id}") + tf_name = sanitize_identifier(resource.name) + block_pattern = f'resource "{resource.resource_type}" "{tf_name}"' + block_pos = content.find(block_pattern, comment_pos) + + assert comment_pos < block_pos, ( + f"Comment for '{resource.unique_id}' (pos {comment_pos}) " + f"should precede resource block (pos {block_pos})" + ) diff --git a/tests/property/test_dependency_resolver_prop.py b/tests/property/test_dependency_resolver_prop.py new file mode 100644 index 0000000..885d39c --- /dev/null +++ b/tests/property/test_dependency_resolver_prop.py @@ -0,0 +1,565 @@ +"""Property-based tests for the Dependency Resolver. + +**Validates: Requirements 2.1, 2.3, 2.4, 2.5** + +Properties tested: +- Property 6: Dependency relationship identification +- Property 7: Cycle detection correctness +- Property 8: Topological order validity +- Property 9: Unresolved references become data sources or variables +""" + +from hypothesis import given, settings, assume +from hypothesis import strategies as st + +from iac_reverse.models import ( + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + PlatformCategory, + ProviderType, + ResourceRelationship, + ScanResult, + UnresolvedReference, +) +from iac_reverse.resolver import DependencyResolver + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) +platform_category_strategy = st.sampled_from(list(PlatformCategory)) +cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture)) + +# Strategy for generating valid resource IDs +resource_id_strategy = st.text( + min_size=3, + max_size=50, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-/"), +).filter(lambda s: s.strip() != "" and len(s) >= 3) + +# Strategy for resource names +resource_name_strategy = st.text( + min_size=1, + max_size=30, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"), +).filter(lambda s: s.strip() != "") + +# Strategy for resource types (simple identifiers) +resource_type_strategy = st.text( + min_size=3, + max_size=40, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"), +).filter(lambda s: s.strip() != "" and len(s) >= 3) + +# Strategy for endpoint strings +endpoint_strategy = st.text( + min_size=5, + max_size=50, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters=".-:/"), +).filter(lambda s: s.strip() != "") + + +def make_resource( + unique_id: str, + resource_type: str = "generic_resource", + name: str = "resource", + raw_references: list[str] | None = None, + attributes: dict | None = None, +) -> DiscoveredResource: + """Helper to create a DiscoveredResource with sensible defaults.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint="https://api.internal.lab:6443", + attributes=attributes or {"key": "value"}, + raw_references=raw_references or [], + ) + + +def make_scan_result(resources: list[DiscoveredResource]) -> ScanResult: + """Helper to create a ScanResult from a list of resources.""" + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test_hash", + is_partial=False, + ) + + +# Strategy to generate a list of resources with unique IDs and controlled references +@st.composite +def acyclic_resource_graph_strategy(draw): + """Generate a set of resources forming an acyclic dependency graph. + + Resources are created in order, and each resource can only reference + resources that were created before it (ensuring no cycles). + """ + num_resources = draw(st.integers(min_value=2, max_value=8)) + + resources = [] + ids = [] + for i in range(num_resources): + uid = f"resource_{i}" + ids.append(uid) + + # Each resource can only reference earlier resources (ensures acyclic) + if i > 0: + num_refs = draw(st.integers(min_value=0, max_value=min(i, 3))) + refs = draw( + st.lists( + st.sampled_from(ids[:i]), + min_size=num_refs, + max_size=num_refs, + unique=True, + ) + ) + else: + refs = [] + + resource = make_resource( + unique_id=uid, + name=f"res_{i}", + raw_references=refs, + ) + resources.append(resource) + + return resources + + +@st.composite +def cyclic_resource_graph_strategy(draw): + """Generate a set of resources that contain at least one cycle. + + Creates a base set of resources and then adds references to form a cycle. + """ + num_resources = draw(st.integers(min_value=2, max_value=6)) + + resources = [] + ids = [] + for i in range(num_resources): + uid = f"resource_{i}" + ids.append(uid) + resource = make_resource( + unique_id=uid, + name=f"res_{i}", + raw_references=[], + ) + resources.append(resource) + + # Create a cycle: pick a subset of at least 2 resources and form a ring + cycle_size = draw(st.integers(min_value=2, max_value=num_resources)) + cycle_indices = draw( + st.lists( + st.sampled_from(list(range(num_resources))), + min_size=cycle_size, + max_size=cycle_size, + unique=True, + ) + ) + + # Form a ring: each resource in the cycle references the next one + for j in range(len(cycle_indices)): + src_idx = cycle_indices[j] + tgt_idx = cycle_indices[(j + 1) % len(cycle_indices)] + target_id = ids[tgt_idx] + if target_id not in resources[src_idx].raw_references: + resources[src_idx].raw_references.append(target_id) + + return resources + + +@st.composite +def resources_with_unresolved_refs_strategy(draw): + """Generate resources where some raw_references point to IDs not in the inventory.""" + num_resources = draw(st.integers(min_value=1, max_value=5)) + + resources = [] + ids = [] + for i in range(num_resources): + uid = f"resource_{i}" + ids.append(uid) + + # Generate unresolved reference IDs (not in the inventory) + num_unresolved = draw(st.integers(min_value=1, max_value=4)) + unresolved_ids = [] + for i in range(num_unresolved): + # Mix of IDs with "/" (should suggest data_source) and without (should suggest variable) + if draw(st.booleans()): + unresolved_id = f"external/resource/{i}" + else: + unresolved_id = f"external_var_{i}" + unresolved_ids.append(unresolved_id) + + # Create resources, some referencing unresolved IDs + for i in range(num_resources): + # Pick some unresolved refs for this resource + num_ext_refs = draw(st.integers(min_value=0, max_value=min(num_unresolved, 2))) + ext_refs = draw( + st.lists( + st.sampled_from(unresolved_ids), + min_size=num_ext_refs, + max_size=num_ext_refs, + unique=True, + ) + ) + + resource = make_resource( + unique_id=ids[i], + name=f"res_{i}", + raw_references=ext_refs, + ) + resources.append(resource) + + return resources, unresolved_ids + + +# --------------------------------------------------------------------------- +# Property 6: Dependency relationship identification +# --------------------------------------------------------------------------- + + +class TestDependencyRelationshipIdentification: + """Property 6: Dependency relationship identification. + + **Validates: Requirements 2.1** + + For any resource with raw_references pointing to other resources in the + inventory, the resolver SHALL create a ResourceRelationship for each + resolved reference. + """ + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_relationship_created_for_each_resolved_reference( + self, resources: list[DiscoveredResource] + ): + """For each raw_reference pointing to a known resource, a relationship is created.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Count expected relationships: each raw_reference that points to a resource in inventory + resource_ids = {r.unique_id for r in resources} + expected_relationships = 0 + for resource in resources: + for ref in resource.raw_references: + if ref in resource_ids: + expected_relationships += 1 + + assert len(graph.relationships) == expected_relationships + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_relationship_source_and_target_are_correct( + self, resources: list[DiscoveredResource] + ): + """Each relationship has source_id as the referencing resource and target_id as the referenced.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + resource_ids = {r.unique_id for r in resources} + + for rel in graph.relationships: + # source_id is the resource that holds the reference + assert rel.source_id in resource_ids + # target_id is the resource being referenced + assert rel.target_id in resource_ids + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_relationship_type_is_valid( + self, resources: list[DiscoveredResource] + ): + """Each relationship has a valid relationship_type.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + valid_types = {"parent-child", "reference", "dependency"} + for rel in graph.relationships: + assert rel.relationship_type in valid_types + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_relationship_source_attribute_is_non_empty( + self, resources: list[DiscoveredResource] + ): + """Each relationship has a non-empty source_attribute.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for rel in graph.relationships: + assert isinstance(rel.source_attribute, str) + assert len(rel.source_attribute) > 0 + + +# --------------------------------------------------------------------------- +# Property 7: Cycle detection correctness +# --------------------------------------------------------------------------- + + +class TestCycleDetectionCorrectness: + """Property 7: Cycle detection correctness. + + **Validates: Requirements 2.3** + + For any graph containing a cycle, the resolver SHALL detect and report it + in the cycles list. For any acyclic dependency graph, the resolver SHALL + report zero cycles. + """ + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_acyclic_graph_reports_zero_cycles( + self, resources: list[DiscoveredResource] + ): + """An acyclic graph should have no cycles reported.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycles) == 0 + + @given(resources=cyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_cyclic_graph_reports_at_least_one_cycle( + self, resources: list[DiscoveredResource] + ): + """A graph with a cycle should have at least one cycle reported.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycles) >= 1 + + @given(resources=cyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_cycle_contains_valid_resource_ids( + self, resources: list[DiscoveredResource] + ): + """Each reported cycle contains only valid resource IDs from the inventory.""" + scan_result = make_scan_result(resources) + resource_ids = {r.unique_id for r in resources} + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for cycle in graph.cycles: + assert len(cycle) >= 2, "A cycle must involve at least 2 resources" + for resource_id in cycle: + assert resource_id in resource_ids + + @given(resources=cyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_cycle_reports_have_resolution_suggestions( + self, resources: list[DiscoveredResource] + ): + """Each cycle report includes a suggested break edge and resolution strategy.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for report in graph.cycle_reports: + assert report.suggested_break is not None + assert len(report.suggested_break) == 2 + assert report.break_relationship_type in {"parent-child", "reference", "dependency"} + assert isinstance(report.resolution_strategy, str) + assert len(report.resolution_strategy) > 0 + + +# --------------------------------------------------------------------------- +# Property 8: Topological order validity +# --------------------------------------------------------------------------- + + +class TestTopologicalOrderValidity: + """Property 8: Topological order validity. + + **Validates: Requirements 2.4** + + For any acyclic dependency graph, no resource SHALL appear before any + resource it depends on in the topological order. + """ + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_topological_order_contains_all_resources( + self, resources: list[DiscoveredResource] + ): + """The topological order must contain all resource IDs.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + resource_ids = {r.unique_id for r in resources} + assert set(graph.topological_order) == resource_ids + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_dependencies_appear_before_dependents( + self, resources: list[DiscoveredResource] + ): + """For every dependency edge (A depends on B), B appears before A in topological order. + + In the resolver, if resource A has B in raw_references, then A depends on B, + meaning B must appear before A in the topological order. + """ + scan_result = make_scan_result(resources) + resource_ids = {r.unique_id for r in resources} + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Build position map + position = {rid: idx for idx, rid in enumerate(graph.topological_order)} + + # For each resource, its referenced resources (that are in inventory) must come before it + for resource in resources: + for ref_id in resource.raw_references: + if ref_id in resource_ids: + assert position[ref_id] < position[resource.unique_id], ( + f"Resource '{ref_id}' (dependency) should appear before " + f"'{resource.unique_id}' (dependent) in topological order" + ) + + @given(resources=acyclic_resource_graph_strategy()) + @settings(max_examples=100) + def test_topological_order_has_no_duplicates( + self, resources: list[DiscoveredResource] + ): + """The topological order must not contain duplicate entries.""" + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.topological_order) == len(set(graph.topological_order)) + + +# --------------------------------------------------------------------------- +# Property 9: Unresolved references become data sources or variables +# --------------------------------------------------------------------------- + + +class TestUnresolvedReferences: + """Property 9: Unresolved references become data sources or variables. + + **Validates: Requirements 2.5** + + For any raw_reference pointing to an ID not in the inventory, the resolver + SHALL create an UnresolvedReference with suggested_resolution of either + "data_source" or "variable". + """ + + @given(data=resources_with_unresolved_refs_strategy()) + @settings(max_examples=100) + def test_unresolved_references_are_tracked( + self, data: tuple[list[DiscoveredResource], list[str]] + ): + """Each reference to an ID not in inventory creates an UnresolvedReference.""" + resources, unresolved_ids = data + scan_result = make_scan_result(resources) + resource_ids = {r.unique_id for r in resources} + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Count expected unresolved references + expected_unresolved = 0 + for resource in resources: + for ref in resource.raw_references: + if ref not in resource_ids: + expected_unresolved += 1 + + assert len(graph.unresolved_references) == expected_unresolved + + @given(data=resources_with_unresolved_refs_strategy()) + @settings(max_examples=100) + def test_unresolved_references_suggest_data_source_or_variable( + self, data: tuple[list[DiscoveredResource], list[str]] + ): + """Each UnresolvedReference has suggested_resolution of 'data_source' or 'variable'.""" + resources, _ = data + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for unresolved in graph.unresolved_references: + assert unresolved.suggested_resolution in {"data_source", "variable"}, ( + f"Expected 'data_source' or 'variable', got '{unresolved.suggested_resolution}'" + ) + + @given(data=resources_with_unresolved_refs_strategy()) + @settings(max_examples=100) + def test_unresolved_references_have_valid_source_resource( + self, data: tuple[list[DiscoveredResource], list[str]] + ): + """Each UnresolvedReference has a source_resource_id that exists in the inventory.""" + resources, _ = data + scan_result = make_scan_result(resources) + resource_ids = {r.unique_id for r in resources} + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for unresolved in graph.unresolved_references: + assert unresolved.source_resource_id in resource_ids + + @given(data=resources_with_unresolved_refs_strategy()) + @settings(max_examples=100) + def test_unresolved_references_have_non_empty_fields( + self, data: tuple[list[DiscoveredResource], list[str]] + ): + """Each UnresolvedReference has non-empty source_attribute and referenced_id.""" + resources, _ = data + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for unresolved in graph.unresolved_references: + assert isinstance(unresolved.source_attribute, str) + assert len(unresolved.source_attribute) > 0 + assert isinstance(unresolved.referenced_id, str) + assert len(unresolved.referenced_id) > 0 + + @given(data=resources_with_unresolved_refs_strategy()) + @settings(max_examples=100) + def test_ids_with_slash_or_colon_suggest_data_source( + self, data: tuple[list[DiscoveredResource], list[str]] + ): + """References containing '/' or ':' should suggest 'data_source' resolution.""" + resources, _ = data + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for unresolved in graph.unresolved_references: + if "/" in unresolved.referenced_id or ":" in unresolved.referenced_id: + assert unresolved.suggested_resolution == "data_source", ( + f"Reference '{unresolved.referenced_id}' contains '/' or ':' " + f"and should suggest 'data_source', got '{unresolved.suggested_resolution}'" + ) + + @given(data=resources_with_unresolved_refs_strategy()) + @settings(max_examples=100) + def test_ids_without_slash_or_colon_suggest_variable( + self, data: tuple[list[DiscoveredResource], list[str]] + ): + """References without '/' or ':' should suggest 'variable' resolution.""" + resources, _ = data + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for unresolved in graph.unresolved_references: + if "/" not in unresolved.referenced_id and ":" not in unresolved.referenced_id: + assert unresolved.suggested_resolution == "variable", ( + f"Reference '{unresolved.referenced_id}' has no '/' or ':' " + f"and should suggest 'variable', got '{unresolved.suggested_resolution}'" + ) diff --git a/tests/property/test_drift_report_prop.py b/tests/property/test_drift_report_prop.py new file mode 100644 index 0000000..7c73b94 --- /dev/null +++ b/tests/property/test_drift_report_prop.py @@ -0,0 +1,308 @@ +"""Property-based tests for drift report correctness. + +**Validates: Requirements 7.3** + +Properties tested: +- Property 22: Drift report correctness — For any terraform plan output + containing planned changes, the Validator SHALL report each change with + the correct resource address and change type (add, modify, destroy). +""" + +import json +import tempfile +from unittest.mock import MagicMock, patch + +from hypothesis import given, settings, assume +from hypothesis import strategies as st + +from iac_reverse.models import PlannedChange, ValidationResult +from iac_reverse.validator import Validator + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +# Terraform action types that map to our change types +TERRAFORM_ACTIONS = ["create", "update", "delete"] + +# Expected mapping from terraform actions to our change types +ACTION_TO_CHANGE_TYPE = { + "create": "add", + "update": "modify", + "delete": "destroy", +} + +# Strategy for valid terraform resource addresses +# Format: . or .. +resource_type_prefix_strategy = st.sampled_from([ + "aws_instance", + "kubernetes_deployment", + "docker_service", + "harvester_virtualmachine", + "synology_shared_folder", + "windows_service", + "bare_metal_hardware", + "null_resource", + "local_file", + "random_id", +]) + +resource_name_suffix_strategy = st.text( + min_size=1, + max_size=20, + alphabet=st.characters(whitelist_categories=("Ll",), whitelist_characters="_"), +).filter(lambda s: s[0].isalpha() or s[0] == "_") + + +@st.composite +def resource_address_strategy(draw): + """Generate a valid terraform resource address like 'aws_instance.my_server'.""" + prefix = draw(resource_type_prefix_strategy) + suffix = draw(resource_name_suffix_strategy) + # Optionally add a module prefix + use_module = draw(st.booleans()) + if use_module: + module_name = draw(st.text( + min_size=1, + max_size=10, + alphabet=st.characters(whitelist_categories=("Ll",), whitelist_characters="_"), + ).filter(lambda s: s[0].isalpha())) + return f"module.{module_name}.{prefix}.{suffix}" + return f"{prefix}.{suffix}" + + +terraform_action_strategy = st.sampled_from(TERRAFORM_ACTIONS) + + +@st.composite +def planned_change_entry_strategy(draw): + """Generate a single planned change entry as it appears in terraform plan JSON output.""" + addr = draw(resource_address_strategy()) + action = draw(terraform_action_strategy) + return (addr, action) + + +@st.composite +def planned_changes_list_strategy(draw): + """Generate a list of planned changes with unique resource addresses.""" + num_changes = draw(st.integers(min_value=1, max_value=10)) + changes = [] + seen_addrs = set() + + for _ in range(num_changes): + entry = draw(planned_change_entry_strategy()) + addr, action = entry + # Ensure unique addresses + if addr in seen_addrs: + continue + seen_addrs.add(addr) + changes.append((addr, action)) + + assume(len(changes) >= 1) + return changes + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +VALIDATE_SUCCESS_JSON = json.dumps( + {"valid": True, "error_count": 0, "diagnostics": []} +) + + +def _make_completed_process(returncode=0, stdout="", stderr=""): + """Create a mock CompletedProcess-like object.""" + mock = MagicMock() + mock.returncode = returncode + mock.stdout = stdout + mock.stderr = stderr + return mock + + +def build_plan_output(changes: list[tuple[str, str]]) -> str: + """Build terraform plan JSON streaming output from a list of (addr, action) tuples.""" + lines = [json.dumps({"type": "version", "terraform": "1.7.0"})] + + for addr, action in changes: + lines.append( + json.dumps( + { + "type": "planned_change", + "change": { + "resource": {"addr": addr}, + "action": action, + }, + } + ) + ) + + # Add change_summary + total_add = sum(1 for _, a in changes if a == "create") + total_change = sum(1 for _, a in changes if a == "update") + total_remove = sum(1 for _, a in changes if a == "delete") + lines.append( + json.dumps( + { + "type": "change_summary", + "changes": { + "add": total_add, + "change": total_change, + "remove": total_remove, + }, + } + ) + ) + + return "\n".join(lines) + + +def run_validator_with_plan(plan_output: str) -> ValidationResult: + """Run the Validator with mocked subprocess calls, returning the result.""" + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with tempfile.TemporaryDirectory() as tmp_dir: + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + return validator.validate(tmp_dir) + + +# --------------------------------------------------------------------------- +# Property 22: Drift report correctness +# --------------------------------------------------------------------------- + + +class TestDriftReportCorrectness: + """Property 22: Drift report correctness. + + **Validates: Requirements 7.3** + + For any terraform plan output containing N planned changes, the drift + report SHALL list exactly N entries, each with the correct resource + address and change type (add, modify, or destroy). + """ + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_count_matches_planned_changes( + self, changes: list[tuple[str, str]] + ): + """The number of reported planned changes equals the number in the plan output.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + assert len(result.planned_changes) == len(changes), ( + f"Expected {len(changes)} planned changes, " + f"got {len(result.planned_changes)}. " + f"Input changes: {changes}" + ) + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_resource_addresses_match( + self, changes: list[tuple[str, str]] + ): + """Each reported change has the correct resource address from the plan.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + expected_addrs = {addr for addr, _ in changes} + actual_addrs = {c.resource_address for c in result.planned_changes} + + assert actual_addrs == expected_addrs, ( + f"Resource address mismatch.\n" + f"Expected: {sorted(expected_addrs)}\n" + f"Actual: {sorted(actual_addrs)}" + ) + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_change_types_correct( + self, changes: list[tuple[str, str]] + ): + """Each reported change has the correct change type mapping.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + # Build expected mapping: addr -> change_type + expected_map = { + addr: ACTION_TO_CHANGE_TYPE[action] for addr, action in changes + } + + for planned_change in result.planned_changes: + addr = planned_change.resource_address + assert addr in expected_map, ( + f"Unexpected resource address '{addr}' in planned changes" + ) + expected_type = expected_map[addr] + assert planned_change.change_type == expected_type, ( + f"For resource '{addr}': expected change_type='{expected_type}', " + f"got '{planned_change.change_type}'" + ) + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_plan_success_is_false( + self, changes: list[tuple[str, str]] + ): + """When there are planned changes, plan_success is always False.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + assert result.plan_success is False, ( + f"plan_success should be False when there are {len(changes)} " + f"planned changes, but got True" + ) + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_each_change_is_planned_change_instance( + self, changes: list[tuple[str, str]] + ): + """Each entry in the drift report is a PlannedChange instance.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + for i, change in enumerate(result.planned_changes): + assert isinstance(change, PlannedChange), ( + f"Entry {i} is {type(change).__name__}, expected PlannedChange" + ) + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_change_type_in_valid_set( + self, changes: list[tuple[str, str]] + ): + """Every reported change_type is one of 'add', 'modify', or 'destroy'.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + valid_types = {"add", "modify", "destroy"} + for change in result.planned_changes: + assert change.change_type in valid_types, ( + f"Invalid change_type '{change.change_type}' for resource " + f"'{change.resource_address}'. Must be one of {valid_types}" + ) + + @given(changes=planned_changes_list_strategy()) + @settings(max_examples=100) + def test_drift_report_no_duplicate_addresses( + self, changes: list[tuple[str, str]] + ): + """No resource address appears more than once in the drift report.""" + plan_output = build_plan_output(changes) + result = run_validator_with_plan(plan_output) + + addresses = [c.resource_address for c in result.planned_changes] + assert len(addresses) == len(set(addresses)), ( + f"Duplicate resource addresses found in drift report: " + f"{[a for a in addresses if addresses.count(a) > 1]}" + ) diff --git a/tests/property/test_incremental_scan_prop.py b/tests/property/test_incremental_scan_prop.py new file mode 100644 index 0000000..e63fdb0 --- /dev/null +++ b/tests/property/test_incremental_scan_prop.py @@ -0,0 +1,790 @@ +"""Property-based tests for Incremental Scan Engine. + +**Validates: Requirements 8.1, 8.2, 8.3, 8.5, 8.6** + +Properties tested: +- Property 23: Change classification correctness +- Property 24: Incremental update scope +- Property 25: Removed resource exclusion +- Property 26: Snapshot retention +""" + +import json +import tempfile +from pathlib import Path + +from hypothesis import given, settings, assume +from hypothesis import strategies as st + +from iac_reverse.incremental import ChangeDetector, IncrementalUpdater, SnapshotStore +from iac_reverse.models import ( + ChangeSummary, + ChangeType, + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ResourceChange, + ScanResult, +) + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_strategy = st.sampled_from(list(ProviderType)) +platform_strategy = st.sampled_from(list(PlatformCategory)) +architecture_strategy = st.sampled_from(list(CpuArchitecture)) + +# Simple attribute values for resources +attribute_value_strategy = st.one_of( + st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz0123456789"), + st.integers(min_value=0, max_value=1000), + st.booleans(), +) + +attributes_strategy = st.dictionaries( + keys=st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz_"), + values=attribute_value_strategy, + min_size=1, + max_size=5, +) + +# Resource name strategy (valid identifiers) +resource_name_strategy = st.text( + min_size=1, + max_size=15, + alphabet="abcdefghijklmnopqrstuvwxyz_", +).filter(lambda s: s[0].isalpha()) + +# Resource type strategy +resource_type_strategy = st.sampled_from([ + "docker_service", + "kubernetes_deployment", + "synology_shared_folder", + "harvester_virtualmachine", + "bare_metal_hardware", + "windows_service", +]) + + +@st.composite +def discovered_resource_strategy(draw, uid=None): + """Generate a DiscoveredResource with valid fields.""" + resource_type = draw(resource_type_strategy) + unique_id = uid or draw(st.text( + min_size=5, max_size=30, + alphabet="abcdefghijklmnopqrstuvwxyz0123456789_-/", + ).filter(lambda s: s[0].isalpha())) + name = draw(resource_name_strategy) + provider = draw(provider_strategy) + platform = draw(platform_strategy) + arch = draw(architecture_strategy) + endpoint = draw(st.text(min_size=3, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz.")) + attributes = draw(attributes_strategy) + + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform, + architecture=arch, + endpoint=endpoint, + attributes=attributes, + raw_references=[], + ) + + +@st.composite +def scan_result_strategy(draw, min_resources=0, max_resources=8): + """Generate a ScanResult with unique resource IDs.""" + num_resources = draw(st.integers(min_value=min_resources, max_value=max_resources)) + resources = [] + seen_ids = set() + + for i in range(num_resources): + uid = f"resource_{i}_{draw(st.text(min_size=3, max_size=8, alphabet='abcdefghijklmnopqrstuvwxyz'))}" + if uid in seen_ids: + uid = f"resource_{i}_fallback" + seen_ids.add(uid) + + resource = draw(discovered_resource_strategy(uid=uid)) + resources.append(resource) + + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test_profile_hash", + is_partial=False, + ) + + +@st.composite +def scan_result_pair_strategy(draw): + """Generate a pair of scan results with some overlap for meaningful diffs. + + Creates a previous and current scan where: + - Some resources exist in both (potentially modified) + - Some resources only in previous (removed) + - Some resources only in current (added) + """ + # Shared resources (exist in both, may be modified) + num_shared = draw(st.integers(min_value=0, max_value=4)) + # Resources only in previous (will be removed) + num_removed = draw(st.integers(min_value=0, max_value=3)) + # Resources only in current (will be added) + num_added = draw(st.integers(min_value=0, max_value=3)) + + assume(num_shared + num_removed + num_added >= 1) + + previous_resources = [] + current_resources = [] + + # Generate shared resources + for i in range(num_shared): + uid = f"shared_{i}" + resource_type = draw(resource_type_strategy) + name = draw(resource_name_strategy) + provider = draw(provider_strategy) + platform = draw(platform_strategy) + arch = draw(architecture_strategy) + endpoint = draw(st.text(min_size=3, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz.")) + prev_attrs = draw(attributes_strategy) + + prev_resource = DiscoveredResource( + resource_type=resource_type, + unique_id=uid, + name=name, + provider=provider, + platform_category=platform, + architecture=arch, + endpoint=endpoint, + attributes=prev_attrs, + raw_references=[], + ) + previous_resources.append(prev_resource) + + # Possibly modify attributes for current version + modify = draw(st.booleans()) + if modify: + curr_attrs = draw(attributes_strategy) + else: + curr_attrs = dict(prev_attrs) + + curr_resource = DiscoveredResource( + resource_type=resource_type, + unique_id=uid, + name=name, + provider=provider, + platform_category=platform, + architecture=arch, + endpoint=endpoint, + attributes=curr_attrs, + raw_references=[], + ) + current_resources.append(curr_resource) + + # Generate removed resources (only in previous) + for i in range(num_removed): + uid = f"removed_{i}" + resource = draw(discovered_resource_strategy(uid=uid)) + previous_resources.append(resource) + + # Generate added resources (only in current) + for i in range(num_added): + uid = f"added_{i}" + resource = draw(discovered_resource_strategy(uid=uid)) + current_resources.append(resource) + + previous = ScanResult( + resources=previous_resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-14T09:00:00Z", + profile_hash="test_profile", + is_partial=False, + ) + current = ScanResult( + resources=current_resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test_profile", + is_partial=False, + ) + + return previous, current + + +# --------------------------------------------------------------------------- +# Property 23: Change classification correctness +# --------------------------------------------------------------------------- + + +class TestChangeClassificationCorrectness: + """Property 23: Change classification correctness. + + **Validates: Requirements 8.1, 8.5** + + For any pair of scan results (previous and current), every resource + SHALL be classified exactly once as: added, removed, or modified. + The summary counts SHALL equal the actual number of resources in each + category. + """ + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100) + def test_every_resource_classified_exactly_once(self, data): + """Every resource is classified as exactly one of: added, removed, or modified.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + prev_ids = {r.unique_id for r in previous.resources} + curr_ids = {r.unique_id for r in current.resources} + all_ids = prev_ids | curr_ids + + # Each change should reference a resource from either scan + change_ids = [c.resource_id for c in summary.changes] + + # No duplicates in changes + assert len(change_ids) == len(set(change_ids)), ( + f"Duplicate resource IDs in changes: " + f"{[rid for rid in change_ids if change_ids.count(rid) > 1]}" + ) + + # Every changed resource must be from the union of both scans + for change in summary.changes: + assert change.resource_id in all_ids, ( + f"Change references unknown resource: {change.resource_id}" + ) + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100) + def test_added_resources_in_current_not_previous(self, data): + """Resources classified as ADDED are in current but not in previous.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + prev_ids = {r.unique_id for r in previous.resources} + curr_ids = {r.unique_id for r in current.resources} + + added_changes = [c for c in summary.changes if c.change_type == ChangeType.ADDED] + for change in added_changes: + assert change.resource_id in curr_ids, ( + f"ADDED resource {change.resource_id} not in current scan" + ) + assert change.resource_id not in prev_ids, ( + f"ADDED resource {change.resource_id} exists in previous scan" + ) + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100) + def test_removed_resources_in_previous_not_current(self, data): + """Resources classified as REMOVED are in previous but not in current.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + prev_ids = {r.unique_id for r in previous.resources} + curr_ids = {r.unique_id for r in current.resources} + + removed_changes = [c for c in summary.changes if c.change_type == ChangeType.REMOVED] + for change in removed_changes: + assert change.resource_id in prev_ids, ( + f"REMOVED resource {change.resource_id} not in previous scan" + ) + assert change.resource_id not in curr_ids, ( + f"REMOVED resource {change.resource_id} exists in current scan" + ) + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100) + def test_modified_resources_in_both_with_differing_attributes(self, data): + """Resources classified as MODIFIED exist in both scans with differing attributes.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + prev_map = {r.unique_id: r for r in previous.resources} + curr_map = {r.unique_id: r for r in current.resources} + + modified_changes = [c for c in summary.changes if c.change_type == ChangeType.MODIFIED] + for change in modified_changes: + assert change.resource_id in prev_map, ( + f"MODIFIED resource {change.resource_id} not in previous scan" + ) + assert change.resource_id in curr_map, ( + f"MODIFIED resource {change.resource_id} not in current scan" + ) + # Attributes must actually differ + assert prev_map[change.resource_id].attributes != curr_map[change.resource_id].attributes, ( + f"MODIFIED resource {change.resource_id} has identical attributes" + ) + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100) + def test_summary_counts_match_actual_changes(self, data): + """Summary counts equal the actual number of resources in each category.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + actual_added = sum(1 for c in summary.changes if c.change_type == ChangeType.ADDED) + actual_removed = sum(1 for c in summary.changes if c.change_type == ChangeType.REMOVED) + actual_modified = sum(1 for c in summary.changes if c.change_type == ChangeType.MODIFIED) + + assert summary.added_count == actual_added, ( + f"added_count={summary.added_count} != actual={actual_added}" + ) + assert summary.removed_count == actual_removed, ( + f"removed_count={summary.removed_count} != actual={actual_removed}" + ) + assert summary.modified_count == actual_modified, ( + f"modified_count={summary.modified_count} != actual={actual_modified}" + ) + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100) + def test_change_types_are_valid(self, data): + """Every change has a valid ChangeType value.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + valid_types = {ChangeType.ADDED, ChangeType.REMOVED, ChangeType.MODIFIED} + for change in summary.changes: + assert change.change_type in valid_types, ( + f"Invalid change_type: {change.change_type}" + ) + + +# --------------------------------------------------------------------------- +# Property 24: Incremental update scope +# --------------------------------------------------------------------------- + + +class TestIncrementalUpdateScope: + """Property 24: Incremental update scope. + + **Validates: Requirements 8.2** + + For any change set applied to existing IaC files, only files containing + added, modified, or removed resources SHALL be modified. Files containing + only unchanged resources SHALL remain identical. + """ + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100, deadline=None) + def test_only_changed_resource_files_are_modified(self, data): + """Only .tf files for resource types with changes are modified.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + # Skip if no changes (nothing to test) + assume(len(summary.changes) > 0) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Create initial .tf files for all resource types in previous scan + resource_types_in_previous = {r.resource_type for r in previous.resources} + # Also create a file for an "unchanged" resource type + unchanged_type = "unchanged_resource_type" + resource_types_in_previous.add(unchanged_type) + + for rt in resource_types_in_previous: + tf_path = Path(tmp_dir) / f"{rt}.tf" + tf_path.write_text(f'# Placeholder for {rt}\n', encoding="utf-8") + + # Record original content of the unchanged file + unchanged_path = Path(tmp_dir) / f"{unchanged_type}.tf" + original_unchanged_content = unchanged_path.read_text(encoding="utf-8") + + # Build resource_attributes for added resources + resource_attributes = {} + for change in summary.changes: + if change.change_type == ChangeType.ADDED: + # Find the resource in current scan + for r in current.resources: + if r.unique_id == change.resource_id: + resource_attributes[change.resource_id] = r.attributes + break + + # Apply incremental update + updater = IncrementalUpdater( + change_summary=summary, + output_dir=tmp_dir, + resource_attributes=resource_attributes, + ) + updater.apply() + + # The unchanged file should not be modified + assert unchanged_path.read_text(encoding="utf-8") == original_unchanged_content, ( + "File for unchanged resource type was modified" + ) + + # Modified files should only be for resource types with changes + changed_resource_types = {c.resource_type for c in summary.changes} + for modified_file in updater.modified_files: + file_name = Path(modified_file).name + # Modified files should be .tf files for changed resource types + # or the state file + if file_name == "terraform.tfstate": + continue + assert file_name.endswith(".tf"), ( + f"Unexpected modified file: {file_name}" + ) + rt = file_name[:-3] # strip .tf + assert rt in changed_resource_types, ( + f"File {file_name} was modified but resource type " + f"'{rt}' has no changes" + ) + + +# --------------------------------------------------------------------------- +# Property 25: Removed resource exclusion +# --------------------------------------------------------------------------- + + +class TestRemovedResourceExclusion: + """Property 25: Removed resource exclusion. + + **Validates: Requirements 8.3** + + For any resource classified as removed, the updated IaC output SHALL + not contain a resource block for that resource, AND the updated state + file SHALL not contain a state entry for that resource. + """ + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100, deadline=None) + def test_removed_resources_not_in_tf_files(self, data): + """Removed resources do not appear in .tf files after update.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + removed_changes = [c for c in summary.changes if c.change_type == ChangeType.REMOVED] + assume(len(removed_changes) > 0) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Create .tf files with resource blocks for previous resources + from iac_reverse.generator.sanitize import sanitize_identifier + + resources_by_type: dict[str, list] = {} + for r in previous.resources: + resources_by_type.setdefault(r.resource_type, []).append(r) + + for rt, resources in resources_by_type.items(): + tf_path = Path(tmp_dir) / f"{rt}.tf" + lines = [] + for r in resources: + tf_name = sanitize_identifier(r.name) + lines.append(f'# Source: {r.unique_id}') + lines.append(f'resource "{rt}" "{tf_name}" {{') + for k, v in r.attributes.items(): + lines.append(f' {k} = "{v}"') + lines.append("}") + lines.append("") + tf_path.write_text("\n".join(lines), encoding="utf-8") + + # Build resource_attributes for added resources + resource_attributes = {} + for change in summary.changes: + if change.change_type == ChangeType.ADDED: + for r in current.resources: + if r.unique_id == change.resource_id: + resource_attributes[change.resource_id] = r.attributes + break + + # Apply incremental update + updater = IncrementalUpdater( + change_summary=summary, + output_dir=tmp_dir, + resource_attributes=resource_attributes, + ) + updater.apply() + + # Verify removed resources are not in any .tf file + for change in removed_changes: + tf_path = Path(tmp_dir) / f"{change.resource_type}.tf" + if tf_path.exists(): + content = tf_path.read_text(encoding="utf-8") + tf_name = sanitize_identifier(change.resource_name) + # The resource block should not exist + block_header = f'resource "{change.resource_type}" "{tf_name}"' + assert block_header not in content, ( + f"Removed resource {change.resource_id} still has a " + f"resource block in {tf_path.name}" + ) + + @given(data=scan_result_pair_strategy()) + @settings(max_examples=100, deadline=None) + def test_removed_resources_not_in_state_file(self, data): + """Removed resources do not appear in the state file after update.""" + previous, current = data + detector = ChangeDetector() + summary = detector.compare(current, previous) + + removed_changes = [c for c in summary.changes if c.change_type == ChangeType.REMOVED] + assume(len(removed_changes) > 0) + + with tempfile.TemporaryDirectory() as tmp_dir: + from iac_reverse.generator.sanitize import sanitize_identifier + + # Create initial state file with entries for previous resources + state = { + "version": 4, + "terraform_version": "1.7.0", + "serial": 1, + "lineage": "test-lineage", + "outputs": {}, + "resources": [], + } + for r in previous.resources: + tf_name = sanitize_identifier(r.name) + state["resources"].append({ + "mode": "managed", + "type": r.resource_type, + "name": tf_name, + "provider": f'provider["registry.terraform.io/hashicorp/{r.resource_type.split("_")[0]}"]', + "instances": [{ + "schema_version": 0, + "attributes": {"id": r.unique_id, **r.attributes}, + "sensitive_attributes": [], + "dependencies": [], + }], + }) + + state_path = Path(tmp_dir) / "terraform.tfstate" + state_path.write_text(json.dumps(state, indent=2), encoding="utf-8") + + # Create .tf files so updater can process removals + resources_by_type: dict[str, list] = {} + for r in previous.resources: + resources_by_type.setdefault(r.resource_type, []).append(r) + + for rt, resources in resources_by_type.items(): + tf_path = Path(tmp_dir) / f"{rt}.tf" + lines = [] + for r in resources: + tf_name = sanitize_identifier(r.name) + lines.append(f'# Source: {r.unique_id}') + lines.append(f'resource "{rt}" "{tf_name}" {{') + for k, v in r.attributes.items(): + lines.append(f' {k} = "{v}"') + lines.append("}") + lines.append("") + tf_path.write_text("\n".join(lines), encoding="utf-8") + + # Build resource_attributes for added resources + resource_attributes = {} + for change in summary.changes: + if change.change_type == ChangeType.ADDED: + for r in current.resources: + if r.unique_id == change.resource_id: + resource_attributes[change.resource_id] = r.attributes + break + + # Apply incremental update + updater = IncrementalUpdater( + change_summary=summary, + output_dir=tmp_dir, + resource_attributes=resource_attributes, + ) + updater.apply() + + # Verify removed resources are not in state file + updated_state = json.loads( + state_path.read_text(encoding="utf-8") + ) + state_entries = updated_state.get("resources", []) + + for change in removed_changes: + tf_name = sanitize_identifier(change.resource_name) + matching = [ + e for e in state_entries + if e.get("type") == change.resource_type + and e.get("name") == tf_name + ] + assert len(matching) == 0, ( + f"Removed resource {change.resource_id} still has a " + f"state entry (type={change.resource_type}, name={tf_name})" + ) + + +# --------------------------------------------------------------------------- +# Property 26: Snapshot retention +# --------------------------------------------------------------------------- + + +class TestSnapshotRetention: + """Property 26: Snapshot retention. + + **Validates: Requirements 8.6** + + For any sequence of N scans (N >= 2) for the same Scan_Profile, at + least the two most recent scan results SHALL be retained in storage + after each scan completes. + """ + + @given(num_scans=st.integers(min_value=2, max_value=8)) + @settings(max_examples=100) + def test_at_least_two_snapshots_retained(self, num_scans): + """After N scans, at least 2 most recent snapshots are retained.""" + from unittest.mock import patch + from datetime import datetime, timezone + + with tempfile.TemporaryDirectory() as tmp_dir: + store = SnapshotStore(base_dir=tmp_dir) + profile_hash = "retention_test_profile" + + # Store N scan results with mocked timestamps to ensure unique filenames + for i in range(num_scans): + result = ScanResult( + resources=[ + DiscoveredResource( + resource_type="docker_service", + unique_id=f"svc_{i}", + name=f"service_{i}", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint="localhost", + attributes={"version": str(i)}, + raw_references=[], + ) + ], + warnings=[], + errors=[], + scan_timestamp=f"2024-01-{15 + i:02d}T10:00:00Z", + profile_hash=profile_hash, + is_partial=False, + ) + # Mock datetime.now to return unique timestamps + mock_time = datetime(2024, 1, 15 + i, 10, 0, 0, tzinfo=timezone.utc) + with patch( + "iac_reverse.incremental.snapshot_store.datetime" + ) as mock_dt: + mock_dt.now.return_value = mock_time + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + store.store_snapshot(result, profile_hash) + + # Count remaining snapshots + snapshot_files = list(store.snapshot_dir.glob(f"{profile_hash}_*.json")) + assert len(snapshot_files) >= 2, ( + f"After {num_scans} scans, only {len(snapshot_files)} " + f"snapshots retained (expected >= 2)" + ) + + @given(num_scans=st.integers(min_value=2, max_value=8)) + @settings(max_examples=100) + def test_most_recent_snapshot_is_loadable(self, num_scans): + """The most recent snapshot can be loaded after multiple stores.""" + from unittest.mock import patch + from datetime import datetime, timezone + + with tempfile.TemporaryDirectory() as tmp_dir: + store = SnapshotStore(base_dir=tmp_dir) + profile_hash = "loadable_test_profile" + + last_resource_id = None + for i in range(num_scans): + last_resource_id = f"svc_{i}" + result = ScanResult( + resources=[ + DiscoveredResource( + resource_type="kubernetes_deployment", + unique_id=last_resource_id, + name=f"deploy_{i}", + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + endpoint="k8s-api.local", + attributes={"replicas": i + 1}, + raw_references=[], + ) + ], + warnings=[], + errors=[], + scan_timestamp=f"2024-01-{15 + i:02d}T10:00:00Z", + profile_hash=profile_hash, + is_partial=False, + ) + mock_time = datetime(2024, 1, 15 + i, 10, 0, 0, tzinfo=timezone.utc) + with patch( + "iac_reverse.incremental.snapshot_store.datetime" + ) as mock_dt: + mock_dt.now.return_value = mock_time + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + store.store_snapshot(result, profile_hash) + + # Load the most recent snapshot + loaded = store.load_previous(profile_hash) + assert loaded is not None, "Could not load most recent snapshot" + assert len(loaded.resources) == 1 + assert loaded.resources[0].unique_id == last_resource_id, ( + f"Expected most recent resource '{last_resource_id}', " + f"got '{loaded.resources[0].unique_id}'" + ) + + @given(num_scans=st.integers(min_value=3, max_value=10)) + @settings(max_examples=100) + def test_different_profiles_retain_independently(self, num_scans): + """Snapshots for different profiles are retained independently.""" + from unittest.mock import patch + from datetime import datetime, timezone + + with tempfile.TemporaryDirectory() as tmp_dir: + store = SnapshotStore(base_dir=tmp_dir) + profile_a = "profile_alpha" + profile_b = "profile_beta" + + scan_idx = 0 + for i in range(num_scans): + for profile_hash in [profile_a, profile_b]: + result = ScanResult( + resources=[ + DiscoveredResource( + resource_type="docker_service", + unique_id=f"{profile_hash}_svc_{i}", + name=f"svc_{i}", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint="localhost", + attributes={"idx": i}, + raw_references=[], + ) + ], + warnings=[], + errors=[], + scan_timestamp=f"2024-01-{15 + i:02d}T10:00:00Z", + profile_hash=profile_hash, + is_partial=False, + ) + # Use unique timestamps per store call + mock_time = datetime(2024, 1, 15, 10, scan_idx, 0, tzinfo=timezone.utc) + scan_idx += 1 + with patch( + "iac_reverse.incremental.snapshot_store.datetime" + ) as mock_dt: + mock_dt.now.return_value = mock_time + mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw) + store.store_snapshot(result, profile_hash) + + # Both profiles should have at least 2 snapshots + snapshots_a = list(store.snapshot_dir.glob(f"{profile_a}_*.json")) + snapshots_b = list(store.snapshot_dir.glob(f"{profile_b}_*.json")) + + assert len(snapshots_a) >= 2, ( + f"Profile A has {len(snapshots_a)} snapshots (expected >= 2)" + ) + assert len(snapshots_b) >= 2, ( + f"Profile B has {len(snapshots_b)} snapshots (expected >= 2)" + ) diff --git a/tests/property/test_multi_provider_prop.py b/tests/property/test_multi_provider_prop.py new file mode 100644 index 0000000..cf0ed9a --- /dev/null +++ b/tests/property/test_multi_provider_prop.py @@ -0,0 +1,803 @@ +"""Property-based tests for multi-provider merging and filtering. + +**Validates: Requirements 5.3, 5.4, 6.1, 6.2, 6.4, 6.6, 6.7** + +Property 18: Multi-provider merge with naming conflict resolution +For any two or more resource inventories from different on-premises providers +where resource names collide, the merged inventory SHALL contain all resources +from all providers, with conflicting names prefixed by the provider identifier, +and no resources lost. + +Property 19: Provider block generation +For any resource set spanning N distinct on-premises providers, the generated +provider configuration SHALL contain exactly N provider blocks, one per distinct +provider. + +Property 20: Scan profile validation completeness (additional multi-provider scenarios) +Already covered in test_scan_profile_validation_prop.py; this adds multi-provider +scenarios. + +Property 21: Filtering correctness +For any scan profile with resource type filters, the discovered resources SHALL +be a subset where every resource's type is in the filter list. No resource outside +the filter criteria shall appear. +""" + +from typing import Callable + +from hypothesis import given, assume, settings +from hypothesis import strategies as st + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + GeneratedFile, + PlatformCategory, + PROVIDER_PLATFORM_MAP, + PROVIDER_SUPPORTED_RESOURCE_TYPES, + ProviderType, + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.generator import ProviderBlockGenerator, ResourceMerger +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import Scanner + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) + +architecture_strategy = st.sampled_from(list(CpuArchitecture)) + +non_empty_credentials_strategy = st.dictionaries( + keys=st.text( + min_size=1, max_size=20, + alphabet=st.characters(whitelist_categories=("L", "N")), + ), + values=st.text(min_size=1, max_size=50), + min_size=1, + max_size=5, +) + +# Strategy for resource names that could collide across providers +resource_name_strategy = st.text( + min_size=1, + max_size=30, + alphabet=st.characters(whitelist_categories=("L", "N", "Pd")), +).filter(lambda s: s.strip()) + + +def discovered_resource_strategy( + provider: ProviderType, + name: str | None = None, +) -> st.SearchStrategy[DiscoveredResource]: + """Generate a DiscoveredResource for a given provider with optional fixed name.""" + supported_types = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + platform_category = PROVIDER_PLATFORM_MAP[provider] + + return st.builds( + DiscoveredResource, + resource_type=st.sampled_from(supported_types), + unique_id=st.uuids().map(str), + name=st.just(name) if name else resource_name_strategy, + provider=st.just(provider), + platform_category=st.just(platform_category), + architecture=architecture_strategy, + endpoint=st.just("http://localhost:8080"), + attributes=st.just({"key": "value"}), + raw_references=st.just([]), + ) + + +def scan_result_strategy( + provider: ProviderType, + resources: list[DiscoveredResource] | None = None, +) -> st.SearchStrategy[ScanResult]: + """Generate a ScanResult for a given provider.""" + if resources is not None: + return st.just(ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-01T00:00:00Z", + profile_hash="abc123", + )) + return st.builds( + ScanResult, + resources=st.lists( + discovered_resource_strategy(provider), + min_size=1, + max_size=5, + ), + warnings=st.just([]), + errors=st.just([]), + scan_timestamp=st.just("2024-01-01T00:00:00Z"), + profile_hash=st.just("abc123"), + ) + + +# --------------------------------------------------------------------------- +# Mock Plugin for Filtering Tests +# --------------------------------------------------------------------------- + + +class FilteringPlugin(ProviderPlugin): + """A plugin that discovers resources only for requested resource types.""" + + def __init__(self, provider: ProviderType): + self._provider = provider + self._supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PROVIDER_PLATFORM_MAP[self._provider] + + def list_endpoints(self) -> list[str]: + return ["http://localhost:8080"] + + def list_supported_resource_types(self) -> list[str]: + return self._supported + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + """Discover exactly one resource per requested resource type.""" + resources = [] + for i, rt in enumerate(resource_types): + resources.append( + DiscoveredResource( + resource_type=rt, + unique_id=f"id-{rt}-{i}", + name=f"resource-{rt}-{i}", + provider=self._provider, + platform_category=PROVIDER_PLATFORM_MAP[self._provider], + architecture=CpuArchitecture.AMD64, + endpoint="http://localhost:8080", + attributes={"key": "value"}, + ) + ) + progress_callback(ScanProgress( + current_resource_type=rt, + resources_discovered=i + 1, + resource_types_completed=i + 1, + total_resource_types=len(resource_types), + )) + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-01T00:00:00Z", + profile_hash="abc123", + ) + + +# --------------------------------------------------------------------------- +# Property 18: Multi-provider merge with naming conflict resolution +# --------------------------------------------------------------------------- + + +class TestMultiProviderMergeConflictResolution: + """Property 18: Multi-provider merge with naming conflict resolution. + + When resources from different providers share the same name, the merger + prefixes with provider identifier. + + **Validates: Requirements 5.3** + """ + + @given( + provider_a=provider_type_strategy, + provider_b=provider_type_strategy, + shared_name=resource_name_strategy, + ) + @settings(max_examples=100) + def test_conflicting_names_are_prefixed(self, provider_a, provider_b, shared_name): + """Resources with the same name from different providers get prefixed.""" + assume(provider_a != provider_b) + + resource_a = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0], + unique_id="id-a", + name=shared_name, + provider=provider_a, + platform_category=PROVIDER_PLATFORM_MAP[provider_a], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-a:8080", + attributes={"source": "a"}, + ) + resource_b = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_b][0], + unique_id="id-b", + name=shared_name, + provider=provider_b, + platform_category=PROVIDER_PLATFORM_MAP[provider_b], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-b:8080", + attributes={"source": "b"}, + ) + + scan_result_a = ScanResult( + resources=[resource_a], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + scan_result_b = ScanResult( + resources=[resource_b], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + + merger = ResourceMerger() + merged = merger.merge([scan_result_a, scan_result_b]) + + # Both resources must be present (no loss) + assert len(merged) == 2 + + # Conflicting names must be prefixed with provider identifier + merged_names = {r.name for r in merged} + expected_name_a = f"{provider_a.value}_{shared_name}" + expected_name_b = f"{provider_b.value}_{shared_name}" + assert expected_name_a in merged_names, ( + f"Expected '{expected_name_a}' in merged names, got: {merged_names}" + ) + assert expected_name_b in merged_names, ( + f"Expected '{expected_name_b}' in merged names, got: {merged_names}" + ) + + @given( + provider_a=provider_type_strategy, + provider_b=provider_type_strategy, + name_a=resource_name_strategy, + name_b=resource_name_strategy, + ) + @settings(max_examples=100) + def test_non_conflicting_names_unchanged(self, provider_a, provider_b, name_a, name_b): + """Resources with unique names across providers are not prefixed.""" + assume(provider_a != provider_b) + assume(name_a != name_b) + + resource_a = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0], + unique_id="id-a", + name=name_a, + provider=provider_a, + platform_category=PROVIDER_PLATFORM_MAP[provider_a], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-a:8080", + attributes={}, + ) + resource_b = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_b][0], + unique_id="id-b", + name=name_b, + provider=provider_b, + platform_category=PROVIDER_PLATFORM_MAP[provider_b], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-b:8080", + attributes={}, + ) + + scan_result_a = ScanResult( + resources=[resource_a], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + scan_result_b = ScanResult( + resources=[resource_b], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + + merger = ResourceMerger() + merged = merger.merge([scan_result_a, scan_result_b]) + + # No resources lost + assert len(merged) == 2 + + # Names should remain unchanged (no prefix) + merged_names = {r.name for r in merged} + assert name_a in merged_names, ( + f"Expected original name '{name_a}' preserved, got: {merged_names}" + ) + assert name_b in merged_names, ( + f"Expected original name '{name_b}' preserved, got: {merged_names}" + ) + + @given( + provider_a=provider_type_strategy, + provider_b=provider_type_strategy, + provider_c=provider_type_strategy, + shared_name=resource_name_strategy, + ) + @settings(max_examples=100) + def test_three_provider_conflict_all_prefixed( + self, provider_a, provider_b, provider_c, shared_name + ): + """When 3 providers share a name, all get prefixed.""" + assume(len({provider_a, provider_b, provider_c}) == 3) + + resources_and_results = [] + for provider in [provider_a, provider_b, provider_c]: + resource = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider][0], + unique_id=f"id-{provider.value}", + name=shared_name, + provider=provider, + platform_category=PROVIDER_PLATFORM_MAP[provider], + architecture=CpuArchitecture.AMD64, + endpoint=f"http://{provider.value}:8080", + attributes={}, + ) + result = ScanResult( + resources=[resource], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + resources_and_results.append(result) + + merger = ResourceMerger() + merged = merger.merge(resources_and_results) + + # All 3 resources preserved + assert len(merged) == 3 + + # All must be prefixed + for provider in [provider_a, provider_b, provider_c]: + expected = f"{provider.value}_{shared_name}" + assert any(r.name == expected for r in merged), ( + f"Expected prefixed name '{expected}' in merged results" + ) + + @given( + provider_a=provider_type_strategy, + provider_b=provider_type_strategy, + shared_name=resource_name_strategy, + ) + @settings(max_examples=100) + def test_merge_preserves_all_resources_no_loss( + self, provider_a, provider_b, shared_name + ): + """Merging never loses resources regardless of naming conflicts.""" + assume(provider_a != provider_b) + + resource_a = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0], + unique_id="id-a", + name=shared_name, + provider=provider_a, + platform_category=PROVIDER_PLATFORM_MAP[provider_a], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-a:8080", + attributes={"source": "a"}, + ) + resource_b = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_b][0], + unique_id="id-b", + name=shared_name, + provider=provider_b, + platform_category=PROVIDER_PLATFORM_MAP[provider_b], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-b:8080", + attributes={"source": "b"}, + ) + + # Also add a non-conflicting resource + resource_c = DiscoveredResource( + resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0], + unique_id="id-c", + name="unique_resource_name", + provider=provider_a, + platform_category=PROVIDER_PLATFORM_MAP[provider_a], + architecture=CpuArchitecture.AMD64, + endpoint="http://host-a:8080", + attributes={"source": "c"}, + ) + + scan_result_a = ScanResult( + resources=[resource_a, resource_c], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + scan_result_b = ScanResult( + resources=[resource_b], + warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + + merger = ResourceMerger() + merged = merger.merge([scan_result_a, scan_result_b]) + + # Total resources = 3 (no loss) + assert len(merged) == 3 + + # Provider-specific attributes preserved + unique_ids = {r.unique_id for r in merged} + assert "id-a" in unique_ids + assert "id-b" in unique_ids + assert "id-c" in unique_ids + + +# --------------------------------------------------------------------------- +# Property 19: Provider block generation +# --------------------------------------------------------------------------- + + +class TestProviderBlockGeneration: + """Property 19: Provider block generation. + + For any set of providers used, a provider block is generated for each. + + **Validates: Requirements 5.4** + """ + + @given( + providers=st.lists( + provider_type_strategy, + min_size=1, + max_size=6, + ).map(lambda ps: list(set(ps))), # Deduplicate + ) + @settings(max_examples=100) + def test_one_provider_block_per_distinct_provider(self, providers): + """Generated output contains exactly one provider block per distinct provider.""" + assume(len(providers) >= 1) + + # Create profiles for each provider + profiles = [ + ScanProfile( + provider=p, + credentials={"token": "test-token"}, + ) + for p in providers + ] + + provider_types = set(providers) + + generator = ProviderBlockGenerator() + result = generator.generate(profiles=profiles, provider_types=provider_types) + + # Result should be a GeneratedFile + assert isinstance(result, GeneratedFile) + assert result.filename == "providers.tf" + + content = result.content + + # Count provider blocks: each provider type should have exactly one + # provider "name" { block + for provider_type in provider_types: + # Get the terraform provider name for this type + from iac_reverse.generator.provider_block import _PROVIDER_METADATA + tf_name = _PROVIDER_METADATA[provider_type][0] + provider_block_marker = f'provider "{tf_name}"' + count = content.count(provider_block_marker) + assert count == 1, ( + f"Expected exactly 1 provider block for '{tf_name}', " + f"found {count} in:\n{content}" + ) + + @given( + providers=st.lists( + provider_type_strategy, + min_size=2, + max_size=6, + ).map(lambda ps: list(set(ps))), + ) + @settings(max_examples=100) + def test_required_providers_block_lists_all(self, providers): + """The terraform required_providers block lists all providers used.""" + assume(len(providers) >= 2) + + profiles = [ + ScanProfile( + provider=p, + credentials={"token": "test-token"}, + ) + for p in providers + ] + + provider_types = set(providers) + + generator = ProviderBlockGenerator() + result = generator.generate(profiles=profiles, provider_types=provider_types) + + content = result.content + + # The required_providers block must exist + assert "required_providers" in content + + # Each provider must appear in the required_providers block + from iac_reverse.generator.provider_block import _PROVIDER_METADATA + for provider_type in provider_types: + tf_name, source, _ = _PROVIDER_METADATA[provider_type] + assert tf_name in content, ( + f"Expected provider name '{tf_name}' in required_providers block" + ) + assert source in content, ( + f"Expected source '{source}' in required_providers block" + ) + + @given(provider=provider_type_strategy) + @settings(max_examples=100) + def test_single_provider_generates_one_block(self, provider): + """A single provider generates exactly one provider block.""" + profiles = [ + ScanProfile( + provider=provider, + credentials={"token": "test-token"}, + ) + ] + + generator = ProviderBlockGenerator() + result = generator.generate( + profiles=profiles, + provider_types={provider}, + ) + + content = result.content + + from iac_reverse.generator.provider_block import _PROVIDER_METADATA + tf_name = _PROVIDER_METADATA[provider][0] + + # Exactly one provider block + provider_block_marker = f'provider "{tf_name}"' + assert content.count(provider_block_marker) == 1 + + # terraform block with required_providers + assert "terraform {" in content + assert "required_providers" in content + + +# --------------------------------------------------------------------------- +# Property 20: Scan profile validation completeness (multi-provider scenarios) +# --------------------------------------------------------------------------- + + +class TestScanProfileValidationMultiProvider: + """Property 20: Scan profile validation completeness (multi-provider scenarios). + + Additional multi-provider scenarios beyond what's in test_scan_profile_validation_prop.py. + + **Validates: Requirements 6.1, 6.6, 6.7** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_valid_multi_provider_profiles_all_pass_validation( + self, provider, credentials + ): + """Each valid profile in a multi-provider set passes validation independently.""" + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + errors = profile.validate() + assert errors == [], f"Expected no errors for valid profile, got: {errors}" + + @given( + provider_a=provider_type_strategy, + provider_b=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_mixed_valid_invalid_profiles_detected_independently( + self, provider_a, provider_b, credentials + ): + """Invalid profiles are detected independently in a multi-provider set.""" + # Valid profile + valid_profile = ScanProfile( + provider=provider_a, + credentials=credentials, + resource_type_filters=None, + ) + # Invalid profile (empty credentials) + invalid_profile = ScanProfile( + provider=provider_b, + credentials={}, + resource_type_filters=None, + ) + + valid_errors = valid_profile.validate() + invalid_errors = invalid_profile.validate() + + assert valid_errors == [] + assert len(invalid_errors) >= 1 + assert any("credentials" in e for e in invalid_errors) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_cross_provider_resource_types_detected_as_unsupported( + self, provider, credentials + ): + """Resource types from a different provider are flagged as unsupported.""" + # Pick a resource type from a different provider + other_providers = [p for p in ProviderType if p != provider] + assume(len(other_providers) > 0) + other_provider = other_providers[0] + other_types = PROVIDER_SUPPORTED_RESOURCE_TYPES[other_provider] + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=other_types[:2], + ) + errors = profile.validate() + + # Should detect unsupported types (unless they happen to overlap) + supported = set(PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]) + unsupported = [t for t in other_types[:2] if t not in supported] + if unsupported: + assert any("unsupported" in e.lower() for e in errors), ( + f"Expected unsupported error for cross-provider types, got: {errors}" + ) + + +# --------------------------------------------------------------------------- +# Property 21: Filtering correctness +# --------------------------------------------------------------------------- + + +class TestFilteringCorrectness: + """Property 21: Filtering correctness. + + When resource type filters are specified, only those types appear in the + scan result. + + **Validates: Requirements 6.2, 6.4** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_filtered_scan_returns_only_filtered_types(self, provider, credentials): + """When filters are specified, only filtered resource types appear in results.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + assume(len(supported) >= 2) + + # Pick a subset of supported types as filter + filter_types = supported[:2] + + plugin = FilteringPlugin(provider=provider) + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=filter_types, + ) + + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # All discovered resources must have types in the filter list + for resource in result.resources: + assert resource.resource_type in filter_types, ( + f"Resource type '{resource.resource_type}' not in filter list " + f"{filter_types}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_no_filter_returns_all_supported_types(self, provider, credentials): + """When no filters are specified, all supported types are discovered.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + + plugin = FilteringPlugin(provider=provider) + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # All supported types should be discovered + discovered_types = {r.resource_type for r in result.resources} + for rt in supported: + assert rt in discovered_types, ( + f"Expected type '{rt}' to be discovered when no filter, " + f"got: {discovered_types}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_single_type_filter_returns_only_that_type(self, provider, credentials): + """A single-type filter returns only resources of that type.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + assume(len(supported) >= 1) + + single_filter = [supported[0]] + + plugin = FilteringPlugin(provider=provider) + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=single_filter, + ) + + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # Only the single filtered type should appear + for resource in result.resources: + assert resource.resource_type == single_filter[0], ( + f"Expected only '{single_filter[0]}', got '{resource.resource_type}'" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_empty_filter_returns_no_resources(self, provider, credentials): + """An empty filter list results in no resources discovered.""" + plugin = FilteringPlugin(provider=provider) + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=[], + ) + + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # Empty filter means nothing to discover + assert len(result.resources) == 0, ( + f"Expected 0 resources with empty filter, got {len(result.resources)}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_filter_subset_excludes_non_filtered_types(self, provider, credentials): + """Types not in the filter list do not appear in results.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + assume(len(supported) >= 3) + + # Filter to first 2 types only + filter_types = supported[:2] + excluded_types = supported[2:] + + plugin = FilteringPlugin(provider=provider) + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=filter_types, + ) + + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # None of the excluded types should appear + discovered_types = {r.resource_type for r in result.resources} + for excluded in excluded_types: + assert excluded not in discovered_types, ( + f"Excluded type '{excluded}' should not appear in filtered results" + ) diff --git a/tests/property/test_resource_inventory_completeness_prop.py b/tests/property/test_resource_inventory_completeness_prop.py new file mode 100644 index 0000000..e35a823 --- /dev/null +++ b/tests/property/test_resource_inventory_completeness_prop.py @@ -0,0 +1,222 @@ +"""Property-based tests for resource inventory completeness. + +**Validates: Requirements 1.2** + +Property 1: Resource inventory completeness +For any discovered resource from any on-premises provider (Docker Swarm, Kubernetes, +Synology, Harvester, Bare Metal, Windows), the resulting inventory entry SHALL contain +non-empty values for resource_type, unique_id, name, provider, platform_category, +architecture, and attributes fields. +""" + +from hypothesis import given, settings +from hypothesis import strategies as st + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) +platform_category_strategy = st.sampled_from(list(PlatformCategory)) +cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture)) + +non_empty_text_strategy = st.text( + min_size=1, + max_size=100, + alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")), +).filter(lambda s: s.strip() != "") + +resource_type_strategy = st.text( + min_size=1, + max_size=50, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"), +).filter(lambda s: len(s) > 0 and s.strip() != "") + +unique_id_strategy = st.text( + min_size=1, + max_size=200, + alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")), +).filter(lambda s: s.strip() != "") + +name_strategy = st.text( + min_size=1, + max_size=100, + alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")), +).filter(lambda s: s.strip() != "") + +endpoint_strategy = st.text( + min_size=1, + max_size=200, + alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")), +).filter(lambda s: s.strip() != "") + +non_empty_attributes_strategy = st.dictionaries( + keys=st.text( + min_size=1, + max_size=30, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"), + ), + values=st.one_of( + st.text(min_size=1, max_size=50), + st.integers(min_value=0, max_value=10000), + st.booleans(), + ), + min_size=1, + max_size=10, +) + +raw_references_strategy = st.lists( + st.text(min_size=0, max_size=100), + min_size=0, + max_size=5, +) + + +discovered_resource_strategy = st.builds( + DiscoveredResource, + resource_type=resource_type_strategy, + unique_id=unique_id_strategy, + name=name_strategy, + provider=provider_type_strategy, + platform_category=platform_category_strategy, + architecture=cpu_architecture_strategy, + endpoint=endpoint_strategy, + attributes=non_empty_attributes_strategy, + raw_references=raw_references_strategy, +) + + +# --------------------------------------------------------------------------- +# Property Tests +# --------------------------------------------------------------------------- + + +class TestResourceInventoryCompleteness: + """Property 1: Resource inventory completeness. + + **Validates: Requirements 1.2** + + For any discovered resource from any on-premises provider, the resulting + inventory entry SHALL contain non-empty values for resource_type, unique_id, + name, provider, platform_category, architecture, and attributes fields. + """ + + @given(resource=discovered_resource_strategy) + def test_resource_type_is_non_empty(self, resource: DiscoveredResource): + """resource_type field must be a non-empty string.""" + assert isinstance(resource.resource_type, str) + assert len(resource.resource_type) > 0 + assert resource.resource_type.strip() != "" + + @given(resource=discovered_resource_strategy) + def test_unique_id_is_non_empty(self, resource: DiscoveredResource): + """unique_id field must be a non-empty string.""" + assert isinstance(resource.unique_id, str) + assert len(resource.unique_id) > 0 + assert resource.unique_id.strip() != "" + + @given(resource=discovered_resource_strategy) + def test_name_is_non_empty(self, resource: DiscoveredResource): + """name field must be a non-empty string.""" + assert isinstance(resource.name, str) + assert len(resource.name) > 0 + assert resource.name.strip() != "" + + @given(resource=discovered_resource_strategy) + def test_provider_is_valid_enum(self, resource: DiscoveredResource): + """provider field must be a valid ProviderType enum value.""" + assert isinstance(resource.provider, ProviderType) + assert resource.provider is not None + + @given(resource=discovered_resource_strategy) + def test_platform_category_is_valid_enum(self, resource: DiscoveredResource): + """platform_category field must be a valid PlatformCategory enum value.""" + assert isinstance(resource.platform_category, PlatformCategory) + assert resource.platform_category is not None + + @given(resource=discovered_resource_strategy) + def test_architecture_is_valid_enum(self, resource: DiscoveredResource): + """architecture field must be a valid CpuArchitecture enum value.""" + assert isinstance(resource.architecture, CpuArchitecture) + assert resource.architecture is not None + + @given(resource=discovered_resource_strategy) + def test_attributes_is_non_empty_dict(self, resource: DiscoveredResource): + """attributes field must be a non-empty dictionary.""" + assert isinstance(resource.attributes, dict) + assert len(resource.attributes) > 0 + + @given(resource=discovered_resource_strategy) + def test_all_mandatory_fields_populated(self, resource: DiscoveredResource): + """All mandatory fields must be non-empty/non-None simultaneously.""" + # resource_type + assert isinstance(resource.resource_type, str) and len(resource.resource_type) > 0 + # unique_id + assert isinstance(resource.unique_id, str) and len(resource.unique_id) > 0 + # name + assert isinstance(resource.name, str) and len(resource.name) > 0 + # provider + assert isinstance(resource.provider, ProviderType) + # platform_category + assert isinstance(resource.platform_category, PlatformCategory) + # architecture + assert isinstance(resource.architecture, CpuArchitecture) + # attributes + assert isinstance(resource.attributes, dict) and len(resource.attributes) > 0 + + @given( + resources=st.lists(discovered_resource_strategy, min_size=1, max_size=10), + warnings=st.lists(st.text(min_size=0, max_size=50), max_size=3), + errors=st.lists(st.text(min_size=0, max_size=50), max_size=3), + ) + @settings(max_examples=50) + def test_scan_result_resources_all_have_mandatory_fields( + self, resources, warnings, errors + ): + """When a mock plugin produces resources in a ScanResult, every resource has all required fields.""" + scan_result = ScanResult( + resources=resources, + warnings=warnings, + errors=errors, + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test_hash_abc123", + is_partial=False, + ) + + for resource in scan_result.resources: + # resource_type is non-empty string + assert isinstance(resource.resource_type, str) + assert len(resource.resource_type) > 0 + assert resource.resource_type.strip() != "" + + # unique_id is non-empty string + assert isinstance(resource.unique_id, str) + assert len(resource.unique_id) > 0 + assert resource.unique_id.strip() != "" + + # name is non-empty string + assert isinstance(resource.name, str) + assert len(resource.name) > 0 + assert resource.name.strip() != "" + + # provider is valid enum + assert isinstance(resource.provider, ProviderType) + + # platform_category is valid enum + assert isinstance(resource.platform_category, PlatformCategory) + + # architecture is valid enum + assert isinstance(resource.architecture, CpuArchitecture) + + # attributes is non-empty dict + assert isinstance(resource.attributes, dict) + assert len(resource.attributes) > 0 diff --git a/tests/property/test_scan_profile_validation_prop.py b/tests/property/test_scan_profile_validation_prop.py new file mode 100644 index 0000000..8481f70 --- /dev/null +++ b/tests/property/test_scan_profile_validation_prop.py @@ -0,0 +1,257 @@ +"""Property-based tests for ScanProfile validation completeness. + +**Validates: Requirements 6.1, 6.6, 6.7** + +Property 20: Scan profile validation completeness +For any scan profile with K invalid fields (missing provider, empty credentials, +unreachable endpoints, filters exceeding 200 entries, or unsupported resource types), +the validation error SHALL list all K invalid fields in a single response. +""" + +from hypothesis import given, assume, settings +from hypothesis import strategies as st + +from iac_reverse.models import ( + MAX_RESOURCE_TYPE_FILTERS, + PROVIDER_SUPPORTED_RESOURCE_TYPES, + ProviderType, + ScanProfile, +) + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) + +non_empty_credentials_strategy = st.dictionaries( + keys=st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=("L", "N", "P"))), + values=st.text(min_size=1, max_size=50), + min_size=1, + max_size=5, +) + +empty_credentials_strategy = st.just({}) + + +def valid_resource_types_strategy(provider: ProviderType) -> st.SearchStrategy: + """Generate a list of valid resource types for the given provider.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + return st.lists(st.sampled_from(supported), min_size=0, max_size=min(len(supported), 10)) + + +invalid_resource_type_strategy = st.text( + min_size=5, max_size=30, + alphabet=st.characters(whitelist_categories=("L",)) +).filter( + lambda t: all(t not in types for types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values()) +) + + +# --------------------------------------------------------------------------- +# Property Tests +# --------------------------------------------------------------------------- + + +class TestScanProfileValidationCompleteness: + """Property 20: Scan profile validation completeness. + + **Validates: Requirements 6.1, 6.6, 6.7** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + def test_valid_profile_returns_no_errors(self, provider, credentials): + """A profile with non-empty credentials and no filters is always valid.""" + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + errors = profile.validate() + assert errors == [], f"Expected no errors for valid profile, got: {errors}" + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + def test_valid_profile_with_valid_filters_returns_no_errors(self, provider, credentials): + """A profile with valid credentials and valid resource type filters is valid.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=supported, + ) + errors = profile.validate() + assert errors == [], f"Expected no errors for valid profile with valid filters, got: {errors}" + + @given(provider=provider_type_strategy) + def test_empty_credentials_always_produces_credentials_error(self, provider): + """Empty credentials must always produce an error mentioning 'credentials'.""" + profile = ScanProfile( + provider=provider, + credentials={}, + resource_type_filters=None, + ) + errors = profile.validate() + assert len(errors) >= 1 + assert any("credentials" in e for e in errors), ( + f"Expected error mentioning 'credentials', got: {errors}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + extra_count=st.integers(min_value=1, max_value=50), + ) + def test_oversized_filters_produces_count_error(self, provider, credentials, extra_count): + """Filters exceeding MAX_RESOURCE_TYPE_FILTERS must produce an error about the count limit.""" + # Build a list that exceeds the limit using valid types repeated + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + oversized_count = MAX_RESOURCE_TYPE_FILTERS + extra_count + filters = (supported * (oversized_count // len(supported) + 1))[:oversized_count] + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=filters, + ) + errors = profile.validate() + assert any( + "at most" in e or str(MAX_RESOURCE_TYPE_FILTERS) in e + for e in errors + ), f"Expected error mentioning count limit, got: {errors}" + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + invalid_types=st.lists(invalid_resource_type_strategy, min_size=1, max_size=5), + ) + def test_unsupported_types_produces_unsupported_error(self, provider, credentials, invalid_types): + """Unsupported resource types must produce an error mentioning them.""" + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=invalid_types, + ) + errors = profile.validate() + assert any("unsupported" in e.lower() for e in errors), ( + f"Expected error mentioning unsupported types, got: {errors}" + ) + + @given( + provider=provider_type_strategy, + invalid_types=st.lists(invalid_resource_type_strategy, min_size=1, max_size=3), + ) + def test_no_short_circuit_credentials_and_unsupported(self, provider, invalid_types): + """When both credentials are empty AND unsupported types exist, both errors are reported.""" + profile = ScanProfile( + provider=provider, + credentials={}, + resource_type_filters=invalid_types, + ) + errors = profile.validate() + assert len(errors) >= 2, f"Expected at least 2 errors, got {len(errors)}: {errors}" + assert any("credentials" in e for e in errors), ( + f"Expected credentials error, got: {errors}" + ) + assert any("unsupported" in e.lower() for e in errors), ( + f"Expected unsupported types error, got: {errors}" + ) + + @given( + provider=provider_type_strategy, + extra_count=st.integers(min_value=1, max_value=20), + ) + def test_no_short_circuit_credentials_and_oversized(self, provider, extra_count): + """When both credentials are empty AND filters exceed limit, both errors are reported.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + oversized_count = MAX_RESOURCE_TYPE_FILTERS + extra_count + filters = (supported * (oversized_count // len(supported) + 1))[:oversized_count] + + profile = ScanProfile( + provider=provider, + credentials={}, + resource_type_filters=filters, + ) + errors = profile.validate() + assert len(errors) >= 2, f"Expected at least 2 errors, got {len(errors)}: {errors}" + assert any("credentials" in e for e in errors), ( + f"Expected credentials error, got: {errors}" + ) + assert any( + "at most" in e or str(MAX_RESOURCE_TYPE_FILTERS) in e + for e in errors + ), f"Expected count limit error, got: {errors}" + + @given( + provider=provider_type_strategy, + extra_count=st.integers(min_value=1, max_value=10), + invalid_types=st.lists(invalid_resource_type_strategy, min_size=1, max_size=3), + ) + def test_no_short_circuit_all_three_issues(self, provider, extra_count, invalid_types): + """When credentials empty, filters oversized, AND unsupported types exist, all errors reported.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + oversized_count = MAX_RESOURCE_TYPE_FILTERS + extra_count + # Mix valid types (to reach oversized count) with invalid types + valid_padding = (supported * (oversized_count // len(supported) + 1))[:oversized_count] + filters = valid_padding + invalid_types + + profile = ScanProfile( + provider=provider, + credentials={}, + resource_type_filters=filters, + ) + errors = profile.validate() + assert len(errors) >= 3, f"Expected at least 3 errors, got {len(errors)}: {errors}" + assert any("credentials" in e for e in errors), ( + f"Expected credentials error, got: {errors}" + ) + assert any( + "at most" in e or str(MAX_RESOURCE_TYPE_FILTERS) in e + for e in errors + ), f"Expected count limit error, got: {errors}" + assert any("unsupported" in e.lower() for e in errors), ( + f"Expected unsupported types error, got: {errors}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + def test_empty_list_filters_is_valid(self, provider, credentials): + """An empty resource_type_filters list (not None) should be valid.""" + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=[], + ) + errors = profile.validate() + assert errors == [], f"Expected no errors for empty filter list, got: {errors}" + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + count=st.integers(min_value=1, max_value=MAX_RESOURCE_TYPE_FILTERS), + ) + def test_filters_at_or_below_limit_with_valid_types_is_valid(self, provider, credentials, count): + """Any number of valid filters at or below the limit should produce no count error.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + # Repeat valid types to reach the desired count + filters = (supported * (count // len(supported) + 1))[:count] + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=filters, + ) + errors = profile.validate() + # Should not have a count limit error + assert not any( + "at most" in e or (str(MAX_RESOURCE_TYPE_FILTERS) in e and "entries" in e) + for e in errors + ), f"Unexpected count limit error for {count} filters: {errors}" diff --git a/tests/property/test_scanner_behavior_prop.py b/tests/property/test_scanner_behavior_prop.py new file mode 100644 index 0000000..a5216ec --- /dev/null +++ b/tests/property/test_scanner_behavior_prop.py @@ -0,0 +1,608 @@ +"""Property-based tests for Scanner behavior. + +**Validates: Requirements 1.3, 1.4, 1.5, 1.7** + +Property 2: Authentication error descriptiveness +For any provider type and any authentication failure reason, the error returned +by the Scanner SHALL contain both the provider name string and the failure reason string. + +Property 3: Graceful degradation on unsupported resource types +For any scan request containing a mix of supported and unsupported resource types, +the Scanner SHALL produce warnings for each unsupported type AND return a complete +inventory for all supported types. + +Property 4: Progress reporting frequency +The Scanner SHALL report progress at least once per resource type completion. + +Property 5: Partial inventory preservation on failure +If the Provider API connection is lost during an active scan, the Scanner SHALL +return a partial resource inventory. +""" + +from typing import Callable + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + PROVIDER_SUPPORTED_RESOURCE_TYPES, + ProviderType, + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import ( + AuthenticationError, + ConnectionLostError, + Scanner, +) + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) + +non_empty_string_strategy = st.text( + min_size=1, + max_size=100, + alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")), +).filter(lambda s: s.strip()) + +non_empty_credentials_strategy = st.dictionaries( + keys=st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=("L", "N"))), + values=st.text(min_size=1, max_size=50), + min_size=1, + max_size=5, +) + +unsupported_resource_type_strategy = st.text( + min_size=5, + max_size=30, + alphabet=st.characters(whitelist_categories=("L",)), +).filter( + lambda t: all(t not in types for types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values()) +) + + +# --------------------------------------------------------------------------- +# Mock Plugin Implementations +# --------------------------------------------------------------------------- + + +class FailingAuthPlugin(ProviderPlugin): + """A plugin that always fails authentication with a given reason.""" + + def __init__(self, failure_reason: str): + self.failure_reason = failure_reason + + def authenticate(self, credentials: dict[str, str]) -> None: + raise RuntimeError(self.failure_reason) + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["http://localhost:8080"] + + def list_supported_resource_types(self) -> list[str]: + return ["mock_resource"] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + return ScanResult( + resources=[], warnings=[], errors=[], + scan_timestamp="", profile_hash="", + ) + + +class GracefulDegradationPlugin(ProviderPlugin): + """A plugin that supports specific resource types and discovers resources for them.""" + + def __init__(self, supported_types: list[str]): + self._supported_types = supported_types + + def authenticate(self, credentials: dict[str, str]) -> None: + pass # Always succeeds + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["http://localhost:8080"] + + def list_supported_resource_types(self) -> list[str]: + return self._supported_types + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + # Create one resource per supported resource type requested + resources = [] + for i, rt in enumerate(resource_types): + resources.append( + DiscoveredResource( + resource_type=rt, + unique_id=f"id-{rt}-{i}", + name=f"resource-{rt}-{i}", + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint="http://localhost:8080", + attributes={"key": "value"}, + ) + ) + progress_callback(ScanProgress( + current_resource_type=rt, + resources_discovered=i + 1, + resource_types_completed=i + 1, + total_resource_types=len(resource_types), + )) + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + +class ProgressTrackingPlugin(ProviderPlugin): + """A plugin that reports progress per resource type.""" + + def __init__(self, supported_types: list[str]): + self._supported_types = supported_types + + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["http://localhost:8080"] + + def list_supported_resource_types(self) -> list[str]: + return self._supported_types + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + resources = [] + for i, rt in enumerate(resource_types): + resource = DiscoveredResource( + resource_type=rt, + unique_id=f"id-{rt}-{i}", + name=f"resource-{rt}-{i}", + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint="http://localhost:8080", + attributes={}, + ) + resources.append(resource) + progress_callback(ScanProgress( + current_resource_type=rt, + resources_discovered=i + 1, + resource_types_completed=i + 1, + total_resource_types=len(resource_types), + )) + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + +class ConnectionLossPlugin(ProviderPlugin): + """A plugin that loses connection after discovering some resources.""" + + def __init__(self, supported_types: list[str], fail_after: int): + self._supported_types = supported_types + self._fail_after = fail_after + + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["http://localhost:8080"] + + def list_supported_resource_types(self) -> list[str]: + return self._supported_types + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback: Callable[[ScanProgress], None], + ) -> ScanResult: + # Simulate connection loss by raising ConnectionError + raise ConnectionError( + f"Connection lost after discovering {self._fail_after} resources" + ) + + +# --------------------------------------------------------------------------- +# Property Tests +# --------------------------------------------------------------------------- + + +class TestAuthenticationErrorDescriptiveness: + """Property 2: Authentication error descriptiveness. + + For any provider type and any authentication failure reason, the error + returned by the Scanner SHALL contain both the provider name string and + the failure reason string. + + **Validates: Requirements 1.3** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + failure_reason=non_empty_string_strategy, + ) + @settings(max_examples=100) + def test_auth_error_contains_provider_name_and_reason( + self, provider, credentials, failure_reason + ): + """AuthenticationError must contain both provider name and failure reason.""" + plugin = FailingAuthPlugin(failure_reason=failure_reason) + profile = ScanProfile( + provider=provider, + credentials=credentials, + ) + scanner = Scanner(profile=profile, plugin=plugin) + + with pytest.raises(AuthenticationError) as exc_info: + scanner.scan() + + error = exc_info.value + # The error must contain the provider name + assert provider.value in str(error), ( + f"Expected provider name '{provider.value}' in error message, " + f"got: '{str(error)}'" + ) + # The error must contain the failure reason + assert failure_reason in str(error), ( + f"Expected failure reason '{failure_reason}' in error message, " + f"got: '{str(error)}'" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + failure_reason=non_empty_string_strategy, + ) + @settings(max_examples=100) + def test_auth_error_attributes_match(self, provider, credentials, failure_reason): + """AuthenticationError attributes must store provider_name and reason.""" + plugin = FailingAuthPlugin(failure_reason=failure_reason) + profile = ScanProfile( + provider=provider, + credentials=credentials, + ) + scanner = Scanner(profile=profile, plugin=plugin) + + with pytest.raises(AuthenticationError) as exc_info: + scanner.scan() + + error = exc_info.value + assert error.provider_name == provider.value + assert error.reason == failure_reason + + +class TestGracefulDegradationOnUnsupportedTypes: + """Property 3: Graceful degradation on unsupported resource types. + + For any scan request containing a mix of supported and unsupported resource + types, the Scanner SHALL produce warnings for each unsupported type AND + return a complete inventory for all supported types. + + **Validates: Requirements 1.4** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + unsupported_types=st.lists(unsupported_resource_type_strategy, min_size=1, max_size=5), + ) + @settings(max_examples=100) + def test_unsupported_types_produce_warnings( + self, provider, credentials, unsupported_types + ): + """Each unsupported resource type must produce a warning.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = GracefulDegradationPlugin(supported_types=supported) + + # Mix supported and unsupported types + mixed_filters = list(supported[:2]) + unsupported_types + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=mixed_filters, + ) + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # There must be a warning for each unsupported type + for unsupported in unsupported_types: + assert any(unsupported in w for w in result.warnings), ( + f"Expected warning for unsupported type '{unsupported}', " + f"got warnings: {result.warnings}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + unsupported_types=st.lists(unsupported_resource_type_strategy, min_size=1, max_size=5), + ) + @settings(max_examples=100) + def test_supported_types_still_discovered( + self, provider, credentials, unsupported_types + ): + """Supported types must still be fully discovered despite unsupported types.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = GracefulDegradationPlugin(supported_types=supported) + + # Use at least one supported type plus unsupported types + supported_subset = supported[:2] + mixed_filters = supported_subset + unsupported_types + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=mixed_filters, + ) + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # All supported types in the filter should have resources discovered + discovered_types = {r.resource_type for r in result.resources} + for st_type in supported_subset: + assert st_type in discovered_types, ( + f"Expected supported type '{st_type}' to be discovered, " + f"but only found: {discovered_types}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + unsupported_types=st.lists(unsupported_resource_type_strategy, min_size=1, max_size=5), + ) + @settings(max_examples=100) + def test_warning_count_matches_unsupported_count( + self, provider, credentials, unsupported_types + ): + """Number of warnings must be at least the number of unsupported types.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = GracefulDegradationPlugin(supported_types=supported) + + # Only unsupported types in filter + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=unsupported_types, + ) + scanner = Scanner(profile=profile, plugin=plugin) + result = scanner.scan() + + # Deduplicate unsupported types for comparison + unique_unsupported = set(unsupported_types) + assert len(result.warnings) >= len(unique_unsupported), ( + f"Expected at least {len(unique_unsupported)} warnings, " + f"got {len(result.warnings)}: {result.warnings}" + ) + + +class TestProgressReportingFrequency: + """Property 4: Progress reporting frequency. + + For any scan across N resource types, the progress callback SHALL be + invoked at least N times, once per resource type completion. + + **Validates: Requirements 1.5** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_progress_reported_at_least_once_per_resource_type( + self, provider, credentials + ): + """Progress callback must be invoked at least once per resource type.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = ProgressTrackingPlugin(supported_types=supported) + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, # Scan all supported types + ) + + progress_reports: list[ScanProgress] = [] + + def track_progress(progress: ScanProgress) -> None: + progress_reports.append(progress) + + scanner = Scanner(profile=profile, plugin=plugin) + scanner.scan(progress_callback=track_progress) + + # Must have at least N progress reports for N resource types + assert len(progress_reports) >= len(supported), ( + f"Expected at least {len(supported)} progress reports, " + f"got {len(progress_reports)}" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + ) + @settings(max_examples=100) + def test_progress_reports_cover_all_resource_types( + self, provider, credentials + ): + """Progress reports must cover every resource type being scanned.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = ProgressTrackingPlugin(supported_types=supported) + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + + progress_reports: list[ScanProgress] = [] + + def track_progress(progress: ScanProgress) -> None: + progress_reports.append(progress) + + scanner = Scanner(profile=profile, plugin=plugin) + scanner.scan(progress_callback=track_progress) + + # Every resource type should appear in at least one progress report + reported_types = {p.current_resource_type for p in progress_reports} + for rt in supported: + assert rt in reported_types, ( + f"Expected resource type '{rt}' in progress reports, " + f"but only found: {reported_types}" + ) + + +class TestPartialInventoryPreservationOnFailure: + """Property 5: Partial inventory preservation on failure. + + If the Provider API connection is lost during an active scan, the Scanner + SHALL return a partial resource inventory. + + **Validates: Requirements 1.7** + """ + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + fail_after=st.integers(min_value=0, max_value=10), + ) + @settings(max_examples=100) + def test_connection_loss_raises_with_partial_result( + self, provider, credentials, fail_after + ): + """Connection loss must raise ConnectionLostError with partial result.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = ConnectionLossPlugin( + supported_types=supported, + fail_after=fail_after, + ) + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + scanner = Scanner(profile=profile, plugin=plugin) + + with pytest.raises(ConnectionLostError) as exc_info: + scanner.scan() + + error = exc_info.value + # Must have a partial_result attribute + assert hasattr(error, "partial_result") + partial = error.partial_result + assert isinstance(partial, ScanResult) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + fail_after=st.integers(min_value=0, max_value=10), + ) + @settings(max_examples=100) + def test_partial_result_is_marked_as_partial( + self, provider, credentials, fail_after + ): + """Partial result from connection loss must have is_partial=True.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = ConnectionLossPlugin( + supported_types=supported, + fail_after=fail_after, + ) + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + scanner = Scanner(profile=profile, plugin=plugin) + + with pytest.raises(ConnectionLostError) as exc_info: + scanner.scan() + + partial = exc_info.value.partial_result + assert partial.is_partial is True, ( + "Partial result from connection loss must have is_partial=True" + ) + + @given( + provider=provider_type_strategy, + credentials=non_empty_credentials_strategy, + fail_after=st.integers(min_value=0, max_value=10), + ) + @settings(max_examples=100) + def test_partial_result_contains_error_info( + self, provider, credentials, fail_after + ): + """Partial result must contain error information about the failure.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + plugin = ConnectionLossPlugin( + supported_types=supported, + fail_after=fail_after, + ) + + profile = ScanProfile( + provider=provider, + credentials=credentials, + resource_type_filters=None, + ) + scanner = Scanner(profile=profile, plugin=plugin) + + with pytest.raises(ConnectionLostError) as exc_info: + scanner.scan() + + partial = exc_info.value.partial_result + # Must have at least one error or warning indicating the failure + assert len(partial.errors) > 0 or len(partial.warnings) > 0, ( + "Partial result must contain error/warning info about the connection loss" + ) diff --git a/tests/property/test_state_builder_prop.py b/tests/property/test_state_builder_prop.py new file mode 100644 index 0000000..db193a4 --- /dev/null +++ b/tests/property/test_state_builder_prop.py @@ -0,0 +1,567 @@ +"""Property-based tests for the State Builder. + +**Validates: Requirements 4.1, 4.2, 4.4, 4.5** + +Properties tested: +- Property 16: State file structural validity +- Property 17: State entry completeness and schema correctness +""" + +import json +import re +import uuid + +from hypothesis import given, settings, assume +from hypothesis import strategies as st + +from iac_reverse.generator.sanitize import sanitize_identifier +from iac_reverse.models import ( + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + GeneratedFile, + PlatformCategory, + PROVIDER_SUPPORTED_RESOURCE_TYPES, + ProviderType, + ResourceRelationship, +) +from iac_reverse.state_builder import StateBuilder + + +# --------------------------------------------------------------------------- +# Hypothesis Strategies +# --------------------------------------------------------------------------- + +provider_type_strategy = st.sampled_from(list(ProviderType)) +platform_category_strategy = st.sampled_from(list(PlatformCategory)) +cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture)) + +# All supported resource types across all providers (flat list) +ALL_SUPPORTED_RESOURCE_TYPES = [] +for _types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values(): + ALL_SUPPORTED_RESOURCE_TYPES.extend(_types) + +resource_type_strategy = st.sampled_from(ALL_SUPPORTED_RESOURCE_TYPES) + +# Strategy for resource names (valid identifiers with some variety) +resource_name_strategy = st.text( + min_size=1, + max_size=20, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"), +).filter(lambda s: s.strip() != "") + +# Strategy for unique IDs (non-empty strings) +unique_id_strategy = st.text( + min_size=1, + max_size=40, + alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-/:."), +).filter(lambda s: s.strip() != "") + +# Strategy for simple attribute values +simple_attr_value_strategy = st.one_of( + st.text( + min_size=1, + max_size=30, + alphabet=st.characters( + whitelist_categories=("L", "N"), whitelist_characters="_-./: " + ), + ).filter(lambda s: s.strip() != ""), + st.integers(min_value=0, max_value=10000), + st.booleans(), +) + +# Strategy for attribute dictionaries (non-empty) +attributes_strategy = st.dictionaries( + keys=st.text( + min_size=1, + max_size=15, + alphabet=st.characters(whitelist_categories=("L",), whitelist_characters="_"), + ).filter(lambda s: s.strip() != "" and s[0].isalpha()), + values=simple_attr_value_strategy, + min_size=1, + max_size=5, +) + +# Strategy for provider version strings (semver-like) +provider_version_strategy = st.from_regex(r"[1-9][0-9]{0,1}\.[0-9]{1,2}\.[0-9]{1,2}", fullmatch=True) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_resource( + unique_id: str, + resource_type: str = "kubernetes_deployment", + name: str = "my_resource", + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, + architecture: CpuArchitecture = CpuArchitecture.AMD64, + attributes: dict | None = None, + raw_references: list[str] | None = None, +) -> DiscoveredResource: + """Helper to create a DiscoveredResource with sensible defaults.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=architecture, + endpoint="https://api.internal.lab:6443", + attributes=attributes or {"key": "value"}, + raw_references=raw_references or [], + ) + + +def make_dependency_graph( + resources: list[DiscoveredResource], + relationships: list[ResourceRelationship] | None = None, +) -> DependencyGraph: + """Helper to create a DependencyGraph from resources.""" + return DependencyGraph( + resources=resources, + relationships=relationships or [], + topological_order=[r.unique_id for r in resources], + cycles=[], + unresolved_references=[], + ) + + +def make_code_generation_result() -> CodeGenerationResult: + """Helper to create a minimal CodeGenerationResult.""" + return CodeGenerationResult( + resource_files=[ + GeneratedFile(filename="main.tf", content="", resource_count=0) + ], + variables_file=GeneratedFile( + filename="variables.tf", content="", resource_count=0 + ), + provider_file=GeneratedFile( + filename="provider.tf", content="", resource_count=0 + ), + ) + + +# --------------------------------------------------------------------------- +# Composite strategies +# --------------------------------------------------------------------------- + + +@st.composite +def mappable_resource_strategy(draw): + """Generate a single DiscoveredResource that is mappable to state. + + A mappable resource has a non-empty unique_id and a recognized resource type. + """ + resource_type = draw(resource_type_strategy) + name = draw(resource_name_strategy) + unique_id = draw(unique_id_strategy) + provider = draw(provider_type_strategy) + platform_category = draw(platform_category_strategy) + architecture = draw(cpu_architecture_strategy) + attributes = draw(attributes_strategy) + + return make_resource( + unique_id=unique_id, + resource_type=resource_type, + name=name, + provider=provider, + platform_category=platform_category, + architecture=architecture, + attributes=attributes, + ) + + +@st.composite +def multiple_mappable_resources_strategy(draw): + """Generate a list of mappable resources with unique IDs.""" + num_resources = draw(st.integers(min_value=1, max_value=5)) + resources = [] + seen_ids = set() + + for _ in range(num_resources): + resource = draw(mappable_resource_strategy()) + # Ensure unique IDs are distinct + if resource.unique_id in seen_ids: + continue + seen_ids.add(resource.unique_id) + resources.append(resource) + + assume(len(resources) >= 1) + return resources + + +@st.composite +def resource_with_sensitive_attrs_strategy(draw): + """Generate a resource with attributes that include sensitive-looking keys.""" + resource_type = draw(resource_type_strategy) + name = draw(resource_name_strategy) + unique_id = draw(unique_id_strategy) + + # Include at least one sensitive key + sensitive_key = draw(st.sampled_from([ + "password", "api_secret", "auth_token", "private_key", "tls_certificate", + ])) + sensitive_value = draw(st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnop")) + + # Also include non-sensitive attributes + normal_attrs = draw(attributes_strategy) + normal_attrs[sensitive_key] = sensitive_value + + return make_resource( + unique_id=unique_id, + resource_type=resource_type, + name=name, + attributes=normal_attrs, + ) + + +# --------------------------------------------------------------------------- +# Property 16: State file structural validity +# --------------------------------------------------------------------------- + + +class TestStateFileStructuralValidity: + """Property 16: State file structural validity. + + **Validates: Requirements 4.1** + + For any set of resources, the generated state file has version=4, + valid UUID lineage, serial=1, and valid JSON structure. + """ + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_file_version_is_4( + self, resources: list[DiscoveredResource] + ): + """The generated state file always has version=4.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert state_file.version == 4, ( + f"Expected version=4, got version={state_file.version}" + ) + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_file_has_valid_uuid_lineage( + self, resources: list[DiscoveredResource] + ): + """The generated state file has a valid UUID lineage.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + # Lineage should be a valid UUID + try: + parsed_uuid = uuid.UUID(state_file.lineage) + except ValueError: + raise AssertionError( + f"Lineage '{state_file.lineage}' is not a valid UUID" + ) + + assert parsed_uuid.version == 4, ( + f"Expected UUID version 4, got version {parsed_uuid.version}" + ) + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_file_serial_is_1( + self, resources: list[DiscoveredResource] + ): + """The generated state file always has serial=1.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert state_file.serial == 1, ( + f"Expected serial=1, got serial={state_file.serial}" + ) + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_file_produces_valid_json( + self, resources: list[DiscoveredResource] + ): + """The state file serializes to valid JSON via to_json().""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + json_str = state_file.to_json() + + # Must parse as valid JSON + try: + parsed = json.loads(json_str) + except json.JSONDecodeError as e: + raise AssertionError( + f"State file to_json() produced invalid JSON: {e}" + ) + + assert isinstance(parsed, dict), "State JSON root must be a dict" + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_json_has_required_top_level_fields( + self, resources: list[DiscoveredResource] + ): + """The serialized state JSON has version, terraform_version, serial, lineage, resources.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + parsed = json.loads(state_file.to_json()) + + required_fields = {"version", "terraform_version", "serial", "lineage", "resources"} + missing = required_fields - set(parsed.keys()) + assert not missing, ( + f"State JSON missing required top-level fields: {missing}" + ) + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_json_resource_entries_have_required_fields( + self, resources: list[DiscoveredResource] + ): + """Each resource entry in the JSON has mode, type, name, provider, and instances.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + parsed = json.loads(state_file.to_json()) + + required_resource_fields = {"mode", "type", "name", "provider", "instances"} + + for i, entry in enumerate(parsed["resources"]): + missing = required_resource_fields - set(entry.keys()) + assert not missing, ( + f"Resource entry {i} missing required fields: {missing}. " + f"Entry keys: {list(entry.keys())}" + ) + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_json_instances_have_schema_and_attributes( + self, resources: list[DiscoveredResource] + ): + """Each instance in the state JSON has schema_version, attributes, sensitive_attributes, dependencies.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + parsed = json.loads(state_file.to_json()) + + required_instance_fields = { + "schema_version", "attributes", "sensitive_attributes", "dependencies" + } + + for i, entry in enumerate(parsed["resources"]): + for j, instance in enumerate(entry["instances"]): + missing = required_instance_fields - set(instance.keys()) + assert not missing, ( + f"Resource {i}, instance {j} missing fields: {missing}. " + f"Instance keys: {list(instance.keys())}" + ) + + +# --------------------------------------------------------------------------- +# Property 17: State entry completeness and schema correctness +# --------------------------------------------------------------------------- + + +class TestStateEntryCompletenessAndSchemaCorrectness: + """Property 17: State entry completeness and schema correctness. + + **Validates: Requirements 4.4, 4.5** + + For any resource, the state entry has non-empty resource_type, + resource_name, provider_id, and attributes matching the discovery data. + """ + + @given(resource=mappable_resource_strategy()) + @settings(max_examples=100) + def test_state_entry_has_non_empty_resource_type( + self, resource: DiscoveredResource + ): + """Each state entry has a non-empty resource_type.""" + builder = StateBuilder() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert len(state_file.resources) == 1 + entry = state_file.resources[0] + assert entry.resource_type != "", ( + "State entry resource_type must not be empty" + ) + assert entry.resource_type == resource.resource_type, ( + f"Expected resource_type '{resource.resource_type}', " + f"got '{entry.resource_type}'" + ) + + @given(resource=mappable_resource_strategy()) + @settings(max_examples=100) + def test_state_entry_has_non_empty_resource_name( + self, resource: DiscoveredResource + ): + """Each state entry has a non-empty resource_name (sanitized).""" + builder = StateBuilder() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert len(state_file.resources) == 1 + entry = state_file.resources[0] + assert entry.resource_name != "", ( + "State entry resource_name must not be empty" + ) + # The name should be a sanitized version of the original + expected_name = sanitize_identifier(resource.name) + assert entry.resource_name == expected_name, ( + f"Expected resource_name '{expected_name}', " + f"got '{entry.resource_name}'" + ) + + @given(resource=mappable_resource_strategy()) + @settings(max_examples=100) + def test_state_entry_has_non_empty_provider_id( + self, resource: DiscoveredResource + ): + """Each state entry has a non-empty provider_id matching the resource's unique_id.""" + builder = StateBuilder() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert len(state_file.resources) == 1 + entry = state_file.resources[0] + assert entry.provider_id != "", ( + "State entry provider_id must not be empty" + ) + assert entry.provider_id == resource.unique_id, ( + f"Expected provider_id '{resource.unique_id}', " + f"got '{entry.provider_id}'" + ) + + @given(resource=mappable_resource_strategy()) + @settings(max_examples=100) + def test_state_entry_attributes_match_discovery_data( + self, resource: DiscoveredResource + ): + """State entry attributes contain all attributes from the discovered resource.""" + builder = StateBuilder() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert len(state_file.resources) == 1 + entry = state_file.resources[0] + + # All discovery attributes should be present in the state entry + for key, value in resource.attributes.items(): + assert key in entry.attributes, ( + f"Discovery attribute '{key}' missing from state entry attributes. " + f"State attrs: {list(entry.attributes.keys())}" + ) + assert entry.attributes[key] == value, ( + f"Attribute '{key}' mismatch: discovery={value}, " + f"state={entry.attributes[key]}" + ) + + @given( + resource=mappable_resource_strategy(), + provider_version=provider_version_strategy, + ) + @settings(max_examples=100) + def test_state_entry_schema_version_matches_provider_version( + self, resource: DiscoveredResource, provider_version: str + ): + """State entry schema_version matches the major version from provider_version.""" + builder = StateBuilder() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, provider_version) + + assert len(state_file.resources) == 1 + entry = state_file.resources[0] + + # Schema version should be the major version number + expected_schema_version = int(provider_version.split(".")[0]) + assert entry.schema_version == expected_schema_version, ( + f"Expected schema_version={expected_schema_version} " + f"(from provider_version='{provider_version}'), " + f"got schema_version={entry.schema_version}" + ) + + @given(resource=resource_with_sensitive_attrs_strategy()) + @settings(max_examples=100) + def test_state_entry_marks_sensitive_attributes( + self, resource: DiscoveredResource + ): + """State entry identifies and marks sensitive attributes correctly.""" + builder = StateBuilder() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + + assert len(state_file.resources) == 1 + entry = state_file.resources[0] + + # Sensitive attributes list should not be empty when resource has + # attributes with sensitive patterns (password, secret, token, key, certificate) + sensitive_patterns = ["password", "secret", "token", "key", "certificate"] + has_sensitive = any( + any(pattern in attr_key.lower() for pattern in sensitive_patterns) + for attr_key in resource.attributes.keys() + ) + + if has_sensitive: + assert len(entry.sensitive_attributes) > 0, ( + f"Resource has sensitive-looking attributes " + f"{list(resource.attributes.keys())} but sensitive_attributes " + f"is empty" + ) + + @given(resources=multiple_mappable_resources_strategy()) + @settings(max_examples=100) + def test_state_json_id_field_matches_provider_id( + self, resources: list[DiscoveredResource] + ): + """In the serialized JSON, each instance's attributes.id matches the provider_id.""" + builder = StateBuilder() + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + state_file = builder.build(code_result, graph, "1.0.0") + parsed = json.loads(state_file.to_json()) + + for i, entry in enumerate(parsed["resources"]): + for instance in entry["instances"]: + assert "id" in instance["attributes"], ( + f"Resource entry {i} instance missing 'id' in attributes" + ) + # The id should be non-empty + assert instance["attributes"]["id"] != "", ( + f"Resource entry {i} has empty 'id' attribute" + ) diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..b3fcc59 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1 @@ +"""Unit tests for IaC Reverse Engineering Tool.""" diff --git a/tests/unit/__pycache__/__init__.cpython-313.pyc b/tests/unit/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000..1244262 Binary files /dev/null and b/tests/unit/__pycache__/__init__.cpython-313.pyc differ diff --git a/tests/unit/__pycache__/test_authentik.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_authentik.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..0d462d9 Binary files /dev/null and b/tests/unit/__pycache__/test_authentik.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_bare_metal_plugin.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_bare_metal_plugin.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..b43dfbe Binary files /dev/null and b/tests/unit/__pycache__/test_bare_metal_plugin.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_change_detector.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_change_detector.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..a72517b Binary files /dev/null and b/tests/unit/__pycache__/test_change_detector.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_cli.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_cli.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..57b9114 Binary files /dev/null and b/tests/unit/__pycache__/test_cli.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_code_generator.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_code_generator.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..5fcb57d Binary files /dev/null and b/tests/unit/__pycache__/test_code_generator.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_docker_swarm_plugin.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_docker_swarm_plugin.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..0738fe0 Binary files /dev/null and b/tests/unit/__pycache__/test_docker_swarm_plugin.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_harvester_plugin.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_harvester_plugin.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..572dd43 Binary files /dev/null and b/tests/unit/__pycache__/test_harvester_plugin.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_incremental_updater.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_incremental_updater.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..111f513 Binary files /dev/null and b/tests/unit/__pycache__/test_incremental_updater.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_kubernetes_plugin.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_kubernetes_plugin.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..831fced Binary files /dev/null and b/tests/unit/__pycache__/test_kubernetes_plugin.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_models.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_models.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..dc974e2 Binary files /dev/null and b/tests/unit/__pycache__/test_models.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_multi_provider_scanner.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_multi_provider_scanner.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..4fd4f96 Binary files /dev/null and b/tests/unit/__pycache__/test_multi_provider_scanner.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_plugin_base.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_plugin_base.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..b3c0d9e Binary files /dev/null and b/tests/unit/__pycache__/test_plugin_base.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_profile_loader.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_profile_loader.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..ebbf8ce Binary files /dev/null and b/tests/unit/__pycache__/test_profile_loader.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_provider_block.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_provider_block.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..5c60a3c Binary files /dev/null and b/tests/unit/__pycache__/test_provider_block.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_resolver.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_resolver.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..febbbbe Binary files /dev/null and b/tests/unit/__pycache__/test_resolver.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_resolver_cycles.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_resolver_cycles.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..7c6b91a Binary files /dev/null and b/tests/unit/__pycache__/test_resolver_cycles.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_resolver_unresolved.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_resolver_unresolved.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..d2a7958 Binary files /dev/null and b/tests/unit/__pycache__/test_resolver_unresolved.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_resource_merger.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_resource_merger.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..2f7acfe Binary files /dev/null and b/tests/unit/__pycache__/test_resource_merger.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_sanitize.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_sanitize.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..11aebdd Binary files /dev/null and b/tests/unit/__pycache__/test_sanitize.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_scan_profile_validation.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_scan_profile_validation.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..5cd81e0 Binary files /dev/null and b/tests/unit/__pycache__/test_scan_profile_validation.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_scanner.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_scanner.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..b964df6 Binary files /dev/null and b/tests/unit/__pycache__/test_scanner.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_scanner_filtering.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_scanner_filtering.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..b24373b Binary files /dev/null and b/tests/unit/__pycache__/test_scanner_filtering.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_snapshot_store.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_snapshot_store.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..aeed011 Binary files /dev/null and b/tests/unit/__pycache__/test_snapshot_store.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_state_builder.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_state_builder.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..293ac02 Binary files /dev/null and b/tests/unit/__pycache__/test_state_builder.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_state_builder_unmapped.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_state_builder_unmapped.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..7da4dc9 Binary files /dev/null and b/tests/unit/__pycache__/test_state_builder_unmapped.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_synology_plugin.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_synology_plugin.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..8ec8bc0 Binary files /dev/null and b/tests/unit/__pycache__/test_synology_plugin.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_validator.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_validator.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..232a019 Binary files /dev/null and b/tests/unit/__pycache__/test_validator.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_validator_autocorrect.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_validator_autocorrect.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..0dab27c Binary files /dev/null and b/tests/unit/__pycache__/test_validator_autocorrect.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_variable_extractor.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_variable_extractor.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..0f9260d Binary files /dev/null and b/tests/unit/__pycache__/test_variable_extractor.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/__pycache__/test_windows_plugin.cpython-313-pytest-9.0.3.pyc b/tests/unit/__pycache__/test_windows_plugin.cpython-313-pytest-9.0.3.pyc new file mode 100644 index 0000000..b7ffb13 Binary files /dev/null and b/tests/unit/__pycache__/test_windows_plugin.cpython-313-pytest-9.0.3.pyc differ diff --git a/tests/unit/test_authentik.py b/tests/unit/test_authentik.py new file mode 100644 index 0000000..8e844b6 --- /dev/null +++ b/tests/unit/test_authentik.py @@ -0,0 +1,495 @@ +"""Unit tests for Authentik authentication and discovery plugin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.auth.authentik_auth import ( + AuthenticationError, + AuthentikAuthProvider, + AuthentikConfig, + AuthentikSession, +) +from iac_reverse.auth.authentik_discovery import ( + AuthentikDiscoveryError, + AuthentikDiscoveryPlugin, +) +from iac_reverse.models import CpuArchitecture, PlatformCategory, ScanProgress + + +# --------------------------------------------------------------------------- +# AuthentikAuthProvider tests +# --------------------------------------------------------------------------- + + +class TestAuthentikAuthProvider: + """Tests for AuthentikAuthProvider SSO authentication.""" + + def setup_method(self): + self.provider = AuthentikAuthProvider() + self.config = AuthentikConfig( + base_url="https://auth.internal.lab", + client_id="iac-reverse-tool", + client_secret="test-secret", + ) + + @patch("iac_reverse.auth.authentik_auth.requests.post") + @patch("iac_reverse.auth.authentik_auth.requests.get") + def test_authenticate_user_success(self, mock_get, mock_post): + """Successful authentication returns a valid session.""" + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + "access_token": "access-123", + "refresh_token": "refresh-456", + }, + ) + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + "sub": "user-001", + "groups": ["admins", "infra-team"], + }, + ) + + session = self.provider.authenticate_user(self.config) + + assert session.access_token == "access-123" + assert session.refresh_token == "refresh-456" + assert session.user_id == "user-001" + assert session.groups == ["admins", "infra-team"] + + @patch("iac_reverse.auth.authentik_auth.requests.post") + def test_authenticate_user_failure_status(self, mock_post): + """Authentication failure raises AuthenticationError.""" + mock_post.return_value = MagicMock( + status_code=401, + text="Invalid client credentials", + ) + + with pytest.raises(AuthenticationError) as exc_info: + self.provider.authenticate_user(self.config) + + assert "Authentik" in str(exc_info.value) + assert "401" in str(exc_info.value) + + @patch("iac_reverse.auth.authentik_auth.requests.post") + def test_authenticate_user_connection_error(self, mock_post): + """Connection error raises AuthenticationError.""" + import requests + + mock_post.side_effect = requests.ConnectionError("Connection refused") + + with pytest.raises(AuthenticationError) as exc_info: + self.provider.authenticate_user(self.config) + + assert "Authentik" in str(exc_info.value) + assert "failed to connect" in str(exc_info.value) + + @patch("iac_reverse.auth.authentik_auth.requests.post") + @patch("iac_reverse.auth.authentik_auth.requests.get") + def test_refresh_session_success(self, mock_get, mock_post): + """Successful token refresh returns updated session.""" + session = AuthentikSession( + access_token="old-access", + refresh_token="old-refresh", + user_id="user-001", + groups=["admins"], + ) + + mock_post.return_value = MagicMock( + status_code=200, + json=lambda: { + "access_token": "new-access-789", + "refresh_token": "new-refresh-012", + }, + ) + mock_get.return_value = MagicMock( + status_code=200, + json=lambda: { + "sub": "user-001", + "groups": ["admins", "new-group"], + }, + ) + + new_session = self.provider.refresh_session(self.config, session) + + assert new_session.access_token == "new-access-789" + assert new_session.refresh_token == "new-refresh-012" + assert new_session.user_id == "user-001" + + @patch("iac_reverse.auth.authentik_auth.requests.post") + def test_refresh_session_failure(self, mock_post): + """Failed refresh raises AuthenticationError.""" + session = AuthentikSession( + access_token="old-access", + refresh_token="expired-refresh", + user_id="user-001", + groups=[], + ) + + mock_post.return_value = MagicMock( + status_code=400, + text="Invalid refresh token", + ) + + with pytest.raises(AuthenticationError) as exc_info: + self.provider.refresh_session(self.config, session) + + assert "refresh failed" in str(exc_info.value) + + @patch("iac_reverse.auth.authentik_auth.requests.get") + def test_validate_token_valid(self, mock_get): + """Valid token returns True.""" + mock_get.return_value = MagicMock(status_code=200) + + result = self.provider.validate_token(self.config, "valid-token") + + assert result is True + + @patch("iac_reverse.auth.authentik_auth.requests.get") + def test_validate_token_invalid(self, mock_get): + """Invalid token returns False.""" + mock_get.return_value = MagicMock(status_code=401) + + result = self.provider.validate_token(self.config, "invalid-token") + + assert result is False + + @patch("iac_reverse.auth.authentik_auth.requests.get") + def test_validate_token_connection_error(self, mock_get): + """Connection error during validation returns False.""" + import requests + + mock_get.side_effect = requests.ConnectionError("timeout") + + result = self.provider.validate_token(self.config, "some-token") + + assert result is False + + +# --------------------------------------------------------------------------- +# AuthentikDiscoveryPlugin tests +# --------------------------------------------------------------------------- + + +class TestAuthentikDiscoveryPlugin: + """Tests for AuthentikDiscoveryPlugin resource discovery.""" + + def setup_method(self): + self.plugin = AuthentikDiscoveryPlugin() + self.credentials = { + "base_url": "https://auth.internal.lab", + "api_token": "test-api-token", + } + + def test_get_platform_category(self): + """Plugin returns CONTAINER_ORCHESTRATION category.""" + assert self.plugin.get_platform_category() == PlatformCategory.CONTAINER_ORCHESTRATION + + def test_list_supported_resource_types(self): + """Plugin lists all expected Authentik resource types.""" + types = self.plugin.list_supported_resource_types() + + expected = [ + "authentik_flow", + "authentik_stage", + "authentik_provider", + "authentik_application", + "authentik_outpost", + "authentik_property_mapping", + "authentik_certificate", + "authentik_group", + "authentik_source", + ] + assert types == expected + + def test_detect_architecture_defaults_to_amd64(self): + """Architecture detection defaults to AMD64.""" + arch = self.plugin.detect_architecture("https://auth.internal.lab") + assert arch == CpuArchitecture.AMD64 + + def test_list_endpoints_before_auth(self): + """Endpoints list is empty before authentication.""" + assert self.plugin.list_endpoints() == [] + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_authenticate_success(self, mock_get): + """Successful authentication sets internal state.""" + mock_get.return_value = MagicMock(status_code=200, json=lambda: {"results": []}) + + self.plugin.authenticate(self.credentials) + + assert self.plugin.list_endpoints() == ["https://auth.internal.lab"] + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_authenticate_invalid_token(self, mock_get): + """Invalid API token raises AuthentikDiscoveryError.""" + mock_get.return_value = MagicMock(status_code=401) + + with pytest.raises(AuthentikDiscoveryError) as exc_info: + self.plugin.authenticate(self.credentials) + + assert "invalid API token" in str(exc_info.value) + + def test_authenticate_missing_base_url(self): + """Missing base_url raises AuthentikDiscoveryError.""" + with pytest.raises(AuthentikDiscoveryError) as exc_info: + self.plugin.authenticate({"api_token": "token"}) + + assert "base_url" in str(exc_info.value) + + def test_authenticate_missing_api_token(self): + """Missing api_token raises AuthentikDiscoveryError.""" + with pytest.raises(AuthentikDiscoveryError) as exc_info: + self.plugin.authenticate({"base_url": "https://auth.lab"}) + + assert "api_token" in str(exc_info.value) + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_authenticate_connection_error(self, mock_get): + """Connection error during auth raises AuthentikDiscoveryError.""" + import requests + + mock_get.side_effect = requests.ConnectionError("refused") + + with pytest.raises(AuthentikDiscoveryError) as exc_info: + self.plugin.authenticate(self.credentials) + + assert "failed to connect" in str(exc_info.value) + + def test_discover_resources_without_auth(self): + """Discovering without authentication raises error.""" + with pytest.raises(AuthentikDiscoveryError) as exc_info: + self.plugin.discover_resources( + ["https://auth.lab"], + ["authentik_flow"], + lambda p: None, + ) + + assert "must authenticate" in str(exc_info.value) + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_discover_resources_flows(self, mock_get): + """Discovers Authentik flows from the API.""" + # First call is for authentication check + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + + # Second call is for flow discovery + flow_response = MagicMock( + status_code=200, + json=lambda: { + "results": [ + { + "pk": "flow-uuid-1", + "name": "default-authentication-flow", + "slug": "default-authentication-flow", + "stages": ["stage-1", "stage-2"], + }, + { + "pk": "flow-uuid-2", + "name": "default-enrollment-flow", + "slug": "default-enrollment-flow", + "stages": ["stage-3"], + }, + ], + "pagination": {"next": 0, "count": 2}, + }, + ) + + mock_get.side_effect = [auth_response, flow_response] + + self.plugin.authenticate(self.credentials) + + progress_updates: list[ScanProgress] = [] + result = self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_flow"], + lambda p: progress_updates.append(p), + ) + + assert len(result.resources) == 2 + assert result.resources[0].resource_type == "authentik_flow" + assert result.resources[0].name == "default-authentication-flow" + assert result.resources[0].unique_id == "authentik/authentik_flow/flow-uuid-1" + assert "stage-1" in result.resources[0].raw_references + assert len(result.warnings) == 0 + assert len(result.errors) == 0 + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_discover_resources_multiple_types(self, mock_get): + """Discovers multiple resource types in one call.""" + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + + app_response = MagicMock( + status_code=200, + json=lambda: { + "results": [ + { + "pk": "app-1", + "name": "grafana", + "slug": "grafana", + "provider": "provider-1", + } + ], + "pagination": {"next": 0, "count": 1}, + }, + ) + + group_response = MagicMock( + status_code=200, + json=lambda: { + "results": [ + {"pk": "group-1", "name": "admins"}, + {"pk": "group-2", "name": "users"}, + ], + "pagination": {"next": 0, "count": 2}, + }, + ) + + mock_get.side_effect = [auth_response, app_response, group_response] + + self.plugin.authenticate(self.credentials) + + result = self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_application", "authentik_group"], + lambda p: None, + ) + + assert len(result.resources) == 3 + app_resources = [r for r in result.resources if r.resource_type == "authentik_application"] + group_resources = [r for r in result.resources if r.resource_type == "authentik_group"] + assert len(app_resources) == 1 + assert len(group_resources) == 2 + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_discover_resources_unsupported_type_warning(self, mock_get): + """Unsupported resource type produces a warning.""" + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + mock_get.side_effect = [auth_response] + + self.plugin.authenticate(self.credentials) + + result = self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_nonexistent"], + lambda p: None, + ) + + assert len(result.resources) == 0 + assert len(result.warnings) == 1 + assert "Unsupported" in result.warnings[0] + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_discover_resources_api_error(self, mock_get): + """API error during discovery is captured in errors list.""" + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + error_response = MagicMock(status_code=500) + + mock_get.side_effect = [auth_response, error_response] + + self.plugin.authenticate(self.credentials) + + result = self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_flow"], + lambda p: None, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 1 + assert "authentik_flow" in result.errors[0] + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_discover_resources_pagination(self, mock_get): + """Handles paginated API responses correctly.""" + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + + page1_response = MagicMock( + status_code=200, + json=lambda: { + "results": [{"pk": "cert-1", "name": "cert-one"}], + "pagination": {"next": 2, "count": 2}, + }, + ) + page2_response = MagicMock( + status_code=200, + json=lambda: { + "results": [{"pk": "cert-2", "name": "cert-two"}], + "pagination": {"next": 0, "count": 2}, + }, + ) + + mock_get.side_effect = [auth_response, page1_response, page2_response] + + self.plugin.authenticate(self.credentials) + + result = self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_certificate"], + lambda p: None, + ) + + assert len(result.resources) == 2 + assert result.resources[0].name == "cert-one" + assert result.resources[1].name == "cert-two" + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_discover_resources_progress_callback(self, mock_get): + """Progress callback is invoked correctly during discovery.""" + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + empty_response = MagicMock( + status_code=200, + json=lambda: {"results": [], "pagination": {"next": 0, "count": 0}}, + ) + + mock_get.side_effect = [auth_response, empty_response, empty_response] + + self.plugin.authenticate(self.credentials) + + progress_updates: list[ScanProgress] = [] + self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_flow", "authentik_stage"], + lambda p: progress_updates.append(p), + ) + + # Should have progress for each type + final "complete" + assert len(progress_updates) == 3 + assert progress_updates[0].current_resource_type == "authentik_flow" + assert progress_updates[1].current_resource_type == "authentik_stage" + assert progress_updates[2].current_resource_type == "complete" + + @patch("iac_reverse.auth.authentik_discovery.requests.get") + def test_extract_references_from_resource(self, mock_get): + """References are extracted from API response fields.""" + auth_response = MagicMock(status_code=200, json=lambda: {"results": []}) + app_response = MagicMock( + status_code=200, + json=lambda: { + "results": [ + { + "pk": "app-1", + "name": "my-app", + "provider": "provider-uuid-1", + "group": "group-uuid-1", + } + ], + "pagination": {"next": 0, "count": 1}, + }, + ) + + mock_get.side_effect = [auth_response, app_response] + + self.plugin.authenticate(self.credentials) + + result = self.plugin.discover_resources( + ["https://auth.internal.lab"], + ["authentik_application"], + lambda p: None, + ) + + refs = result.resources[0].raw_references + assert "provider-uuid-1" in refs + assert "group-uuid-1" in refs diff --git a/tests/unit/test_bare_metal_plugin.py b/tests/unit/test_bare_metal_plugin.py new file mode 100644 index 0000000..12d57fb --- /dev/null +++ b/tests/unit/test_bare_metal_plugin.py @@ -0,0 +1,664 @@ +"""Unit tests for the BareMetalPlugin provider plugin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + PlatformCategory, + ProviderType, + ScanProgress, +) +from iac_reverse.scanner import AuthenticationError +from iac_reverse.scanner.bare_metal_plugin import BareMetalPlugin + + +class TestBareMetalPluginInterface: + """Tests for basic plugin interface compliance.""" + + def test_implements_provider_plugin(self): + """BareMetalPlugin can be instantiated (implements all abstract methods).""" + plugin = BareMetalPlugin() + assert plugin is not None + + def test_get_platform_category(self): + """Returns PlatformCategory.BARE_METAL.""" + plugin = BareMetalPlugin() + assert plugin.get_platform_category() == PlatformCategory.BARE_METAL + + def test_list_supported_resource_types(self): + """Returns the expected bare metal resource types.""" + plugin = BareMetalPlugin() + expected = [ + "bare_metal_hardware", + "bare_metal_bmc_config", + "bare_metal_network_interface", + "bare_metal_raid_config", + ] + assert plugin.list_supported_resource_types() == expected + + def test_list_endpoints_before_auth(self): + """Returns empty list before authentication.""" + plugin = BareMetalPlugin() + assert plugin.list_endpoints() == [] + + +class TestBareMetalAuthentication: + """Tests for BMC authentication via Redfish.""" + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_authenticate_success(self, mock_session_cls): + """Successful authentication stores session and sets base URL.""" + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.headers = {"X-Auth-Token": "test-token-123"} + mock_session.post.return_value = mock_response + + plugin = BareMetalPlugin() + plugin.authenticate({ + "host": "192.168.1.100", + "username": "admin", + "password": "secret", + }) + + assert plugin._host == "192.168.1.100" + assert plugin._base_url == "https://192.168.1.100:443" + assert plugin._session is not None + mock_session.post.assert_called_once() + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_authenticate_custom_port(self, mock_session_cls): + """Authentication uses custom port when specified.""" + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"X-Auth-Token": "token"} + mock_session.post.return_value = mock_response + + plugin = BareMetalPlugin() + plugin.authenticate({ + "host": "10.0.0.1", + "username": "admin", + "password": "pass", + "port": "8443", + }) + + assert plugin._base_url == "https://10.0.0.1:8443" + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_authenticate_no_ssl(self, mock_session_cls): + """Authentication uses HTTP when use_ssl is false.""" + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"X-Auth-Token": "token"} + mock_session.post.return_value = mock_response + + plugin = BareMetalPlugin() + plugin.authenticate({ + "host": "10.0.0.1", + "username": "admin", + "password": "pass", + "use_ssl": "false", + }) + + assert plugin._base_url == "http://10.0.0.1:443" + + def test_authenticate_missing_host(self): + """Raises AuthenticationError when host is missing.""" + plugin = BareMetalPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"username": "admin", "password": "pass"}) + assert "Missing required credentials" in str(exc_info.value) + + def test_authenticate_missing_username(self): + """Raises AuthenticationError when username is missing.""" + plugin = BareMetalPlugin() + with pytest.raises(AuthenticationError): + plugin.authenticate({"host": "10.0.0.1", "password": "pass"}) + + def test_authenticate_missing_password(self): + """Raises AuthenticationError when password is missing.""" + plugin = BareMetalPlugin() + with pytest.raises(AuthenticationError): + plugin.authenticate({"host": "10.0.0.1", "username": "admin"}) + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_authenticate_401_unauthorized(self, mock_session_cls): + """Raises AuthenticationError on HTTP 401.""" + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_session.post.return_value = mock_response + + plugin = BareMetalPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({ + "host": "10.0.0.1", + "username": "admin", + "password": "wrong", + }) + assert "Invalid credentials" in str(exc_info.value) + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_authenticate_connection_error(self, mock_session_cls): + """Raises AuthenticationError on connection failure.""" + import requests + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + mock_session.post.side_effect = requests.exceptions.ConnectionError( + "Connection refused" + ) + + plugin = BareMetalPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({ + "host": "unreachable.host", + "username": "admin", + "password": "pass", + }) + assert "Cannot connect" in str(exc_info.value) + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_authenticate_timeout(self, mock_session_cls): + """Raises AuthenticationError on timeout.""" + import requests + + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + mock_session.post.side_effect = requests.exceptions.Timeout("Timed out") + + plugin = BareMetalPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({ + "host": "slow.host", + "username": "admin", + "password": "pass", + }) + assert "timed out" in str(exc_info.value) + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_list_endpoints_after_auth(self, mock_session_cls): + """Returns host as endpoint after successful authentication.""" + mock_session = MagicMock() + mock_session_cls.return_value = mock_session + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.headers = {"X-Auth-Token": "token"} + mock_session.post.return_value = mock_response + + plugin = BareMetalPlugin() + plugin.authenticate({ + "host": "192.168.1.50", + "username": "admin", + "password": "pass", + }) + + assert plugin.list_endpoints() == ["192.168.1.50"] + + +class TestBareMetalArchitectureDetection: + """Tests for CPU architecture detection via Redfish.""" + + def test_detect_architecture_no_session(self): + """Returns AMD64 default when no session is available.""" + plugin = BareMetalPlugin() + assert plugin.detect_architecture("10.0.0.1") == CpuArchitecture.AMD64 + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_detect_architecture_amd64(self, mock_session_cls): + """Detects AMD64 architecture from processor data.""" + plugin = BareMetalPlugin() + mock_session = MagicMock() + plugin._session = mock_session + plugin._base_url = "https://10.0.0.1:443" + + # Mock processors collection response + proc_collection_response = MagicMock() + proc_collection_response.status_code = 200 + proc_collection_response.json.return_value = { + "Members": [{"@odata.id": "/redfish/v1/Systems/1/Processors/CPU.1"}] + } + + # Mock individual processor response + proc_response = MagicMock() + proc_response.status_code = 200 + proc_response.json.return_value = { + "InstructionSet": "x86-64", + "Model": "Intel Xeon E5-2680 v4", + } + + mock_session.get.side_effect = [proc_collection_response, proc_response] + + result = plugin.detect_architecture("10.0.0.1") + assert result == CpuArchitecture.AMD64 + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_detect_architecture_aarch64(self, mock_session_cls): + """Detects AARCH64 architecture from processor data.""" + plugin = BareMetalPlugin() + mock_session = MagicMock() + plugin._session = mock_session + plugin._base_url = "https://10.0.0.1:443" + + proc_collection_response = MagicMock() + proc_collection_response.status_code = 200 + proc_collection_response.json.return_value = { + "Members": [{"@odata.id": "/redfish/v1/Systems/1/Processors/CPU.1"}] + } + + proc_response = MagicMock() + proc_response.status_code = 200 + proc_response.json.return_value = { + "InstructionSet": "AArch64", + "Model": "Ampere Altra Q80-30", + } + + mock_session.get.side_effect = [proc_collection_response, proc_response] + + result = plugin.detect_architecture("10.0.0.1") + assert result == CpuArchitecture.AARCH64 + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_detect_architecture_arm_from_model(self, mock_session_cls): + """Detects ARM architecture from model string.""" + plugin = BareMetalPlugin() + mock_session = MagicMock() + plugin._session = mock_session + plugin._base_url = "https://10.0.0.1:443" + + proc_collection_response = MagicMock() + proc_collection_response.status_code = 200 + proc_collection_response.json.return_value = { + "Members": [{"@odata.id": "/redfish/v1/Systems/1/Processors/CPU.1"}] + } + + proc_response = MagicMock() + proc_response.status_code = 200 + proc_response.json.return_value = { + "InstructionSet": "", + "Model": "ARM Cortex-A53", + } + + mock_session.get.side_effect = [proc_collection_response, proc_response] + + result = plugin.detect_architecture("10.0.0.1") + assert result == CpuArchitecture.ARM + + @patch("iac_reverse.scanner.bare_metal_plugin.requests.Session") + def test_detect_architecture_fallback_on_error(self, mock_session_cls): + """Falls back to AMD64 on request error.""" + plugin = BareMetalPlugin() + mock_session = MagicMock() + plugin._session = mock_session + plugin._base_url = "https://10.0.0.1:443" + + mock_session.get.side_effect = Exception("Network error") + + result = plugin.detect_architecture("10.0.0.1") + assert result == CpuArchitecture.AMD64 + + +class TestBareMetalDiscoverResources: + """Tests for resource discovery via Redfish.""" + + def _make_authenticated_plugin(self): + """Create a plugin with a mocked session.""" + plugin = BareMetalPlugin() + plugin._session = MagicMock() + plugin._base_url = "https://10.0.0.1:443" + plugin._host = "10.0.0.1" + return plugin + + def test_discover_hardware(self): + """Discovers hardware inventory from /redfish/v1/Systems/1.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection (for architecture) + proc_collection = MagicMock() + proc_collection.status_code = 200 + proc_collection.json.return_value = { + "Members": [{"@odata.id": "/redfish/v1/Systems/1/Processors/CPU.1"}] + } + proc_detail = MagicMock() + proc_detail.status_code = 200 + proc_detail.json.return_value = { + "InstructionSet": "x86-64", + "Model": "Intel Xeon", + } + + # Mock system response + system_response = MagicMock() + system_response.status_code = 200 + system_response.json.return_value = { + "Id": "System.Embedded.1", + "Name": "Dell PowerEdge R740", + "Manufacturer": "Dell Inc.", + "Model": "PowerEdge R740", + "SerialNumber": "ABC123", + "SKU": "R740", + "BiosVersion": "2.12.2", + "MemorySummary": {"TotalSystemMemoryGiB": 256}, + "ProcessorSummary": {"Count": 2, "Model": "Intel Xeon Gold 6248"}, + "PowerState": "On", + "Status": {"State": "Enabled", "Health": "OK"}, + } + + plugin._session.get.side_effect = [ + proc_collection, proc_detail, system_response + ] + + progress_calls = [] + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["bare_metal_hardware"], + progress_callback=lambda p: progress_calls.append(p), + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "bare_metal_hardware" + assert resource.unique_id == "10.0.0.1:System.Embedded.1" + assert resource.name == "Dell PowerEdge R740" + assert resource.provider == ProviderType.BARE_METAL + assert resource.platform_category == PlatformCategory.BARE_METAL + assert resource.architecture == CpuArchitecture.AMD64 + assert resource.attributes["manufacturer"] == "Dell Inc." + assert resource.attributes["total_memory_gib"] == 256 + assert resource.attributes["processor_count"] == 2 + assert len(progress_calls) == 1 + + def test_discover_bmc_config(self): + """Discovers BMC configuration from /redfish/v1/Managers/1.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection + proc_collection = MagicMock() + proc_collection.status_code = 200 + proc_collection.json.return_value = {"Members": []} + + # Mock manager response + manager_response = MagicMock() + manager_response.status_code = 200 + manager_response.json.return_value = { + "Id": "iDRAC.Embedded.1", + "Name": "iDRAC Manager", + "ManagerType": "BMC", + "FirmwareVersion": "5.10.50.00", + "Model": "iDRAC9", + "Status": {"State": "Enabled", "Health": "OK"}, + "UUID": "abc-def-123", + } + + plugin._session.get.side_effect = [proc_collection, manager_response] + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["bare_metal_bmc_config"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "bare_metal_bmc_config" + assert resource.unique_id == "10.0.0.1:iDRAC.Embedded.1" + assert resource.attributes["firmware_version"] == "5.10.50.00" + assert resource.attributes["manager_type"] == "BMC" + + def test_discover_network_interfaces(self): + """Discovers network interfaces from Redfish EthernetInterfaces.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection + proc_collection = MagicMock() + proc_collection.status_code = 200 + proc_collection.json.return_value = {"Members": []} + + # Mock NIC collection + nic_collection = MagicMock() + nic_collection.status_code = 200 + nic_collection.json.return_value = { + "Members": [ + {"@odata.id": "/redfish/v1/Systems/1/EthernetInterfaces/NIC.1"}, + {"@odata.id": "/redfish/v1/Systems/1/EthernetInterfaces/NIC.2"}, + ] + } + + # Mock individual NIC responses + nic1_response = MagicMock() + nic1_response.status_code = 200 + nic1_response.json.return_value = { + "Id": "NIC.Integrated.1-1", + "Name": "Ethernet Interface 1", + "MACAddress": "AA:BB:CC:DD:EE:01", + "SpeedMbps": 10000, + "Status": {"State": "Enabled", "Health": "OK"}, + "IPv4Addresses": [{"Address": "192.168.1.10"}], + "IPv6Addresses": [], + "LinkStatus": "LinkUp", + "AutoNeg": True, + } + + nic2_response = MagicMock() + nic2_response.status_code = 200 + nic2_response.json.return_value = { + "Id": "NIC.Integrated.1-2", + "Name": "Ethernet Interface 2", + "MACAddress": "AA:BB:CC:DD:EE:02", + "SpeedMbps": 1000, + "Status": {"State": "Enabled", "Health": "OK"}, + "IPv4Addresses": [], + "IPv6Addresses": [], + "LinkStatus": "LinkDown", + "AutoNeg": True, + } + + plugin._session.get.side_effect = [ + proc_collection, nic_collection, nic1_response, nic2_response + ] + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["bare_metal_network_interface"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 2 + assert result.resources[0].attributes["mac_address"] == "AA:BB:CC:DD:EE:01" + assert result.resources[0].attributes["speed_mbps"] == 10000 + assert result.resources[1].attributes["mac_address"] == "AA:BB:CC:DD:EE:02" + + def test_discover_raid_config(self): + """Discovers RAID configuration from Redfish Storage.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection + proc_collection = MagicMock() + proc_collection.status_code = 200 + proc_collection.json.return_value = {"Members": []} + + # Mock storage collection + storage_collection = MagicMock() + storage_collection.status_code = 200 + storage_collection.json.return_value = { + "Members": [ + {"@odata.id": "/redfish/v1/Systems/1/Storage/RAID.Integrated.1-1"} + ] + } + + # Mock storage controller detail + storage_detail = MagicMock() + storage_detail.status_code = 200 + storage_detail.json.return_value = { + "Id": "RAID.Integrated.1-1", + "Name": "PERC H740P Mini", + "StorageControllers": [{"Name": "PERC H740P Mini"}], + "Drives": [ + {"@odata.id": "/redfish/v1/Systems/1/Storage/Drives/Disk.0"}, + {"@odata.id": "/redfish/v1/Systems/1/Storage/Drives/Disk.1"}, + ], + "Volumes": { + "@odata.id": "/redfish/v1/Systems/1/Storage/RAID.Integrated.1-1/Volumes" + }, + "Status": {"State": "Enabled", "Health": "OK"}, + } + + # Mock volumes collection + volumes_response = MagicMock() + volumes_response.status_code = 200 + volumes_response.json.return_value = { + "Members": [ + {"@odata.id": "/redfish/v1/Systems/1/Storage/Volumes/Disk.Virtual.0"} + ] + } + + plugin._session.get.side_effect = [ + proc_collection, storage_collection, storage_detail, volumes_response + ] + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["bare_metal_raid_config"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "bare_metal_raid_config" + assert resource.attributes["drive_count"] == 2 + assert len(resource.attributes["volumes"]) == 1 + assert resource.attributes["storage_controllers"] == ["PERC H740P Mini"] + + def test_discover_multiple_resource_types(self): + """Discovers multiple resource types in a single call.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection + proc_collection = MagicMock() + proc_collection.status_code = 200 + proc_collection.json.return_value = {"Members": []} + + # Mock system response + system_response = MagicMock() + system_response.status_code = 200 + system_response.json.return_value = { + "Id": "System.1", + "Name": "Server", + "Manufacturer": "Dell", + "Model": "R740", + } + + # Mock manager response + manager_response = MagicMock() + manager_response.status_code = 200 + manager_response.json.return_value = { + "Id": "BMC.1", + "Name": "BMC", + "ManagerType": "BMC", + "FirmwareVersion": "1.0", + } + + plugin._session.get.side_effect = [ + proc_collection, system_response, manager_response + ] + + progress_calls = [] + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["bare_metal_hardware", "bare_metal_bmc_config"], + progress_callback=lambda p: progress_calls.append(p), + ) + + assert len(result.resources) == 2 + assert len(progress_calls) == 2 + assert progress_calls[0].resource_types_completed == 1 + assert progress_calls[1].resource_types_completed == 2 + + def test_discover_handles_errors_gracefully(self): + """Errors during discovery are captured, not raised.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection that raises - this causes detect_architecture + # to fail, which is caught internally. Then the resource handler also + # needs to raise to trigger the error capture in discover_resources. + plugin._session.get.side_effect = Exception("Server unreachable") + + # Patch _discover_resource_type to raise so the outer handler catches it + with patch.object( + plugin, "_discover_resource_type", side_effect=Exception("Server unreachable") + ): + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["bare_metal_hardware"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 1 + assert "Server unreachable" in result.errors[0] + + def test_discover_unsupported_resource_type(self): + """Unsupported resource types return empty results without error.""" + plugin = self._make_authenticated_plugin() + + # Mock processor detection + proc_collection = MagicMock() + proc_collection.status_code = 200 + proc_collection.json.return_value = {"Members": []} + + plugin._session.get.side_effect = [proc_collection] + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["unknown_type"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 0 + + +class TestParseArchitecture: + """Tests for the _parse_architecture static method.""" + + def test_x86_64_instruction_set(self): + assert BareMetalPlugin._parse_architecture( + {"InstructionSet": "x86-64", "Model": "Intel Xeon"} + ) == CpuArchitecture.AMD64 + + def test_aarch64_instruction_set(self): + assert BareMetalPlugin._parse_architecture( + {"InstructionSet": "AArch64", "Model": "Ampere"} + ) == CpuArchitecture.AARCH64 + + def test_arm_instruction_set(self): + assert BareMetalPlugin._parse_architecture( + {"InstructionSet": "ARM", "Model": "Cortex"} + ) == CpuArchitecture.AARCH64 + + def test_arm_model_32bit(self): + assert BareMetalPlugin._parse_architecture( + {"InstructionSet": "", "Model": "ARM Cortex-A7"} + ) == CpuArchitecture.ARM + + def test_arm_model_64bit(self): + assert BareMetalPlugin._parse_architecture( + {"InstructionSet": "", "Model": "ARM v8 Processor"} + ) == CpuArchitecture.AARCH64 + + def test_empty_data_defaults_amd64(self): + assert BareMetalPlugin._parse_architecture( + {"InstructionSet": "", "Model": ""} + ) == CpuArchitecture.AMD64 diff --git a/tests/unit/test_change_detector.py b/tests/unit/test_change_detector.py new file mode 100644 index 0000000..aa8468d --- /dev/null +++ b/tests/unit/test_change_detector.py @@ -0,0 +1,335 @@ +"""Unit tests for the ChangeDetector class.""" + +import pytest + +from iac_reverse.incremental.change_detector import ChangeDetector +from iac_reverse.models import ( + ChangeSummary, + ChangeType, + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) + + +def _make_resource( + unique_id: str = "res-1", + name: str = "test-resource", + resource_type: str = "kubernetes_deployment", + attributes: dict | None = None, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AMD64, + endpoint="https://k8s-api.internal.lab:6443", + attributes=attributes if attributes is not None else {"replicas": 3}, + ) + + +def _make_scan_result( + resources: list[DiscoveredResource] | None = None, +) -> ScanResult: + """Create a sample ScanResult for testing.""" + return ScanResult( + resources=resources if resources is not None else [], + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="abc123", + ) + + +class TestNoChanges: + """Tests for identical scans producing no changes.""" + + def test_identical_scans_produce_no_changes(self) -> None: + """Comparing identical scans returns empty change summary.""" + resource = _make_resource(unique_id="res-1", attributes={"replicas": 3}) + current = _make_scan_result(resources=[resource]) + previous = _make_scan_result(resources=[resource]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 0 + assert summary.removed_count == 0 + assert summary.modified_count == 0 + assert summary.changes == [] + + def test_empty_scans_produce_no_changes(self) -> None: + """Comparing two empty scans returns empty change summary.""" + current = _make_scan_result(resources=[]) + previous = _make_scan_result(resources=[]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 0 + assert summary.removed_count == 0 + assert summary.modified_count == 0 + assert summary.changes == [] + + +class TestAddedResources: + """Tests for detecting added resources.""" + + def test_new_resource_detected_as_added(self) -> None: + """A resource in current but not in previous is classified as ADDED.""" + resource = _make_resource(unique_id="new-res", name="new-service") + current = _make_scan_result(resources=[resource]) + previous = _make_scan_result(resources=[]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 1 + assert summary.removed_count == 0 + assert summary.modified_count == 0 + assert len(summary.changes) == 1 + + change = summary.changes[0] + assert change.resource_id == "new-res" + assert change.resource_name == "new-service" + assert change.change_type == ChangeType.ADDED + assert change.changed_attributes is None + + def test_multiple_added_resources(self) -> None: + """Multiple new resources are all classified as ADDED.""" + res1 = _make_resource(unique_id="res-1", name="service-1") + res2 = _make_resource(unique_id="res-2", name="service-2") + current = _make_scan_result(resources=[res1, res2]) + previous = _make_scan_result(resources=[]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 2 + added_ids = {c.resource_id for c in summary.changes} + assert added_ids == {"res-1", "res-2"} + + +class TestRemovedResources: + """Tests for detecting removed resources.""" + + def test_missing_resource_detected_as_removed(self) -> None: + """A resource in previous but not in current is classified as REMOVED.""" + resource = _make_resource(unique_id="old-res", name="old-service") + current = _make_scan_result(resources=[]) + previous = _make_scan_result(resources=[resource]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 0 + assert summary.removed_count == 1 + assert summary.modified_count == 0 + assert len(summary.changes) == 1 + + change = summary.changes[0] + assert change.resource_id == "old-res" + assert change.resource_name == "old-service" + assert change.change_type == ChangeType.REMOVED + assert change.changed_attributes is None + + def test_multiple_removed_resources(self) -> None: + """Multiple missing resources are all classified as REMOVED.""" + res1 = _make_resource(unique_id="res-1", name="service-1") + res2 = _make_resource(unique_id="res-2", name="service-2") + current = _make_scan_result(resources=[]) + previous = _make_scan_result(resources=[res1, res2]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.removed_count == 2 + removed_ids = {c.resource_id for c in summary.changes} + assert removed_ids == {"res-1", "res-2"} + + +class TestModifiedResources: + """Tests for detecting modified resources.""" + + def test_changed_attributes_detected_as_modified(self) -> None: + """A resource with changed attributes is classified as MODIFIED.""" + prev_resource = _make_resource( + unique_id="res-1", attributes={"replicas": 3, "image": "nginx:1.24"} + ) + curr_resource = _make_resource( + unique_id="res-1", attributes={"replicas": 5, "image": "nginx:1.24"} + ) + current = _make_scan_result(resources=[curr_resource]) + previous = _make_scan_result(resources=[prev_resource]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 0 + assert summary.removed_count == 0 + assert summary.modified_count == 1 + assert len(summary.changes) == 1 + + change = summary.changes[0] + assert change.resource_id == "res-1" + assert change.change_type == ChangeType.MODIFIED + assert change.changed_attributes == {"replicas": {"old": 3, "new": 5}} + + def test_added_attribute_detected_as_modified(self) -> None: + """A resource with a new attribute key is classified as MODIFIED.""" + prev_resource = _make_resource( + unique_id="res-1", attributes={"replicas": 3} + ) + curr_resource = _make_resource( + unique_id="res-1", attributes={"replicas": 3, "image": "nginx:1.25"} + ) + current = _make_scan_result(resources=[curr_resource]) + previous = _make_scan_result(resources=[prev_resource]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.modified_count == 1 + change = summary.changes[0] + assert change.changed_attributes == { + "image": {"old": None, "new": "nginx:1.25"} + } + + def test_removed_attribute_detected_as_modified(self) -> None: + """A resource with a removed attribute key is classified as MODIFIED.""" + prev_resource = _make_resource( + unique_id="res-1", attributes={"replicas": 3, "image": "nginx:1.25"} + ) + curr_resource = _make_resource( + unique_id="res-1", attributes={"replicas": 3} + ) + current = _make_scan_result(resources=[curr_resource]) + previous = _make_scan_result(resources=[prev_resource]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.modified_count == 1 + change = summary.changes[0] + assert change.changed_attributes == { + "image": {"old": "nginx:1.25", "new": None} + } + + +class TestMixedChanges: + """Tests for scans with a mix of added, removed, and modified resources.""" + + def test_mixed_changes_detected_correctly(self) -> None: + """A scan with added, removed, and modified resources is classified correctly.""" + # Shared resource (modified) + prev_shared = _make_resource( + unique_id="shared", name="shared-svc", attributes={"replicas": 2} + ) + curr_shared = _make_resource( + unique_id="shared", name="shared-svc", attributes={"replicas": 4} + ) + + # Removed resource + removed = _make_resource(unique_id="old-res", name="old-svc") + + # Added resource + added = _make_resource(unique_id="new-res", name="new-svc") + + previous = _make_scan_result(resources=[prev_shared, removed]) + current = _make_scan_result(resources=[curr_shared, added]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.added_count == 1 + assert summary.removed_count == 1 + assert summary.modified_count == 1 + assert len(summary.changes) == 3 + + change_map = {c.resource_id: c for c in summary.changes} + assert change_map["new-res"].change_type == ChangeType.ADDED + assert change_map["old-res"].change_type == ChangeType.REMOVED + assert change_map["shared"].change_type == ChangeType.MODIFIED + + +class TestFirstScan: + """Tests for first scan (no previous snapshot).""" + + def test_first_scan_treats_all_as_added(self) -> None: + """When previous is None, all current resources are classified as ADDED.""" + res1 = _make_resource(unique_id="res-1", name="service-1") + res2 = _make_resource(unique_id="res-2", name="service-2") + current = _make_scan_result(resources=[res1, res2]) + + detector = ChangeDetector() + summary = detector.compare(current, previous=None) + + assert summary.added_count == 2 + assert summary.removed_count == 0 + assert summary.modified_count == 0 + assert len(summary.changes) == 2 + assert all(c.change_type == ChangeType.ADDED for c in summary.changes) + + def test_first_scan_empty_produces_empty_summary(self) -> None: + """First scan with no resources produces empty change summary.""" + current = _make_scan_result(resources=[]) + + detector = ChangeDetector() + summary = detector.compare(current, previous=None) + + assert summary.added_count == 0 + assert summary.removed_count == 0 + assert summary.modified_count == 0 + assert summary.changes == [] + + +class TestChangeSummaryCounts: + """Tests that ChangeSummary counts are always correct.""" + + def test_counts_match_change_list(self) -> None: + """The counts in ChangeSummary always match the actual changes list.""" + # Set up a scenario with 2 added, 1 removed, 1 modified + prev_mod = _make_resource( + unique_id="mod-1", name="mod-svc", attributes={"port": 80} + ) + curr_mod = _make_resource( + unique_id="mod-1", name="mod-svc", attributes={"port": 8080} + ) + removed = _make_resource(unique_id="rem-1", name="rem-svc") + added1 = _make_resource(unique_id="add-1", name="add-svc-1") + added2 = _make_resource(unique_id="add-2", name="add-svc-2") + + previous = _make_scan_result(resources=[prev_mod, removed]) + current = _make_scan_result(resources=[curr_mod, added1, added2]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + # Verify counts match actual changes + actual_added = [c for c in summary.changes if c.change_type == ChangeType.ADDED] + actual_removed = [c for c in summary.changes if c.change_type == ChangeType.REMOVED] + actual_modified = [c for c in summary.changes if c.change_type == ChangeType.MODIFIED] + + assert summary.added_count == len(actual_added) == 2 + assert summary.removed_count == len(actual_removed) == 1 + assert summary.modified_count == len(actual_modified) == 1 + + def test_resource_type_preserved_in_changes(self) -> None: + """ResourceChange objects preserve the resource_type from the resource.""" + resource = _make_resource( + unique_id="res-1", + resource_type="docker_service", + name="my-service", + ) + current = _make_scan_result(resources=[resource]) + previous = _make_scan_result(resources=[]) + + detector = ChangeDetector() + summary = detector.compare(current, previous) + + assert summary.changes[0].resource_type == "docker_service" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py new file mode 100644 index 0000000..9057e1b --- /dev/null +++ b/tests/unit/test_cli.py @@ -0,0 +1,437 @@ +"""Unit tests for the CLI entry point. + +Tests command registration, help text, and basic invocation with mocked dependencies. +""" + +from unittest.mock import patch, MagicMock +from pathlib import Path + +import pytest +from click.testing import CliRunner + +from iac_reverse.cli.cli import cli, _load_scan_profile + + +@pytest.fixture +def runner(): + """Create a Click CliRunner for testing.""" + return CliRunner() + + +class TestCommandRegistration: + """Test that all commands are registered on the CLI group.""" + + def test_scan_command_registered(self, runner): + result = runner.invoke(cli, ["scan", "--help"]) + assert result.exit_code == 0 + assert "Scan infrastructure" in result.output + + def test_generate_command_registered(self, runner): + result = runner.invoke(cli, ["generate", "--help"]) + assert result.exit_code == 0 + assert "Run full pipeline" in result.output + + def test_diff_command_registered(self, runner): + result = runner.invoke(cli, ["diff", "--help"]) + assert result.exit_code == 0 + assert "incremental scan" in result.output + + def test_validate_command_registered(self, runner): + result = runner.invoke(cli, ["validate", "--help"]) + assert result.exit_code == 0 + assert "Validate existing Terraform output" in result.output + + def test_login_command_registered(self, runner): + result = runner.invoke(cli, ["login", "--help"]) + assert result.exit_code == 0 + assert "Authenticate with Authentik SSO" in result.output + + def test_all_commands_in_group(self): + command_names = list(cli.commands.keys()) + assert "scan" in command_names + assert "generate" in command_names + assert "diff" in command_names + assert "validate" in command_names + assert "login" in command_names + + +class TestHelpText: + """Test that help text is available and informative.""" + + def test_main_help(self, runner): + result = runner.invoke(cli, ["--help"]) + assert result.exit_code == 0 + assert "IaC Reverse Engineering Tool" in result.output + assert "scan" in result.output + assert "generate" in result.output + assert "diff" in result.output + assert "validate" in result.output + assert "login" in result.output + + def test_version_option(self, runner): + result = runner.invoke(cli, ["--version"]) + assert result.exit_code == 0 + assert "0.1.0" in result.output + + def test_scan_help_shows_profile_option(self, runner): + result = runner.invoke(cli, ["scan", "--help"]) + assert "--profile" in result.output + + def test_generate_help_shows_options(self, runner): + result = runner.invoke(cli, ["generate", "--help"]) + assert "--profile" in result.output + assert "--output-dir" in result.output + + def test_diff_help_shows_profile_option(self, runner): + result = runner.invoke(cli, ["diff", "--help"]) + assert "--profile" in result.output + + def test_validate_help_shows_dir_option(self, runner): + result = runner.invoke(cli, ["validate", "--help"]) + assert "--dir" in result.output + + def test_login_help_shows_options(self, runner): + result = runner.invoke(cli, ["login", "--help"]) + assert "--url" in result.output + assert "--client-id" in result.output + assert "--client-secret" in result.output + + +class TestScanCommand: + """Test the scan command with mocked dependencies.""" + + def test_scan_missing_profile_option(self, runner): + result = runner.invoke(cli, ["scan"]) + assert result.exit_code != 0 + assert "Missing option" in result.output or "required" in result.output.lower() + + def test_scan_nonexistent_profile(self, runner): + result = runner.invoke(cli, ["scan", "--profile", "nonexistent.yaml"]) + assert result.exit_code != 0 + + @patch("iac_reverse.cli.cli._create_plugin") + def test_scan_with_valid_profile(self, mock_create_plugin, runner, tmp_path): + """Test scan command with a valid profile and mocked plugin.""" + from iac_reverse.models import ScanResult + + # Create a valid profile YAML + profile_file = tmp_path / "profile.yaml" + profile_file.write_text( + "provider: kubernetes\n" + "credentials:\n" + " kubeconfig: /path/to/config\n" + "endpoints:\n" + " - https://k8s.local:6443\n" + ) + + # Mock the plugin and scanner + mock_plugin = MagicMock() + mock_plugin.list_supported_resource_types.return_value = [ + "kubernetes_deployment" + ] + mock_plugin.list_endpoints.return_value = ["https://k8s.local:6443"] + mock_plugin.discover_resources.return_value = ScanResult( + resources=[], + warnings=[], + errors=[], + scan_timestamp="2024-01-01T00:00:00Z", + profile_hash="abc123", + ) + mock_create_plugin.return_value = mock_plugin + + result = runner.invoke(cli, ["scan", "--profile", str(profile_file)]) + assert result.exit_code == 0 + assert "0 resources discovered" in result.output + + +class TestGenerateCommand: + """Test the generate command with mocked dependencies.""" + + def test_generate_missing_options(self, runner): + result = runner.invoke(cli, ["generate"]) + assert result.exit_code != 0 + + @patch("iac_reverse.validator.validator.Validator.validate") + @patch("iac_reverse.state_builder.state_builder.StateBuilder.build") + @patch("iac_reverse.generator.code_generator.CodeGenerator.generate") + @patch("iac_reverse.resolver.resolver.DependencyResolver.resolve") + @patch("iac_reverse.scanner.scanner.Scanner.scan") + @patch("iac_reverse.cli.cli._create_plugin") + def test_generate_full_pipeline( + self, + mock_create_plugin, + mock_scan, + mock_resolve, + mock_generate, + mock_build, + mock_validate, + runner, + tmp_path, + ): + """Test generate command runs the full pipeline.""" + from iac_reverse.models import ( + ScanResult, + DependencyGraph, + CodeGenerationResult, + GeneratedFile, + StateFile, + ValidationResult, + ) + + # Create profile + profile_file = tmp_path / "profile.yaml" + profile_file.write_text( + "provider: kubernetes\n" + "credentials:\n" + " kubeconfig: /path/to/config\n" + ) + + output_dir = tmp_path / "output" + + # Mock plugin + mock_plugin = MagicMock() + mock_create_plugin.return_value = mock_plugin + + # Mock scanner.scan + mock_scan.return_value = ScanResult( + resources=[], + warnings=[], + errors=[], + scan_timestamp="2024-01-01T00:00:00Z", + profile_hash="abc123", + ) + + # Mock resolver.resolve + mock_resolve.return_value = DependencyGraph( + resources=[], + relationships=[], + topological_order=[], + cycles=[], + unresolved_references=[], + ) + + # Mock generator.generate + mock_generate.return_value = CodeGenerationResult( + resource_files=[ + GeneratedFile(filename="test.tf", content="# test", resource_count=1) + ], + variables_file=GeneratedFile(filename="variables.tf", content="", resource_count=0), + provider_file=GeneratedFile(filename="providers.tf", content="", resource_count=0), + ) + + # Mock state builder.build + mock_state_file = MagicMock() + mock_state_file.resources = [] + mock_state_file.to_json.return_value = "{}" + mock_build.return_value = mock_state_file + + # Mock validator.validate + mock_validate.return_value = ValidationResult( + init_success=True, + validate_success=True, + plan_success=True, + ) + + result = runner.invoke( + cli, ["generate", "--profile", str(profile_file), "--output-dir", str(output_dir)] + ) + assert result.exit_code == 0 + assert "Generation complete" in result.output + + +class TestDiffCommand: + """Test the diff command with mocked dependencies.""" + + def test_diff_missing_profile(self, runner): + result = runner.invoke(cli, ["diff"]) + assert result.exit_code != 0 + + @patch("iac_reverse.incremental.change_detector.ChangeDetector.compare") + @patch("iac_reverse.incremental.snapshot_store.SnapshotStore.store_snapshot") + @patch("iac_reverse.incremental.snapshot_store.SnapshotStore.load_previous") + @patch("iac_reverse.scanner.scanner.Scanner.scan") + @patch("iac_reverse.scanner.scanner.Scanner._compute_profile_hash") + @patch("iac_reverse.cli.cli._create_plugin") + def test_diff_first_scan( + self, + mock_create_plugin, + mock_compute_hash, + mock_scan, + mock_load_previous, + mock_store_snapshot, + mock_compare, + runner, + tmp_path, + ): + """Test diff command on first scan (no previous snapshot).""" + from iac_reverse.models import ScanResult, ChangeSummary + + profile_file = tmp_path / "profile.yaml" + profile_file.write_text( + "provider: kubernetes\n" + "credentials:\n" + " kubeconfig: /path/to/config\n" + ) + + mock_plugin = MagicMock() + mock_create_plugin.return_value = mock_plugin + + mock_compute_hash.return_value = "abc123" + mock_load_previous.return_value = None + + mock_scan.return_value = ScanResult( + resources=[], + warnings=[], + errors=[], + scan_timestamp="2024-01-01T00:00:00Z", + profile_hash="abc123", + ) + + mock_compare.return_value = ChangeSummary( + added_count=0, removed_count=0, modified_count=0, changes=[] + ) + + result = runner.invoke(cli, ["diff", "--profile", str(profile_file)]) + assert result.exit_code == 0 + assert "Change Summary" in result.output + + +class TestValidateCommand: + """Test the validate command with mocked dependencies.""" + + def test_validate_missing_dir(self, runner): + result = runner.invoke(cli, ["validate"]) + assert result.exit_code != 0 + + @patch("iac_reverse.validator.validator.Validator.validate") + def test_validate_success(self, mock_validate, runner, tmp_path): + """Test validate command with successful validation.""" + from iac_reverse.models import ValidationResult + + mock_validate.return_value = ValidationResult( + init_success=True, + validate_success=True, + plan_success=True, + ) + + result = runner.invoke(cli, ["validate", "--dir", str(tmp_path)]) + assert result.exit_code == 0 + assert "All validations passed" in result.output + + @patch("iac_reverse.validator.validator.Validator.validate") + def test_validate_failure(self, mock_validate, runner, tmp_path): + """Test validate command with validation errors.""" + from iac_reverse.models import ValidationResult, ValidationError + + mock_validate.return_value = ValidationResult( + init_success=True, + validate_success=False, + plan_success=False, + errors=[ + ValidationError(file="main.tf", message="syntax error", line=5) + ], + ) + + result = runner.invoke(cli, ["validate", "--dir", str(tmp_path)]) + assert result.exit_code == 0 + assert "syntax error" in result.output + + +class TestLoginCommand: + """Test the login command with mocked dependencies.""" + + def test_login_missing_options(self, runner): + result = runner.invoke(cli, ["login"]) + assert result.exit_code != 0 + + @patch("iac_reverse.auth.authentik_auth.AuthentikAuthProvider.authenticate_user") + def test_login_success(self, mock_authenticate, runner, tmp_path): + """Test login command with successful authentication.""" + from iac_reverse.auth.authentik_auth import AuthentikSession + + mock_authenticate.return_value = AuthentikSession( + access_token="test-token-123", + refresh_token="refresh-456", + user_id="user@example.com", + groups=["admins", "devops"], + ) + + with runner.isolated_filesystem(temp_dir=tmp_path): + result = runner.invoke( + cli, + [ + "login", + "--url", "https://auth.internal.lab", + "--client-id", "iac-reverse", + "--client-secret", "secret123", + ], + ) + assert result.exit_code == 0 + assert "Authenticated as user" in result.output + assert "user@example.com" in result.output + + @patch("iac_reverse.auth.authentik_auth.AuthentikAuthProvider.authenticate_user") + def test_login_failure(self, mock_authenticate, runner, tmp_path): + """Test login command with authentication failure.""" + from iac_reverse.auth.authentik_auth import AuthenticationError + + mock_authenticate.side_effect = AuthenticationError( + "Invalid credentials" + ) + + with runner.isolated_filesystem(temp_dir=tmp_path): + result = runner.invoke( + cli, + [ + "login", + "--url", "https://auth.internal.lab", + "--client-id", "iac-reverse", + "--client-secret", "bad-secret", + ], + ) + assert result.exit_code != 0 + assert "Authentication failed" in result.output + + +class TestLoadScanProfile: + """Test the _load_scan_profile helper function.""" + + def test_load_valid_profile(self, tmp_path): + profile_file = tmp_path / "profile.yaml" + profile_file.write_text( + "provider: kubernetes\n" + "credentials:\n" + " kubeconfig: /path/to/config\n" + ) + + profile = _load_scan_profile(str(profile_file)) + assert profile.provider == ProviderType.KUBERNETES + assert profile.credentials == {"kubeconfig": "/path/to/config"} + + def test_load_nonexistent_profile(self): + with pytest.raises(Exception) as exc_info: + _load_scan_profile("/nonexistent/path.yaml") + assert "not found" in str(exc_info.value).lower() or "Profile not found" in str(exc_info.value) + + def test_load_invalid_yaml(self, tmp_path): + profile_file = tmp_path / "bad.yaml" + profile_file.write_text(": : : invalid yaml [[[") + + with pytest.raises(Exception): + _load_scan_profile(str(profile_file)) + + def test_load_unknown_provider(self, tmp_path): + profile_file = tmp_path / "profile.yaml" + profile_file.write_text( + "provider: unknown_provider\n" + "credentials:\n" + " key: value\n" + ) + + with pytest.raises(Exception) as exc_info: + _load_scan_profile(str(profile_file)) + assert "Unknown provider" in str(exc_info.value) + + +# Import for type reference in tests +from iac_reverse.models import ProviderType diff --git a/tests/unit/test_code_generator.py b/tests/unit/test_code_generator.py new file mode 100644 index 0000000..47eb342 --- /dev/null +++ b/tests/unit/test_code_generator.py @@ -0,0 +1,639 @@ +"""Unit tests for the CodeGenerator.""" + +import pytest + +from iac_reverse.models import ( + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + GeneratedFile, + PlatformCategory, + ProviderType, + ResourceRelationship, + ScanProfile, +) +from iac_reverse.generator import CodeGenerator + + +# --------------------------------------------------------------------------- +# Helpers / Fixtures +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "default/deployments/nginx", + name: str = "nginx", + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, + architecture: CpuArchitecture = CpuArchitecture.AARCH64, + attributes: dict | None = None, + raw_references: list[str] | None = None, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=architecture, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {}, + raw_references=raw_references or [], + ) + + +def make_graph( + resources: list[DiscoveredResource], + relationships: list[ResourceRelationship] | None = None, +) -> DependencyGraph: + """Create a DependencyGraph from resources and optional relationships.""" + return DependencyGraph( + resources=resources, + relationships=relationships or [], + topological_order=[r.unique_id for r in resources], + cycles=[], + unresolved_references=[], + ) + + +def make_profiles() -> list[ScanProfile]: + """Create a default list of scan profiles for testing.""" + return [ + ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={"kubeconfig_path": "/home/user/.kube/config"}, + ) + ] + + +# --------------------------------------------------------------------------- +# Tests: Single resource generates valid HCL +# --------------------------------------------------------------------------- + + +class TestSingleResourceGeneration: + """Tests for generating HCL from a single resource.""" + + def test_single_resource_produces_one_file(self): + """A single resource produces exactly one resource file.""" + resource = make_resource( + attributes={"replicas": 3, "image": "nginx:1.25"}, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert len(result.resource_files) == 1 + + def test_single_resource_file_has_correct_filename(self): + """The generated file is named after the resource type.""" + resource = make_resource(resource_type="kubernetes_deployment") + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert result.resource_files[0].filename == "kubernetes_deployment.tf" + + def test_single_resource_file_contains_resource_block(self): + """The generated file contains a resource block with the correct type.""" + resource = make_resource( + resource_type="kubernetes_deployment", + name="nginx", + attributes={"replicas": 3}, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert 'resource "kubernetes_deployment" "nginx"' in content + + def test_single_resource_includes_attributes(self): + """The generated resource block includes all attributes.""" + resource = make_resource( + attributes={"replicas": 3, "image": "nginx:1.25"}, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert "replicas = 3" in content + assert 'image = "nginx:1.25"' in content + + def test_single_resource_resource_count_is_one(self): + """The resource_count for a single resource file is 1.""" + resource = make_resource() + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert result.resource_files[0].resource_count == 1 + + +# --------------------------------------------------------------------------- +# Tests: Multiple resources of same type go in one file +# --------------------------------------------------------------------------- + + +class TestSameTypeGrouping: + """Tests for grouping multiple resources of the same type into one file.""" + + def test_two_resources_same_type_produce_one_file(self): + """Two resources of the same type produce exactly one file.""" + resource_a = make_resource( + unique_id="default/deployments/app-a", + name="app-a", + attributes={"replicas": 2}, + ) + resource_b = make_resource( + unique_id="default/deployments/app-b", + name="app-b", + attributes={"replicas": 1}, + ) + graph = make_graph([resource_a, resource_b]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert len(result.resource_files) == 1 + + def test_two_resources_same_type_both_in_file(self): + """Both resources appear in the same file.""" + resource_a = make_resource( + unique_id="default/deployments/app-a", + name="app-a", + attributes={"replicas": 2}, + ) + resource_b = make_resource( + unique_id="default/deployments/app-b", + name="app-b", + attributes={"replicas": 1}, + ) + graph = make_graph([resource_a, resource_b]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert 'resource "kubernetes_deployment" "app_a"' in content + assert 'resource "kubernetes_deployment" "app_b"' in content + + def test_resource_count_matches_number_of_resources(self): + """The resource_count reflects the number of resources in the file.""" + resources = [ + make_resource( + unique_id=f"default/deployments/app-{i}", + name=f"app-{i}", + ) + for i in range(3) + ] + graph = make_graph(resources) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert result.resource_files[0].resource_count == 3 + + +# --------------------------------------------------------------------------- +# Tests: Different resource types go in separate files +# --------------------------------------------------------------------------- + + +class TestDifferentTypesSeparateFiles: + """Tests for separating different resource types into different files.""" + + def test_two_different_types_produce_two_files(self): + """Two resources of different types produce two files.""" + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/nginx", + name="nginx", + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/nginx-svc", + name="nginx-svc", + ) + graph = make_graph([deployment, service]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert len(result.resource_files) == 2 + + def test_different_types_have_correct_filenames(self): + """Each file is named after its resource type.""" + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/nginx", + name="nginx", + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/nginx-svc", + name="nginx-svc", + ) + graph = make_graph([deployment, service]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + filenames = {f.filename for f in result.resource_files} + + assert "kubernetes_deployment.tf" in filenames + assert "kubernetes_service.tf" in filenames + + def test_each_file_contains_only_its_type(self): + """Each file contains only resource blocks of its designated type.""" + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/nginx", + name="nginx", + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/nginx-svc", + name="nginx-svc", + ) + graph = make_graph([deployment, service]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + for f in result.resource_files: + if f.filename == "kubernetes_deployment.tf": + assert "kubernetes_deployment" in f.content + assert 'resource "kubernetes_service"' not in f.content + elif f.filename == "kubernetes_service.tf": + assert "kubernetes_service" in f.content + assert 'resource "kubernetes_deployment"' not in f.content + + def test_three_types_produce_three_files(self): + """Three distinct resource types produce three files.""" + resources = [ + make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + ), + make_resource( + resource_type="kubernetes_service", + unique_id="default/services/app-svc", + name="app-svc", + ), + make_resource( + resource_type="windows_service", + unique_id="win/services/iis", + name="iis", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=CpuArchitecture.AMD64, + ), + ] + graph = make_graph(resources) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert len(result.resource_files) == 3 + + +# --------------------------------------------------------------------------- +# Tests: Traceability comments are present +# --------------------------------------------------------------------------- + + +class TestTraceabilityComments: + """Tests for traceability comments in generated HCL.""" + + def test_resource_block_has_source_comment(self): + """Each resource block is preceded by a comment with the unique_id.""" + resource = make_resource( + unique_id="apps/v1/deployments/default/nginx", + name="nginx", + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert "# Source: apps/v1/deployments/default/nginx" in content + + def test_multiple_resources_each_have_source_comment(self): + """Each resource in a multi-resource file has its own source comment.""" + resource_a = make_resource( + unique_id="default/deployments/app-a", + name="app-a", + ) + resource_b = make_resource( + unique_id="default/deployments/app-b", + name="app-b", + ) + graph = make_graph([resource_a, resource_b]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert "# Source: default/deployments/app-a" in content + assert "# Source: default/deployments/app-b" in content + + def test_windows_resource_has_source_comment(self): + """Windows resources also have traceability comments.""" + resource = make_resource( + resource_type="windows_service", + unique_id="win-server-01/services/W3SVC", + name="W3SVC", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=CpuArchitecture.AMD64, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert "# Source: win-server-01/services/W3SVC" in content + + +# --------------------------------------------------------------------------- +# Tests: Architecture tags are included +# --------------------------------------------------------------------------- + + +class TestArchitectureTags: + """Tests for architecture-specific tags/labels on resources.""" + + def test_aarch64_resource_has_arch_tag(self): + """An AArch64 resource includes arch = aarch64 in tags.""" + resource = make_resource( + architecture=CpuArchitecture.AARCH64, + attributes={"replicas": 1}, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert '"arch" = "aarch64"' in content + + def test_amd64_resource_has_arch_tag(self): + """An AMD64 resource includes arch = amd64 in tags.""" + resource = make_resource( + resource_type="windows_service", + unique_id="win/services/svc", + name="svc", + architecture=CpuArchitecture.AMD64, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert '"arch" = "amd64"' in content + + def test_arm_resource_has_arch_tag(self): + """An ARM resource includes arch = arm in tags.""" + resource = make_resource( + architecture=CpuArchitecture.ARM, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert '"arch" = "arm"' in content + + def test_managed_by_tag_is_present(self): + """All resources include a managed_by = iac-reverse tag.""" + resource = make_resource() + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert '"managed_by" = "iac-reverse"' in content + + +# --------------------------------------------------------------------------- +# Tests: Dependencies use Terraform references +# --------------------------------------------------------------------------- + + +class TestTerraformReferences: + """Tests for Terraform resource references in generated HCL.""" + + def test_dependency_uses_terraform_reference(self): + """A resource referencing another uses a Terraform reference expression.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/nginx", + name="nginx", + attributes={"namespace": "default"}, + raw_references=["ns/default"], + ) + relationships = [ + ResourceRelationship( + source_id="default/deployments/nginx", + target_id="ns/default", + relationship_type="parent-child", + source_attribute="namespace", + ) + ] + graph = make_graph([namespace, deployment], relationships) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + # Find the deployment file + deploy_file = next( + f for f in result.resource_files + if f.filename == "kubernetes_deployment.tf" + ) + content = deploy_file.content + + # Should use Terraform reference, not hardcoded "default" + assert "kubernetes_namespace.default.id" in content + + def test_reference_by_unique_id_in_attribute(self): + """An attribute containing a target's unique_id is replaced with a reference.""" + app_pool = make_resource( + resource_type="windows_iis_app_pool", + unique_id="win/iis/app_pools/DefaultAppPool", + name="DefaultAppPool", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=CpuArchitecture.AMD64, + ) + site = make_resource( + resource_type="windows_iis_site", + unique_id="win/iis/sites/MySite", + name="MySite", + attributes={"app_pool": "win/iis/app_pools/DefaultAppPool"}, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=CpuArchitecture.AMD64, + ) + relationships = [ + ResourceRelationship( + source_id="win/iis/sites/MySite", + target_id="win/iis/app_pools/DefaultAppPool", + relationship_type="dependency", + source_attribute="app_pool", + ) + ] + graph = make_graph([app_pool, site], relationships) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + site_file = next( + f for f in result.resource_files + if f.filename == "windows_iis_site.tf" + ) + content = site_file.content + + assert "windows_iis_app_pool.DefaultAppPool.id" in content + + def test_non_reference_attributes_remain_literal(self): + """Attributes that don't reference other resources remain as literals.""" + resource = make_resource( + attributes={"replicas": 3, "image": "nginx:1.25"}, + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + assert "replicas = 3" in content + assert 'image = "nginx:1.25"' in content + + def test_multiple_dependencies_all_use_references(self): + """A resource with multiple dependencies uses references for all.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/prod", + name="prod", + ) + config_map = make_resource( + resource_type="kubernetes_config_map", + unique_id="prod/configmaps/app-config", + name="app-config", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="prod/deployments/app", + name="app", + attributes={ + "namespace": "prod", + "config_map": "app-config", + "replicas": 2, + }, + raw_references=["ns/prod", "prod/configmaps/app-config"], + ) + relationships = [ + ResourceRelationship( + source_id="prod/deployments/app", + target_id="ns/prod", + relationship_type="parent-child", + source_attribute="namespace", + ), + ResourceRelationship( + source_id="prod/deployments/app", + target_id="prod/configmaps/app-config", + relationship_type="dependency", + source_attribute="config_map", + ), + ] + graph = make_graph([namespace, config_map, deployment], relationships) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + deploy_file = next( + f for f in result.resource_files + if f.filename == "kubernetes_deployment.tf" + ) + content = deploy_file.content + + assert "kubernetes_namespace.prod.id" in content + assert "kubernetes_config_map.app_config.id" in content + assert "replicas = 2" in content + + +# --------------------------------------------------------------------------- +# Tests: Result structure +# --------------------------------------------------------------------------- + + +class TestResultStructure: + """Tests for the CodeGenerationResult structure.""" + + def test_result_has_variables_file(self): + """The result includes a variables_file (empty placeholder for now).""" + resource = make_resource() + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert result.variables_file is not None + assert result.variables_file.filename == "variables.tf" + + def test_result_has_provider_file(self): + """The result includes a provider_file (empty placeholder for now).""" + resource = make_resource() + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert result.provider_file is not None + assert result.provider_file.filename == "providers.tf" + + def test_empty_graph_produces_no_resource_files(self): + """An empty graph produces no resource files.""" + graph = make_graph([]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + + assert result.resource_files == [] + + def test_name_sanitization_replaces_special_chars(self): + """Resource names with special characters are sanitized for Terraform.""" + resource = make_resource( + name="my-app.v2", + unique_id="default/deployments/my-app.v2", + ) + graph = make_graph([resource]) + generator = CodeGenerator() + + result = generator.generate(graph, make_profiles()) + content = result.resource_files[0].content + + # Should use sanitized name (hyphens and dots replaced with underscores) + assert 'resource "kubernetes_deployment" "my_app_v2"' in content diff --git a/tests/unit/test_docker_swarm_plugin.py b/tests/unit/test_docker_swarm_plugin.py new file mode 100644 index 0000000..d49c86c --- /dev/null +++ b/tests/unit/test_docker_swarm_plugin.py @@ -0,0 +1,444 @@ +"""Unit tests for the DockerSwarmPlugin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, +) +from iac_reverse.scanner import AuthenticationError +from iac_reverse.scanner.docker_swarm_plugin import DockerSwarmPlugin + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def plugin(): + """Create a fresh DockerSwarmPlugin instance.""" + return DockerSwarmPlugin() + + +@pytest.fixture +def mock_docker_client(): + """Create a mock Docker client with common attributes.""" + client = MagicMock() + client.ping.return_value = True + client.info.return_value = {"Architecture": "x86_64"} + client.services.list.return_value = [] + client.networks.list.return_value = [] + client.volumes.list.return_value = [] + client.configs.list.return_value = [] + client.secrets.list.return_value = [] + return client + + +@pytest.fixture +def authenticated_plugin(plugin, mock_docker_client): + """Return a plugin that has been authenticated with a mock client.""" + with patch("iac_reverse.scanner.docker_swarm_plugin.docker.DockerClient") as mock_cls: + mock_cls.return_value = mock_docker_client + plugin.authenticate({"host": "tcp://localhost:2376"}) + return plugin + + +# --------------------------------------------------------------------------- +# Authentication tests +# --------------------------------------------------------------------------- + + +class TestAuthenticate: + def test_authenticate_success(self, plugin): + """Successful authentication connects to Docker daemon.""" + with patch("iac_reverse.scanner.docker_swarm_plugin.docker.DockerClient") as mock_cls: + mock_client = MagicMock() + mock_client.ping.return_value = True + mock_cls.return_value = mock_client + + plugin.authenticate({"host": "tcp://192.168.1.10:2376"}) + + mock_cls.assert_called_once_with( + base_url="tcp://192.168.1.10:2376", + tls=False, + ) + mock_client.ping.assert_called_once() + + def test_authenticate_missing_host(self, plugin): + """Authentication fails when host is not provided.""" + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({}) + + assert "host" in str(exc_info.value).lower() + + def test_authenticate_empty_host(self, plugin): + """Authentication fails when host is empty string.""" + with pytest.raises(AuthenticationError): + plugin.authenticate({"host": ""}) + + def test_authenticate_connection_failure(self, plugin): + """Authentication fails when Docker daemon is unreachable.""" + with patch("iac_reverse.scanner.docker_swarm_plugin.docker.DockerClient") as mock_cls: + mock_client = MagicMock() + mock_client.ping.side_effect = ConnectionError("Connection refused") + mock_cls.return_value = mock_client + + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"host": "tcp://unreachable:2376"}) + + assert "Connection refused" in str(exc_info.value) + + def test_authenticate_with_tls(self, plugin): + """Authentication configures TLS when tls_verify and cert_path are set.""" + with patch("iac_reverse.scanner.docker_swarm_plugin.docker.DockerClient") as mock_cls: + with patch("iac_reverse.scanner.docker_swarm_plugin.TLSConfig") as mock_tls: + mock_client = MagicMock() + mock_client.ping.return_value = True + mock_cls.return_value = mock_client + mock_tls_instance = MagicMock() + mock_tls.return_value = mock_tls_instance + + plugin.authenticate({ + "host": "tcp://secure-host:2376", + "tls_verify": "true", + "cert_path": "/certs", + }) + + mock_tls.assert_called_once_with( + verify=True, + client_cert=("/certs/cert.pem", "/certs/key.pem"), + ca_cert="/certs/ca.pem", + ) + mock_cls.assert_called_once_with( + base_url="tcp://secure-host:2376", + tls=mock_tls_instance, + ) + + +# --------------------------------------------------------------------------- +# Platform category and resource types +# --------------------------------------------------------------------------- + + +class TestPlatformInfo: + def test_get_platform_category(self, plugin): + """Returns CONTAINER_ORCHESTRATION category.""" + assert plugin.get_platform_category() == PlatformCategory.CONTAINER_ORCHESTRATION + + def test_list_supported_resource_types(self, plugin): + """Returns all five Docker Swarm resource types.""" + types = plugin.list_supported_resource_types() + assert types == [ + "docker_service", + "docker_network", + "docker_volume", + "docker_config", + "docker_secret", + ] + + def test_list_endpoints_before_auth(self, plugin): + """Returns empty list before authentication.""" + assert plugin.list_endpoints() == [] + + def test_list_endpoints_after_auth(self, authenticated_plugin): + """Returns the host URL after authentication.""" + assert authenticated_plugin.list_endpoints() == ["tcp://localhost:2376"] + + +# --------------------------------------------------------------------------- +# Architecture detection +# --------------------------------------------------------------------------- + + +class TestDetectArchitecture: + def test_detect_amd64(self, authenticated_plugin, mock_docker_client): + """Detects AMD64 architecture from x86_64 info.""" + mock_docker_client.info.return_value = {"Architecture": "x86_64"} + arch = authenticated_plugin.detect_architecture("tcp://localhost:2376") + assert arch == CpuArchitecture.AMD64 + + def test_detect_aarch64(self, authenticated_plugin, mock_docker_client): + """Detects AARCH64 architecture from aarch64 info.""" + mock_docker_client.info.return_value = {"Architecture": "aarch64"} + arch = authenticated_plugin.detect_architecture("tcp://localhost:2376") + assert arch == CpuArchitecture.AARCH64 + + def test_detect_arm(self, authenticated_plugin, mock_docker_client): + """Detects ARM architecture from armv7l info.""" + mock_docker_client.info.return_value = {"Architecture": "armv7l"} + arch = authenticated_plugin.detect_architecture("tcp://localhost:2376") + assert arch == CpuArchitecture.ARM + + def test_detect_unknown_defaults_to_amd64(self, authenticated_plugin, mock_docker_client): + """Unknown architecture string defaults to AMD64.""" + mock_docker_client.info.return_value = {"Architecture": "sparc"} + arch = authenticated_plugin.detect_architecture("tcp://localhost:2376") + assert arch == CpuArchitecture.AMD64 + + def test_detect_without_client(self, plugin): + """Returns AMD64 when no client is connected.""" + arch = plugin.detect_architecture("tcp://localhost:2376") + assert arch == CpuArchitecture.AMD64 + + +# --------------------------------------------------------------------------- +# Resource discovery +# --------------------------------------------------------------------------- + + +class TestDiscoverResources: + def _noop_callback(self, progress: ScanProgress) -> None: + pass + + def test_discover_services(self, authenticated_plugin, mock_docker_client): + """Discovers Docker services with correct attributes.""" + mock_service = MagicMock() + mock_service.attrs = { + "ID": "svc123", + "Spec": { + "Name": "web-app", + "Mode": {"Replicated": {"Replicas": 3}}, + "Labels": {"env": "prod"}, + "TaskTemplate": { + "ContainerSpec": { + "Image": "nginx:latest", + }, + "Networks": [{"Target": "net456"}], + }, + }, + } + mock_docker_client.services.list.return_value = [mock_service] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_service"], + progress_callback=self._noop_callback, + ) + + assert len(result.resources) == 1 + svc = result.resources[0] + assert svc.resource_type == "docker_service" + assert svc.unique_id == "svc123" + assert svc.name == "web-app" + assert svc.provider == ProviderType.DOCKER_SWARM + assert svc.platform_category == PlatformCategory.CONTAINER_ORCHESTRATION + assert svc.attributes["image"] == "nginx:latest" + assert svc.attributes["replicas"] == 3 + assert "network:net456" in svc.raw_references + + def test_discover_networks(self, authenticated_plugin, mock_docker_client): + """Discovers Docker networks with correct attributes.""" + mock_network = MagicMock() + mock_network.attrs = { + "Id": "net789", + "Name": "overlay-net", + "Driver": "overlay", + "Scope": "swarm", + "Attachable": True, + "Ingress": False, + "Labels": {}, + } + mock_docker_client.networks.list.return_value = [mock_network] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_network"], + progress_callback=self._noop_callback, + ) + + assert len(result.resources) == 1 + net = result.resources[0] + assert net.resource_type == "docker_network" + assert net.unique_id == "net789" + assert net.name == "overlay-net" + assert net.attributes["driver"] == "overlay" + assert net.attributes["scope"] == "swarm" + assert net.attributes["attachable"] is True + + def test_discover_volumes(self, authenticated_plugin, mock_docker_client): + """Discovers Docker volumes with correct attributes.""" + mock_volume = MagicMock() + mock_volume.attrs = { + "Name": "data-vol", + "Driver": "local", + "Mountpoint": "/var/lib/docker/volumes/data-vol/_data", + "Labels": {"backup": "daily"}, + } + mock_docker_client.volumes.list.return_value = [mock_volume] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_volume"], + progress_callback=self._noop_callback, + ) + + assert len(result.resources) == 1 + vol = result.resources[0] + assert vol.resource_type == "docker_volume" + assert vol.unique_id == "data-vol" + assert vol.name == "data-vol" + assert vol.attributes["driver"] == "local" + + def test_discover_configs(self, authenticated_plugin, mock_docker_client): + """Discovers Docker configs (metadata only).""" + mock_config = MagicMock() + mock_config.attrs = { + "ID": "cfg001", + "Spec": { + "Name": "app-config", + "Labels": {"version": "2"}, + }, + "CreatedAt": "2024-01-01T00:00:00Z", + "UpdatedAt": "2024-01-02T00:00:00Z", + } + mock_docker_client.configs.list.return_value = [mock_config] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_config"], + progress_callback=self._noop_callback, + ) + + assert len(result.resources) == 1 + cfg = result.resources[0] + assert cfg.resource_type == "docker_config" + assert cfg.unique_id == "cfg001" + assert cfg.name == "app-config" + assert cfg.attributes["created_at"] == "2024-01-01T00:00:00Z" + + def test_discover_secrets(self, authenticated_plugin, mock_docker_client): + """Discovers Docker secrets (metadata only, no secret data).""" + mock_secret = MagicMock() + mock_secret.attrs = { + "ID": "sec001", + "Spec": { + "Name": "db-password", + "Labels": {}, + }, + "CreatedAt": "2024-01-01T00:00:00Z", + "UpdatedAt": "2024-01-01T00:00:00Z", + } + mock_docker_client.secrets.list.return_value = [mock_secret] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_secret"], + progress_callback=self._noop_callback, + ) + + assert len(result.resources) == 1 + sec = result.resources[0] + assert sec.resource_type == "docker_secret" + assert sec.unique_id == "sec001" + assert sec.name == "db-password" + + def test_discover_all_types(self, authenticated_plugin, mock_docker_client): + """Discovers all resource types in a single call.""" + mock_docker_client.services.list.return_value = [] + mock_docker_client.networks.list.return_value = [] + mock_docker_client.volumes.list.return_value = [] + mock_docker_client.configs.list.return_value = [] + mock_docker_client.secrets.list.return_value = [] + + all_types = [ + "docker_service", + "docker_network", + "docker_volume", + "docker_config", + "docker_secret", + ] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=all_types, + progress_callback=self._noop_callback, + ) + + assert result.errors == [] + assert result.warnings == [] + + def test_discover_reports_progress(self, authenticated_plugin, mock_docker_client): + """Progress callback is invoked for each resource type.""" + progress_updates: list[ScanProgress] = [] + + def track_progress(p: ScanProgress): + progress_updates.append(p) + + authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_service", "docker_network"], + progress_callback=track_progress, + ) + + assert len(progress_updates) == 2 + assert progress_updates[0].current_resource_type == "docker_service" + assert progress_updates[0].resource_types_completed == 1 + assert progress_updates[0].total_resource_types == 2 + assert progress_updates[1].current_resource_type == "docker_network" + assert progress_updates[1].resource_types_completed == 2 + + def test_discover_handles_api_error(self, authenticated_plugin, mock_docker_client): + """API errors are captured in the result errors list.""" + mock_docker_client.services.list.side_effect = Exception("API timeout") + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_service"], + progress_callback=self._noop_callback, + ) + + assert len(result.errors) == 1 + assert "API timeout" in result.errors[0] + + def test_discover_without_authentication(self, plugin): + """Returns error when called without authentication.""" + result = plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_service"], + progress_callback=lambda p: None, + ) + + assert len(result.errors) == 1 + assert "Not authenticated" in result.errors[0] + + def test_service_references_include_volumes_configs_secrets( + self, authenticated_plugin, mock_docker_client + ): + """Service references include network, volume, config, and secret refs.""" + mock_service = MagicMock() + mock_service.attrs = { + "ID": "svc-ref", + "Spec": { + "Name": "full-service", + "Mode": {"Replicated": {"Replicas": 1}}, + "Labels": {}, + "TaskTemplate": { + "ContainerSpec": { + "Image": "app:v1", + "Mounts": [{"Source": "data-vol"}], + "Configs": [{"ConfigID": "cfg-abc"}], + "Secrets": [{"SecretID": "sec-xyz"}], + }, + "Networks": [{"Target": "net-123"}], + }, + }, + } + mock_docker_client.services.list.return_value = [mock_service] + + result = authenticated_plugin.discover_resources( + endpoints=["tcp://localhost:2376"], + resource_types=["docker_service"], + progress_callback=self._noop_callback, + ) + + refs = result.resources[0].raw_references + assert "network:net-123" in refs + assert "volume:data-vol" in refs + assert "config:cfg-abc" in refs + assert "secret:sec-xyz" in refs diff --git a/tests/unit/test_harvester_plugin.py b/tests/unit/test_harvester_plugin.py new file mode 100644 index 0000000..1ada318 --- /dev/null +++ b/tests/unit/test_harvester_plugin.py @@ -0,0 +1,569 @@ +"""Unit tests for the HarvesterPlugin provider plugin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + PlatformCategory, + ProviderType, + ScanProgress, +) +from iac_reverse.scanner.harvester_plugin import HarvesterPlugin + + +# Patch targets for kubernetes client classes +PATCH_NEW_CLIENT = "iac_reverse.scanner.harvester_plugin.config.new_client_from_config" +PATCH_CUSTOM_API = "iac_reverse.scanner.harvester_plugin.client.CustomObjectsApi" +PATCH_CORE_API = "iac_reverse.scanner.harvester_plugin.client.CoreV1Api" + + +class TestHarvesterPluginAuthentication: + """Tests for HarvesterPlugin.authenticate().""" + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_authenticate_with_kubeconfig_path( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """Authenticate loads kubeconfig from the provided path.""" + mock_api_client = MagicMock() + mock_new_client.return_value = mock_api_client + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + mock_new_client.assert_called_once_with( + config_file="/path/to/kubeconfig", + context=None, + ) + assert plugin._api_client is mock_api_client + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_authenticate_with_context( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """Authenticate uses the optional context parameter.""" + mock_api_client = MagicMock() + mock_new_client.return_value = mock_api_client + + plugin = HarvesterPlugin() + plugin.authenticate({ + "kubeconfig_path": "/path/to/kubeconfig", + "context": "harvester-cluster", + }) + + mock_new_client.assert_called_once_with( + config_file="/path/to/kubeconfig", + context="harvester-cluster", + ) + + def test_authenticate_missing_kubeconfig_path(self): + """Authenticate raises AuthenticationError when kubeconfig_path is missing.""" + from iac_reverse.scanner import AuthenticationError + + plugin = HarvesterPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({}) + + assert "kubeconfig_path" in str(exc_info.value) + assert "harvester" in str(exc_info.value) + + @patch(PATCH_NEW_CLIENT) + def test_authenticate_invalid_kubeconfig(self, mock_new_client): + """Authenticate raises AuthenticationError when kubeconfig is invalid.""" + from iac_reverse.scanner import AuthenticationError + + mock_new_client.side_effect = Exception("Invalid kubeconfig format") + + plugin = HarvesterPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"kubeconfig_path": "/bad/path"}) + + assert "Failed to load kubeconfig" in str(exc_info.value) + + +class TestHarvesterPluginMetadata: + """Tests for HarvesterPlugin metadata methods.""" + + def test_get_platform_category(self): + """get_platform_category returns HCI.""" + plugin = HarvesterPlugin() + assert plugin.get_platform_category() == PlatformCategory.HCI + + def test_list_supported_resource_types(self): + """list_supported_resource_types returns all Harvester resource types.""" + plugin = HarvesterPlugin() + types = plugin.list_supported_resource_types() + + assert types == [ + "harvester_virtualmachine", + "harvester_volume", + "harvester_image", + "harvester_network", + ] + + def test_list_endpoints_unauthenticated(self): + """list_endpoints returns empty list when not authenticated.""" + plugin = HarvesterPlugin() + assert plugin.list_endpoints() == [] + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_list_endpoints_authenticated( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """list_endpoints returns the cluster API server URL.""" + mock_api_client = MagicMock() + mock_api_client.configuration.host = "https://harvester.local:6443" + mock_new_client.return_value = mock_api_client + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + endpoints = plugin.list_endpoints() + assert endpoints == ["https://harvester.local:6443"] + + +class TestHarvesterPluginDetectArchitecture: + """Tests for HarvesterPlugin.detect_architecture().""" + + def test_detect_architecture_unauthenticated(self): + """detect_architecture returns AMD64 when not authenticated.""" + plugin = HarvesterPlugin() + arch = plugin.detect_architecture("https://harvester.local:6443") + assert arch == CpuArchitecture.AMD64 + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_detect_architecture_amd64( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """detect_architecture returns AMD64 for amd64 nodes.""" + mock_api_client = MagicMock() + mock_new_client.return_value = mock_api_client + + mock_core_instance = MagicMock() + mock_core_cls.return_value = mock_core_instance + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + # Mock node list response + mock_node = MagicMock() + mock_node.status.node_info.architecture = "amd64" + mock_core_instance.list_node.return_value = MagicMock(items=[mock_node]) + + arch = plugin.detect_architecture("https://harvester.local:6443") + assert arch == CpuArchitecture.AMD64 + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_detect_architecture_arm64( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """detect_architecture returns AARCH64 for arm64 nodes.""" + mock_api_client = MagicMock() + mock_new_client.return_value = mock_api_client + + mock_core_instance = MagicMock() + mock_core_cls.return_value = mock_core_instance + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + mock_node = MagicMock() + mock_node.status.node_info.architecture = "arm64" + mock_core_instance.list_node.return_value = MagicMock(items=[mock_node]) + + arch = plugin.detect_architecture("https://harvester.local:6443") + assert arch == CpuArchitecture.AARCH64 + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_detect_architecture_arm( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """detect_architecture returns ARM for arm nodes.""" + mock_api_client = MagicMock() + mock_new_client.return_value = mock_api_client + + mock_core_instance = MagicMock() + mock_core_cls.return_value = mock_core_instance + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + mock_node = MagicMock() + mock_node.status.node_info.architecture = "arm" + mock_core_instance.list_node.return_value = MagicMock(items=[mock_node]) + + arch = plugin.detect_architecture("https://harvester.local:6443") + assert arch == CpuArchitecture.ARM + + @patch(PATCH_CORE_API) + @patch(PATCH_CUSTOM_API) + @patch(PATCH_NEW_CLIENT) + def test_detect_architecture_api_error_defaults_amd64( + self, mock_new_client, mock_custom_cls, mock_core_cls + ): + """detect_architecture defaults to AMD64 on API errors.""" + from kubernetes.client.rest import ApiException + + mock_api_client = MagicMock() + mock_new_client.return_value = mock_api_client + + mock_core_instance = MagicMock() + mock_core_cls.return_value = mock_core_instance + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + mock_core_instance.list_node.side_effect = ApiException(status=403) + + arch = plugin.detect_architecture("https://harvester.local:6443") + assert arch == CpuArchitecture.AMD64 + + +def _make_authenticated_plugin(): + """Create an authenticated plugin with mocked APIs. + + Returns (plugin, mock_custom_api, mock_core_api) tuple. + """ + with patch(PATCH_NEW_CLIENT) as mock_new_client, \ + patch(PATCH_CUSTOM_API) as mock_custom_cls, \ + patch(PATCH_CORE_API) as mock_core_cls: + + mock_api_client = MagicMock() + mock_api_client.configuration.host = "https://harvester.local:6443" + mock_new_client.return_value = mock_api_client + + mock_custom_instance = MagicMock() + mock_custom_cls.return_value = mock_custom_instance + + mock_core_instance = MagicMock() + mock_core_cls.return_value = mock_core_instance + + plugin = HarvesterPlugin() + plugin.authenticate({"kubeconfig_path": "/path/to/kubeconfig"}) + + # Mock node for architecture detection + mock_node = MagicMock() + mock_node.status.node_info.architecture = "amd64" + mock_core_instance.list_node.return_value = MagicMock(items=[mock_node]) + + return plugin, mock_custom_instance, mock_core_instance + + +class TestHarvesterPluginDiscoverResources: + """Tests for HarvesterPlugin.discover_resources().""" + + def test_discover_vms(self): + """discover_resources discovers virtual machines.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + + mock_custom_api.list_cluster_custom_object.return_value = { + "items": [ + { + "metadata": { + "name": "test-vm", + "namespace": "default", + "uid": "vm-uid-123", + "labels": {"app": "web"}, + "annotations": {}, + }, + "spec": { + "running": True, + "template": { + "spec": { + "volumes": [ + {"dataVolume": {"name": "test-disk"}}, + ], + "networks": [ + {"multus": {"networkName": "vlan100"}}, + ], + } + }, + }, + } + ] + } + + progress_updates = [] + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_virtualmachine"], + progress_callback=lambda p: progress_updates.append(p), + ) + + assert len(result.resources) == 1 + vm = result.resources[0] + assert vm.resource_type == "harvester_virtualmachine" + assert vm.name == "test-vm" + assert vm.unique_id == "vm-uid-123" + assert vm.provider == ProviderType.HARVESTER + assert vm.platform_category == PlatformCategory.HCI + assert vm.architecture == CpuArchitecture.AMD64 + assert vm.attributes["running"] is True + assert "volume:test-disk" in vm.raw_references + assert "network:vlan100" in vm.raw_references + + def test_discover_volumes(self): + """discover_resources discovers data volumes.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + + mock_custom_api.list_cluster_custom_object.return_value = { + "items": [ + { + "metadata": { + "name": "test-disk", + "namespace": "default", + "uid": "vol-uid-456", + "labels": {}, + }, + "spec": { + "source": {"http": {"url": "https://images.example.com/disk.img"}}, + "pvc": {"accessModes": ["ReadWriteOnce"], "resources": {"requests": {"storage": "10Gi"}}}, + }, + } + ] + } + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_volume"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + vol = result.resources[0] + assert vol.resource_type == "harvester_volume" + assert vol.name == "test-disk" + assert vol.unique_id == "vol-uid-456" + + def test_discover_images(self): + """discover_resources discovers VM images.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + + mock_custom_api.list_cluster_custom_object.return_value = { + "items": [ + { + "metadata": { + "name": "ubuntu-22.04", + "namespace": "default", + "uid": "img-uid-789", + "labels": {"os": "linux"}, + }, + "spec": { + "displayName": "Ubuntu 22.04 LTS", + "url": "https://cloud-images.ubuntu.com/jammy/current/jammy-server-cloudimg-amd64.img", + }, + } + ] + } + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_image"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + img = result.resources[0] + assert img.resource_type == "harvester_image" + assert img.name == "ubuntu-22.04" + assert img.attributes["display_name"] == "Ubuntu 22.04 LTS" + assert "ubuntu.com" in img.attributes["url"] + + def test_discover_networks(self): + """discover_resources discovers network attachment definitions.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + + mock_custom_api.list_cluster_custom_object.return_value = { + "items": [ + { + "metadata": { + "name": "vlan100", + "namespace": "default", + "uid": "net-uid-abc", + "labels": {}, + }, + "spec": { + "config": '{"cniVersion":"0.3.1","name":"vlan100","type":"bridge"}', + }, + } + ] + } + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_network"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + net = result.resources[0] + assert net.resource_type == "harvester_network" + assert net.name == "vlan100" + assert "vlan100" in net.attributes["config"] + + def test_discover_multiple_resource_types(self): + """discover_resources handles multiple resource types in one call.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + + # Return different items based on the CRD being queried + def mock_list_custom_object(group, version, plural): + if plural == "virtualmachines": + return {"items": [{"metadata": {"name": "vm1", "namespace": "default", "uid": "uid-1", "labels": {}, "annotations": {}}, "spec": {"running": True, "template": {"spec": {}}}}]} + elif plural == "network-attachment-definitions": + return {"items": [{"metadata": {"name": "net1", "namespace": "default", "uid": "uid-2", "labels": {}}, "spec": {"config": "{}"}}]} + return {"items": []} + + mock_custom_api.list_cluster_custom_object.side_effect = mock_list_custom_object + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_virtualmachine", "harvester_network"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 2 + types = {r.resource_type for r in result.resources} + assert types == {"harvester_virtualmachine", "harvester_network"} + + def test_discover_resources_api_error(self): + """discover_resources records errors when API calls fail.""" + from kubernetes.client.rest import ApiException + + plugin, mock_custom_api, _ = _make_authenticated_plugin() + mock_custom_api.list_cluster_custom_object.side_effect = ApiException( + status=403, reason="Forbidden" + ) + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_virtualmachine"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 1 + assert "403" in result.errors[0] + + def test_discover_resources_progress_callback(self): + """discover_resources invokes progress_callback correctly.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + mock_custom_api.list_cluster_custom_object.return_value = {"items": []} + + progress_updates: list[ScanProgress] = [] + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_virtualmachine", "harvester_volume"], + progress_callback=lambda p: progress_updates.append(p), + ) + + # Should have progress updates: one per resource type + final + assert len(progress_updates) == 3 + assert progress_updates[0].current_resource_type == "harvester_virtualmachine" + assert progress_updates[0].total_resource_types == 2 + assert progress_updates[-1].resource_types_completed == 2 + + def test_discover_resources_unknown_type_warning(self): + """discover_resources warns about unknown resource types.""" + plugin, mock_custom_api, _ = _make_authenticated_plugin() + + result = plugin.discover_resources( + endpoints=["https://harvester.local:6443"], + resource_types=["harvester_unknown"], + progress_callback=lambda p: None, + ) + + assert len(result.warnings) == 1 + assert "Unknown resource type" in result.warnings[0] + + +class TestHarvesterPluginVMReferences: + """Tests for VM reference extraction.""" + + def test_extract_vm_references_data_volume(self): + """Extracts dataVolume references from VM spec.""" + spec = { + "template": { + "spec": { + "volumes": [ + {"dataVolume": {"name": "my-disk"}}, + ], + "networks": [], + } + } + } + refs = HarvesterPlugin._extract_vm_references(spec) + assert refs == ["volume:my-disk"] + + def test_extract_vm_references_pvc(self): + """Extracts persistentVolumeClaim references from VM spec.""" + spec = { + "template": { + "spec": { + "volumes": [ + {"persistentVolumeClaim": {"claimName": "my-pvc"}}, + ], + "networks": [], + } + } + } + refs = HarvesterPlugin._extract_vm_references(spec) + assert refs == ["volume:my-pvc"] + + def test_extract_vm_references_multus_network(self): + """Extracts multus network references from VM spec.""" + spec = { + "template": { + "spec": { + "volumes": [], + "networks": [ + {"multus": {"networkName": "vlan200"}}, + ], + } + } + } + refs = HarvesterPlugin._extract_vm_references(spec) + assert refs == ["network:vlan200"] + + def test_extract_vm_references_empty_spec(self): + """Returns empty list for empty spec.""" + refs = HarvesterPlugin._extract_vm_references({}) + assert refs == [] + + def test_extract_vm_references_mixed(self): + """Extracts both volume and network references.""" + spec = { + "template": { + "spec": { + "volumes": [ + {"dataVolume": {"name": "disk-1"}}, + {"persistentVolumeClaim": {"claimName": "pvc-1"}}, + ], + "networks": [ + {"multus": {"networkName": "mgmt-net"}}, + {"multus": {"networkName": "data-net"}}, + ], + } + } + } + refs = HarvesterPlugin._extract_vm_references(spec) + assert refs == [ + "volume:disk-1", + "volume:pvc-1", + "network:mgmt-net", + "network:data-net", + ] diff --git a/tests/unit/test_incremental_updater.py b/tests/unit/test_incremental_updater.py new file mode 100644 index 0000000..ea96f9d --- /dev/null +++ b/tests/unit/test_incremental_updater.py @@ -0,0 +1,513 @@ +"""Unit tests for the IncrementalUpdater class.""" + +import json +from pathlib import Path + +import pytest + +from iac_reverse.incremental.incremental_updater import IncrementalUpdater +from iac_reverse.models import ChangeSummary, ChangeType, ResourceChange + + +def _make_change( + resource_id: str = "res-1", + resource_type: str = "kubernetes_deployment", + resource_name: str = "nginx", + change_type: ChangeType = ChangeType.ADDED, + changed_attributes: dict | None = None, +) -> ResourceChange: + """Create a ResourceChange for testing.""" + return ResourceChange( + resource_id=resource_id, + resource_type=resource_type, + resource_name=resource_name, + change_type=change_type, + changed_attributes=changed_attributes, + ) + + +def _make_summary(changes: list[ResourceChange]) -> ChangeSummary: + """Create a ChangeSummary from a list of changes.""" + added = sum(1 for c in changes if c.change_type == ChangeType.ADDED) + removed = sum(1 for c in changes if c.change_type == ChangeType.REMOVED) + modified = sum(1 for c in changes if c.change_type == ChangeType.MODIFIED) + return ChangeSummary( + added_count=added, + removed_count=removed, + modified_count=modified, + changes=changes, + ) + + +def _write_tf_file(output_dir: Path, resource_type: str, content: str) -> Path: + """Write a .tf file to the output directory.""" + tf_file = output_dir / f"{resource_type}.tf" + tf_file.write_text(content, encoding="utf-8") + return tf_file + + +def _write_state_file(output_dir: Path, resources: list[dict]) -> Path: + """Write a terraform.tfstate file to the output directory.""" + state = { + "version": 4, + "terraform_version": "1.7.0", + "serial": 1, + "lineage": "test-lineage-uuid", + "outputs": {}, + "resources": resources, + } + state_file = output_dir / "terraform.tfstate" + state_file.write_text(json.dumps(state, indent=2), encoding="utf-8") + return state_file + + +class TestAddedResource: + """Tests for adding new resource blocks.""" + + def test_added_resource_creates_block_in_correct_file( + self, tmp_path: Path + ) -> None: + """An ADDED resource creates a new block in the resource type .tf file.""" + change = _make_change( + resource_id="apps/v1/deployments/default/nginx", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.ADDED, + ) + summary = _make_summary([change]) + attributes = { + "apps/v1/deployments/default/nginx": { + "namespace": "default", + "replicas": 3, + } + } + + updater = IncrementalUpdater(summary, str(tmp_path), attributes) + updater.apply() + + tf_file = tmp_path / "kubernetes_deployment.tf" + assert tf_file.exists() + content = tf_file.read_text(encoding="utf-8") + assert 'resource "kubernetes_deployment" "nginx"' in content + assert "namespace" in content + assert "replicas" in content + assert "# Source: apps/v1/deployments/default/nginx" in content + + def test_added_resource_appends_to_existing_file( + self, tmp_path: Path + ) -> None: + """An ADDED resource appends to an existing .tf file without overwriting.""" + existing_content = ( + '# Source: existing-id\n' + 'resource "kubernetes_deployment" "existing" {\n' + ' replicas = 1\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", existing_content) + + change = _make_change( + resource_id="new-id", + resource_type="kubernetes_deployment", + resource_name="new-service", + change_type=ChangeType.ADDED, + ) + summary = _make_summary([change]) + attributes = {"new-id": {"replicas": 2}} + + updater = IncrementalUpdater(summary, str(tmp_path), attributes) + updater.apply() + + tf_file = tmp_path / "kubernetes_deployment.tf" + content = tf_file.read_text(encoding="utf-8") + # Both resources should be present + assert 'resource "kubernetes_deployment" "existing"' in content + assert 'resource "kubernetes_deployment" "new_service"' in content + + def test_added_resource_creates_file_if_not_exists( + self, tmp_path: Path + ) -> None: + """An ADDED resource creates the .tf file if it doesn't exist.""" + change = _make_change( + resource_id="svc-1", + resource_type="docker_service", + resource_name="my-app", + change_type=ChangeType.ADDED, + ) + summary = _make_summary([change]) + attributes = {"svc-1": {"image": "nginx:latest"}} + + updater = IncrementalUpdater(summary, str(tmp_path), attributes) + updater.apply() + + tf_file = tmp_path / "docker_service.tf" + assert tf_file.exists() + content = tf_file.read_text(encoding="utf-8") + assert 'resource "docker_service" "my_app"' in content + + +class TestRemovedResource: + """Tests for removing resource blocks.""" + + def test_removed_resource_removes_block_from_file( + self, tmp_path: Path + ) -> None: + """A REMOVED resource removes its block from the .tf file.""" + content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + '}\n' + '\n' + '# Source: res-2\n' + 'resource "kubernetes_deployment" "redis" {\n' + ' replicas = 1\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", content) + + change = _make_change( + resource_id="res-1", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.REMOVED, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + tf_file = tmp_path / "kubernetes_deployment.tf" + result = tf_file.read_text(encoding="utf-8") + assert "nginx" not in result + assert 'resource "kubernetes_deployment" "redis"' in result + + def test_removed_resource_updates_state_file( + self, tmp_path: Path + ) -> None: + """A REMOVED resource removes its entry from the state file.""" + state_resources = [ + { + "mode": "managed", + "type": "kubernetes_deployment", + "name": "nginx", + "provider": 'provider["registry.terraform.io/hashicorp/kubernetes"]', + "instances": [ + { + "schema_version": 1, + "attributes": {"id": "res-1", "replicas": 3}, + "sensitive_attributes": [], + "dependencies": [], + } + ], + }, + { + "mode": "managed", + "type": "kubernetes_deployment", + "name": "redis", + "provider": 'provider["registry.terraform.io/hashicorp/kubernetes"]', + "instances": [ + { + "schema_version": 1, + "attributes": {"id": "res-2", "replicas": 1}, + "sensitive_attributes": [], + "dependencies": [], + } + ], + }, + ] + _write_state_file(tmp_path, state_resources) + + # Also write the .tf file so the removal can proceed + tf_content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", tf_content) + + change = _make_change( + resource_id="res-1", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.REMOVED, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + state_file = tmp_path / "terraform.tfstate" + state = json.loads(state_file.read_text(encoding="utf-8")) + # Only redis should remain + assert len(state["resources"]) == 1 + assert state["resources"][0]["name"] == "redis" + # Serial should be incremented + assert state["serial"] == 2 + + +class TestModifiedResource: + """Tests for updating resource blocks.""" + + def test_modified_resource_updates_block_in_file( + self, tmp_path: Path + ) -> None: + """A MODIFIED resource updates the attribute values in the .tf file.""" + content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + ' image = "nginx:1.24"\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", content) + + change = _make_change( + resource_id="res-1", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.MODIFIED, + changed_attributes={ + "replicas": {"old": 3, "new": 5}, + }, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + tf_file = tmp_path / "kubernetes_deployment.tf" + result = tf_file.read_text(encoding="utf-8") + assert "replicas = 5" in result + assert "replicas = 3" not in result + # Unchanged attribute should remain + assert 'image = "nginx:1.24"' in result + + def test_modified_resource_adds_new_attribute( + self, tmp_path: Path + ) -> None: + """A MODIFIED resource with a new attribute adds it to the block.""" + content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", content) + + change = _make_change( + resource_id="res-1", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.MODIFIED, + changed_attributes={ + "image": {"old": None, "new": "nginx:1.25"}, + }, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + tf_file = tmp_path / "kubernetes_deployment.tf" + result = tf_file.read_text(encoding="utf-8") + assert '"nginx:1.25"' in result + assert "replicas = 3" in result + + def test_modified_resource_removes_attribute( + self, tmp_path: Path + ) -> None: + """A MODIFIED resource with a removed attribute removes the line.""" + content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + ' image = "nginx:1.24"\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", content) + + change = _make_change( + resource_id="res-1", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.MODIFIED, + changed_attributes={ + "image": {"old": "nginx:1.24", "new": None}, + }, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + tf_file = tmp_path / "kubernetes_deployment.tf" + result = tf_file.read_text(encoding="utf-8") + assert "image" not in result + assert "replicas = 3" in result + + +class TestOnlyAffectedFilesModified: + """Tests that only files with changed resources are modified.""" + + def test_unrelated_files_are_not_modified(self, tmp_path: Path) -> None: + """Files for resource types without changes are not touched.""" + # Write two .tf files + k8s_content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + '}\n' + ) + docker_content = ( + '# Source: svc-1\n' + 'resource "docker_service" "app" {\n' + ' image = "app:latest"\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", k8s_content) + docker_file = _write_tf_file(tmp_path, "docker_service", docker_content) + + # Record the modification time of the docker file + docker_mtime_before = docker_file.stat().st_mtime + + # Only change the kubernetes resource + change = _make_change( + resource_id="res-1", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.MODIFIED, + changed_attributes={"replicas": {"old": 3, "new": 5}}, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + # Docker file should not be in modified_files + assert str(docker_file) not in updater.modified_files + # Kubernetes file should be in modified_files + k8s_file = tmp_path / "kubernetes_deployment.tf" + assert str(k8s_file) in updater.modified_files + + def test_modified_files_tracks_only_changed(self, tmp_path: Path) -> None: + """The modified_files property only contains files that were changed.""" + content = ( + '# Source: res-1\n' + 'resource "kubernetes_deployment" "nginx" {\n' + ' replicas = 3\n' + '}\n' + ) + _write_tf_file(tmp_path, "kubernetes_deployment", content) + + change = _make_change( + resource_id="new-id", + resource_type="docker_service", + resource_name="new-svc", + change_type=ChangeType.ADDED, + ) + summary = _make_summary([change]) + attributes = {"new-id": {"image": "app:1.0"}} + + updater = IncrementalUpdater(summary, str(tmp_path), attributes) + updater.apply() + + # Only docker_service.tf should be modified + modified = updater.modified_files + assert any("docker_service.tf" in f for f in modified) + assert not any("kubernetes_deployment.tf" in f for f in modified) + + +class TestStateFileUpdatedForRemovedResources: + """Tests that state file is properly updated when resources are removed.""" + + def test_state_entry_removed_for_removed_resource( + self, tmp_path: Path + ) -> None: + """Removing a resource also removes its state entry.""" + state_resources = [ + { + "mode": "managed", + "type": "docker_service", + "name": "my_app", + "provider": 'provider["registry.terraform.io/hashicorp/docker"]', + "instances": [ + { + "schema_version": 0, + "attributes": {"id": "svc-1", "image": "app:1.0"}, + "sensitive_attributes": [], + "dependencies": [], + } + ], + } + ] + _write_state_file(tmp_path, state_resources) + + tf_content = ( + '# Source: svc-1\n' + 'resource "docker_service" "my_app" {\n' + ' image = "app:1.0"\n' + '}\n' + ) + _write_tf_file(tmp_path, "docker_service", tf_content) + + change = _make_change( + resource_id="svc-1", + resource_type="docker_service", + resource_name="my-app", + change_type=ChangeType.REMOVED, + ) + summary = _make_summary([change]) + + updater = IncrementalUpdater(summary, str(tmp_path)) + updater.apply() + + state_file = tmp_path / "terraform.tfstate" + state = json.loads(state_file.read_text(encoding="utf-8")) + assert len(state["resources"]) == 0 + assert state["serial"] == 2 + + def test_state_file_not_modified_when_no_removals( + self, tmp_path: Path + ) -> None: + """State file is not modified when there are no REMOVED changes.""" + state_resources = [ + { + "mode": "managed", + "type": "kubernetes_deployment", + "name": "nginx", + "provider": 'provider["registry.terraform.io/hashicorp/kubernetes"]', + "instances": [ + { + "schema_version": 1, + "attributes": {"id": "res-1"}, + "sensitive_attributes": [], + "dependencies": [], + } + ], + } + ] + _write_state_file(tmp_path, state_resources) + + # Only an ADDED change (no removal) + change = _make_change( + resource_id="new-id", + resource_type="docker_service", + resource_name="new-svc", + change_type=ChangeType.ADDED, + ) + summary = _make_summary([change]) + attributes = {"new-id": {"image": "app:1.0"}} + + updater = IncrementalUpdater(summary, str(tmp_path), attributes) + updater.apply() + + # State file should not be in modified_files + state_path = str(tmp_path / "terraform.tfstate") + assert state_path not in updater.modified_files + + # State should still have the original entry + state_file = tmp_path / "terraform.tfstate" + state = json.loads(state_file.read_text(encoding="utf-8")) + assert len(state["resources"]) == 1 + assert state["serial"] == 1 diff --git a/tests/unit/test_kubernetes_plugin.py b/tests/unit/test_kubernetes_plugin.py new file mode 100644 index 0000000..839f920 --- /dev/null +++ b/tests/unit/test_kubernetes_plugin.py @@ -0,0 +1,508 @@ +"""Unit tests for the KubernetesPlugin provider plugin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProgress, + ScanResult, +) +from iac_reverse.scanner.kubernetes_plugin import KubernetesPlugin +from iac_reverse.scanner.scanner import AuthenticationError + + +class TestKubernetesPluginAuthenticate: + """Tests for KubernetesPlugin.authenticate().""" + + @patch("iac_reverse.scanner.kubernetes_plugin.config") + @patch("iac_reverse.scanner.kubernetes_plugin.client") + def test_authenticate_with_kubeconfig_path(self, mock_client, mock_config): + """Successfully authenticates with a kubeconfig path.""" + plugin = KubernetesPlugin() + credentials = {"kubeconfig_path": "/home/user/.kube/config"} + + plugin.authenticate(credentials) + + mock_config.load_kube_config.assert_called_once_with( + config_file="/home/user/.kube/config", + context=None, + ) + mock_client.ApiClient.assert_called_once() + assert plugin._core_v1 is not None + assert plugin._apps_v1 is not None + assert plugin._networking_v1 is not None + + @patch("iac_reverse.scanner.kubernetes_plugin.config") + @patch("iac_reverse.scanner.kubernetes_plugin.client") + def test_authenticate_with_context(self, mock_client, mock_config): + """Authenticates with a specific context.""" + plugin = KubernetesPlugin() + credentials = { + "kubeconfig_path": "/home/user/.kube/config", + "context": "production", + } + + plugin.authenticate(credentials) + + mock_config.load_kube_config.assert_called_once_with( + config_file="/home/user/.kube/config", + context="production", + ) + + def test_authenticate_missing_kubeconfig_path(self): + """Raises AuthenticationError when kubeconfig_path is missing.""" + plugin = KubernetesPlugin() + + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({}) + + assert "kubeconfig_path is required" in str(exc_info.value) + + @patch("iac_reverse.scanner.kubernetes_plugin.config") + def test_authenticate_invalid_kubeconfig(self, mock_config): + """Raises AuthenticationError when kubeconfig is invalid.""" + mock_config.load_kube_config.side_effect = Exception("file not found") + plugin = KubernetesPlugin() + + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"kubeconfig_path": "/invalid/path"}) + + assert "Failed to load kubeconfig" in str(exc_info.value) + + +class TestKubernetesPluginPlatformCategory: + """Tests for KubernetesPlugin.get_platform_category().""" + + def test_returns_container_orchestration(self): + """Returns CONTAINER_ORCHESTRATION category.""" + plugin = KubernetesPlugin() + assert plugin.get_platform_category() == PlatformCategory.CONTAINER_ORCHESTRATION + + +class TestKubernetesPluginSupportedResourceTypes: + """Tests for KubernetesPlugin.list_supported_resource_types().""" + + def test_returns_all_kubernetes_resource_types(self): + """Returns all six supported Kubernetes resource types.""" + plugin = KubernetesPlugin() + types = plugin.list_supported_resource_types() + + expected = [ + "kubernetes_deployment", + "kubernetes_service", + "kubernetes_ingress", + "kubernetes_config_map", + "kubernetes_persistent_volume", + "kubernetes_namespace", + ] + assert types == expected + + def test_returns_new_list_each_call(self): + """Returns a new list instance each call (not mutable reference).""" + plugin = KubernetesPlugin() + types1 = plugin.list_supported_resource_types() + types2 = plugin.list_supported_resource_types() + assert types1 == types2 + assert types1 is not types2 + + +class TestKubernetesPluginDetectArchitecture: + """Tests for KubernetesPlugin.detect_architecture().""" + + def test_detects_amd64_from_node_label(self): + """Detects AMD64 architecture from kubernetes.io/arch label.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + + node = _make_node( + addresses=[("InternalIP", "192.168.1.10")], + labels={"kubernetes.io/arch": "amd64"}, + ) + plugin._core_v1.list_node.return_value = MagicMock(items=[node]) + + result = plugin.detect_architecture("192.168.1.10") + assert result == CpuArchitecture.AMD64 + + def test_detects_arm64_from_node_label(self): + """Detects AARCH64 architecture from arm64 label.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + + node = _make_node( + addresses=[("InternalIP", "192.168.1.20")], + labels={"kubernetes.io/arch": "arm64"}, + ) + plugin._core_v1.list_node.return_value = MagicMock(items=[node]) + + result = plugin.detect_architecture("192.168.1.20") + assert result == CpuArchitecture.AARCH64 + + def test_detects_arm_from_node_label(self): + """Detects ARM architecture from arm label.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + + node = _make_node( + addresses=[("InternalIP", "192.168.1.30")], + labels={"kubernetes.io/arch": "arm"}, + ) + plugin._core_v1.list_node.return_value = MagicMock(items=[node]) + + result = plugin.detect_architecture("192.168.1.30") + assert result == CpuArchitecture.ARM + + def test_falls_back_to_beta_label(self): + """Falls back to beta.kubernetes.io/arch label.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + + node = _make_node( + addresses=[("InternalIP", "192.168.1.40")], + labels={"beta.kubernetes.io/arch": "arm64"}, + ) + plugin._core_v1.list_node.return_value = MagicMock(items=[node]) + + result = plugin.detect_architecture("192.168.1.40") + assert result == CpuArchitecture.AARCH64 + + def test_defaults_to_amd64_when_no_label(self): + """Defaults to AMD64 when no arch label is present.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + + node = _make_node( + addresses=[("InternalIP", "192.168.1.50")], + labels={}, + ) + plugin._core_v1.list_node.return_value = MagicMock(items=[node]) + + result = plugin.detect_architecture("192.168.1.50") + assert result == CpuArchitecture.AMD64 + + def test_defaults_to_amd64_when_not_authenticated(self): + """Returns AMD64 when plugin is not authenticated.""" + plugin = KubernetesPlugin() + result = plugin.detect_architecture("192.168.1.1") + assert result == CpuArchitecture.AMD64 + + def test_defaults_to_amd64_on_api_error(self): + """Returns AMD64 when API call fails.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + plugin._core_v1.list_node.side_effect = Exception("API error") + + result = plugin.detect_architecture("192.168.1.1") + assert result == CpuArchitecture.AMD64 + + +class TestKubernetesPluginListEndpoints: + """Tests for KubernetesPlugin.list_endpoints().""" + + def test_returns_node_internal_ips(self): + """Returns InternalIP addresses from nodes.""" + plugin = KubernetesPlugin() + plugin._core_v1 = MagicMock() + + nodes = [ + _make_node(addresses=[("InternalIP", "10.0.0.1")]), + _make_node(addresses=[("InternalIP", "10.0.0.2")]), + ] + plugin._core_v1.list_node.return_value = MagicMock(items=nodes) + + result = plugin.list_endpoints() + assert result == ["10.0.0.1", "10.0.0.2"] + + def test_returns_empty_when_not_authenticated(self): + """Returns empty list when not authenticated.""" + plugin = KubernetesPlugin() + assert plugin.list_endpoints() == [] + + +class TestKubernetesPluginDiscoverResources: + """Tests for KubernetesPlugin.discover_resources().""" + + def test_discovers_deployments(self): + """Discovers deployments and returns DiscoveredResource objects.""" + plugin = _make_authenticated_plugin() + + dep = MagicMock() + dep.metadata.name = "nginx" + dep.metadata.namespace = "default" + dep.metadata.labels = {"app": "nginx"} + dep.spec.replicas = 3 + plugin._apps_v1.list_deployment_for_all_namespaces.return_value = MagicMock( + items=[dep] + ) + _stub_empty_apis(plugin, exclude="deployments") + + progress_updates = [] + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_deployment"], + progress_callback=progress_updates.append, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "kubernetes_deployment" + assert resource.unique_id == "default/nginx" + assert resource.name == "nginx" + assert resource.provider == ProviderType.KUBERNETES + assert resource.platform_category == PlatformCategory.CONTAINER_ORCHESTRATION + assert resource.attributes["namespace"] == "default" + assert resource.attributes["replicas"] == 3 + + def test_discovers_services(self): + """Discovers services and returns DiscoveredResource objects.""" + plugin = _make_authenticated_plugin() + + svc = MagicMock() + svc.metadata.name = "my-service" + svc.metadata.namespace = "production" + svc.metadata.labels = {"app": "web"} + svc.spec.type = "ClusterIP" + svc.spec.cluster_ip = "10.96.0.1" + plugin._core_v1.list_service_for_all_namespaces.return_value = MagicMock( + items=[svc] + ) + _stub_empty_apis(plugin, exclude="services") + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_service"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "kubernetes_service" + assert resource.unique_id == "production/my-service" + assert resource.attributes["type"] == "ClusterIP" + + def test_discovers_ingresses(self): + """Discovers ingresses and returns DiscoveredResource objects.""" + plugin = _make_authenticated_plugin() + + ing = MagicMock() + ing.metadata.name = "web-ingress" + ing.metadata.namespace = "default" + ing.metadata.labels = {"app": "web"} + plugin._networking_v1.list_ingress_for_all_namespaces.return_value = MagicMock( + items=[ing] + ) + _stub_empty_apis(plugin, exclude="ingresses") + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_ingress"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "kubernetes_ingress" + assert result.resources[0].unique_id == "default/web-ingress" + + def test_discovers_config_maps(self): + """Discovers config maps and returns DiscoveredResource objects.""" + plugin = _make_authenticated_plugin() + + cm = MagicMock() + cm.metadata.name = "app-config" + cm.metadata.namespace = "default" + cm.metadata.labels = {} + cm.data = {"key1": "value1", "key2": "value2"} + plugin._core_v1.list_config_map_for_all_namespaces.return_value = MagicMock( + items=[cm] + ) + _stub_empty_apis(plugin, exclude="config_maps") + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_config_map"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "kubernetes_config_map" + assert resource.attributes["data_keys"] == ["key1", "key2"] + + def test_discovers_persistent_volumes(self): + """Discovers persistent volumes and returns DiscoveredResource objects.""" + plugin = _make_authenticated_plugin() + + pv = MagicMock() + pv.metadata.name = "pv-data" + pv.metadata.labels = {} + pv.spec.capacity = {"storage": "10Gi"} + pv.spec.access_modes = ["ReadWriteOnce"] + pv.spec.storage_class_name = "standard" + plugin._core_v1.list_persistent_volume.return_value = MagicMock( + items=[pv] + ) + _stub_empty_apis(plugin, exclude="persistent_volumes") + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_persistent_volume"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "kubernetes_persistent_volume" + assert resource.unique_id == "pv-data" + assert resource.attributes["capacity"] == {"storage": "10Gi"} + assert resource.attributes["access_modes"] == ["ReadWriteOnce"] + + def test_discovers_namespaces(self): + """Discovers namespaces and returns DiscoveredResource objects.""" + plugin = _make_authenticated_plugin() + + ns = MagicMock() + ns.metadata.name = "production" + ns.metadata.labels = {"env": "prod"} + ns.status.phase = "Active" + plugin._core_v1.list_namespace.return_value = MagicMock(items=[ns]) + _stub_empty_apis(plugin, exclude="namespaces") + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_namespace"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + resource = result.resources[0] + assert resource.resource_type == "kubernetes_namespace" + assert resource.unique_id == "production" + assert resource.attributes["status"] == "Active" + + def test_reports_progress_for_each_resource_type(self): + """Reports progress callback for each resource type scanned.""" + plugin = _make_authenticated_plugin() + _stub_empty_apis(plugin) + + progress_updates: list[ScanProgress] = [] + plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_deployment", "kubernetes_service"], + progress_callback=progress_updates.append, + ) + + assert len(progress_updates) == 2 + assert progress_updates[0].current_resource_type == "kubernetes_deployment" + assert progress_updates[0].resource_types_completed == 1 + assert progress_updates[0].total_resource_types == 2 + assert progress_updates[1].current_resource_type == "kubernetes_service" + assert progress_updates[1].resource_types_completed == 2 + + def test_handles_api_errors_gracefully(self): + """Records errors when API calls fail without crashing.""" + plugin = _make_authenticated_plugin() + plugin._apps_v1.list_deployment_for_all_namespaces.side_effect = Exception( + "API unavailable" + ) + _stub_empty_apis(plugin, exclude="deployments") + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_deployment"], + progress_callback=lambda p: None, + ) + + assert len(result.errors) == 1 + assert "API unavailable" in result.errors[0] + assert len(result.resources) == 0 + + def test_returns_scan_result_type(self): + """Returns a ScanResult instance.""" + plugin = _make_authenticated_plugin() + _stub_empty_apis(plugin) + + result = plugin.discover_resources( + endpoints=["10.0.0.1"], + resource_types=["kubernetes_namespace"], + progress_callback=lambda p: None, + ) + + assert isinstance(result, ScanResult) + + +# --------------------------------------------------------------------------- +# Test Helpers +# --------------------------------------------------------------------------- + + +def _make_node( + addresses: list[tuple[str, str]] | None = None, + labels: dict[str, str] | None = None, +) -> MagicMock: + """Create a mock Kubernetes node object.""" + node = MagicMock() + node.metadata.labels = labels or {} + + if addresses: + addr_objects = [] + for addr_type, addr_value in addresses: + addr = MagicMock() + addr.type = addr_type + addr.address = addr_value + addr_objects.append(addr) + node.status.addresses = addr_objects + else: + node.status.addresses = [] + + return node + + +def _make_authenticated_plugin() -> KubernetesPlugin: + """Create a KubernetesPlugin with mocked API clients.""" + plugin = KubernetesPlugin() + plugin._api_client = MagicMock() + plugin._core_v1 = MagicMock() + plugin._apps_v1 = MagicMock() + plugin._networking_v1 = MagicMock() + + # Default: detect_architecture returns AMD64 + node = _make_node( + addresses=[("InternalIP", "10.0.0.1")], + labels={"kubernetes.io/arch": "amd64"}, + ) + plugin._core_v1.list_node.return_value = MagicMock(items=[node]) + + return plugin + + +def _stub_empty_apis(plugin: KubernetesPlugin, exclude: str = "") -> None: + """Stub all discovery API calls to return empty lists. + + Args: + plugin: The plugin instance with mocked clients. + exclude: Resource type to exclude from stubbing (leave for test to set up). + """ + if exclude != "deployments": + plugin._apps_v1.list_deployment_for_all_namespaces.return_value = MagicMock( + items=[] + ) + if exclude != "services": + plugin._core_v1.list_service_for_all_namespaces.return_value = MagicMock( + items=[] + ) + if exclude != "ingresses": + plugin._networking_v1.list_ingress_for_all_namespaces.return_value = MagicMock( + items=[] + ) + if exclude != "config_maps": + plugin._core_v1.list_config_map_for_all_namespaces.return_value = MagicMock( + items=[] + ) + if exclude != "persistent_volumes": + plugin._core_v1.list_persistent_volume.return_value = MagicMock(items=[]) + if exclude != "namespaces": + plugin._core_v1.list_namespace.return_value = MagicMock(items=[]) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py new file mode 100644 index 0000000..5429f65 --- /dev/null +++ b/tests/unit/test_models.py @@ -0,0 +1,403 @@ +"""Unit tests for core data models.""" + +import json + +import pytest + +from iac_reverse.models import ( + ChangeType, + ChangeSummary, + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + ExtractedVariable, + GeneratedFile, + PlannedChange, + PlatformCategory, + PROVIDER_PLATFORM_MAP, + ProviderType, + ResourceChange, + ResourceRelationship, + ScanProfile, + ScanProgress, + ScanResult, + StateEntry, + StateFile, + UnresolvedReference, + ValidationError, + ValidationResult, +) + + +class TestProviderType: + def test_all_values(self): + assert ProviderType.DOCKER_SWARM.value == "docker_swarm" + assert ProviderType.KUBERNETES.value == "kubernetes" + assert ProviderType.SYNOLOGY.value == "synology" + assert ProviderType.HARVESTER.value == "harvester" + assert ProviderType.BARE_METAL.value == "bare_metal" + assert ProviderType.WINDOWS.value == "windows" + + def test_member_count(self): + assert len(ProviderType) == 6 + + +class TestPlatformCategory: + def test_all_values(self): + assert PlatformCategory.CONTAINER_ORCHESTRATION.value == "container" + assert PlatformCategory.STORAGE_APPLIANCE.value == "storage" + assert PlatformCategory.HCI.value == "hci" + assert PlatformCategory.BARE_METAL.value == "bare_metal" + assert PlatformCategory.WINDOWS.value == "windows" + + def test_member_count(self): + assert len(PlatformCategory) == 5 + + +class TestProviderPlatformMap: + def test_all_providers_mapped(self): + for provider in ProviderType: + assert provider in PROVIDER_PLATFORM_MAP + + def test_container_orchestration_providers(self): + assert PROVIDER_PLATFORM_MAP[ProviderType.DOCKER_SWARM] == PlatformCategory.CONTAINER_ORCHESTRATION + assert PROVIDER_PLATFORM_MAP[ProviderType.KUBERNETES] == PlatformCategory.CONTAINER_ORCHESTRATION + + def test_other_providers(self): + assert PROVIDER_PLATFORM_MAP[ProviderType.SYNOLOGY] == PlatformCategory.STORAGE_APPLIANCE + assert PROVIDER_PLATFORM_MAP[ProviderType.HARVESTER] == PlatformCategory.HCI + assert PROVIDER_PLATFORM_MAP[ProviderType.BARE_METAL] == PlatformCategory.BARE_METAL + assert PROVIDER_PLATFORM_MAP[ProviderType.WINDOWS] == PlatformCategory.WINDOWS + + +class TestCpuArchitecture: + def test_all_values(self): + assert CpuArchitecture.AMD64.value == "amd64" + assert CpuArchitecture.ARM.value == "arm" + assert CpuArchitecture.AARCH64.value == "aarch64" + + def test_member_count(self): + assert len(CpuArchitecture) == 3 + + +class TestChangeType: + def test_all_values(self): + assert ChangeType.ADDED.value == "added" + assert ChangeType.REMOVED.value == "removed" + assert ChangeType.MODIFIED.value == "modified" + + +class TestScanProfile: + def test_valid_profile(self): + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={"kubeconfig_path": "/home/user/.kube/config"}, + ) + assert profile.validate() == [] + + def test_empty_credentials_invalid(self): + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={}, + ) + errors = profile.validate() + assert len(errors) > 0 + assert "credentials" in errors[0] + + def test_platform_category_property(self): + profile = ScanProfile( + provider=ProviderType.DOCKER_SWARM, + credentials={"host": "localhost"}, + ) + assert profile.platform_category == PlatformCategory.CONTAINER_ORCHESTRATION + + def test_optional_fields_default_none(self): + profile = ScanProfile( + provider=ProviderType.SYNOLOGY, + credentials={"host": "nas01"}, + ) + assert profile.endpoints is None + assert profile.resource_type_filters is None + assert profile.authentik_token is None + + +class TestDiscoveredResource: + def test_creation(self): + resource = DiscoveredResource( + resource_type="kubernetes_deployment", + unique_id="apps/v1/deployments/default/nginx", + name="nginx", + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api:6443", + attributes={"replicas": 3, "image": "nginx:1.25"}, + raw_references=["default/services/nginx-svc"], + ) + assert resource.resource_type == "kubernetes_deployment" + assert resource.unique_id == "apps/v1/deployments/default/nginx" + assert resource.provider == ProviderType.KUBERNETES + + def test_raw_references_default_empty(self): + resource = DiscoveredResource( + resource_type="windows_service", + unique_id="win01/services/nginx", + name="nginx", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + architecture=CpuArchitecture.AMD64, + endpoint="win01.internal.lab", + attributes={"state": "running"}, + ) + assert resource.raw_references == [] + + +class TestScanResult: + def test_creation(self): + result = ScanResult( + resources=[], + warnings=["unsupported type: foo"], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="abc123", + ) + assert result.is_partial is False + assert len(result.warnings) == 1 + + def test_partial_scan(self): + result = ScanResult( + resources=[], + warnings=[], + errors=["connection lost"], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="abc123", + is_partial=True, + ) + assert result.is_partial is True + + +class TestScanProgress: + def test_creation(self): + progress = ScanProgress( + current_resource_type="kubernetes_deployment", + resources_discovered=15, + resource_types_completed=2, + total_resource_types=5, + ) + assert progress.resources_discovered == 15 + assert progress.resource_types_completed == 2 + + +class TestResourceRelationship: + def test_creation(self): + rel = ResourceRelationship( + source_id="resource-a", + target_id="resource-b", + relationship_type="dependency", + source_attribute="network_id", + ) + assert rel.relationship_type == "dependency" + + +class TestUnresolvedReference: + def test_creation(self): + ref = UnresolvedReference( + source_resource_id="resource-a", + source_attribute="vpc_id", + referenced_id="vpc-unknown", + suggested_resolution="data_source", + ) + assert ref.suggested_resolution == "data_source" + + +class TestDependencyGraph: + def test_creation(self): + graph = DependencyGraph( + resources=[], + relationships=[], + topological_order=["a", "b", "c"], + cycles=[], + unresolved_references=[], + ) + assert graph.topological_order == ["a", "b", "c"] + assert graph.cycles == [] + + +class TestGeneratedFile: + def test_creation(self): + gf = GeneratedFile( + filename="kubernetes_deployment.tf", + content='resource "kubernetes_deployment" "nginx" {}', + resource_count=1, + ) + assert gf.filename == "kubernetes_deployment.tf" + assert gf.resource_count == 1 + + +class TestExtractedVariable: + def test_creation(self): + var = ExtractedVariable( + name="environment", + type_expr="string", + default_value="production", + description="Deployment environment", + used_by=["resource-a", "resource-b"], + ) + assert len(var.used_by) == 2 + + def test_used_by_default_empty(self): + var = ExtractedVariable( + name="region", + type_expr="string", + default_value="us-east-1", + description="Region", + ) + assert var.used_by == [] + + +class TestCodeGenerationResult: + def test_creation(self): + result = CodeGenerationResult( + resource_files=[GeneratedFile("main.tf", "content", 5)], + variables_file=GeneratedFile("variables.tf", "vars", 0), + provider_file=GeneratedFile("provider.tf", "provider", 0), + outputs_file=None, + skipped_resources=[("res-1", "unsupported type")], + ) + assert len(result.resource_files) == 1 + assert result.outputs_file is None + assert len(result.skipped_resources) == 1 + + +class TestStateEntry: + def test_creation(self): + entry = StateEntry( + resource_type="kubernetes_deployment", + resource_name="nginx", + provider_id="apps/v1/deployments/default/nginx", + attributes={"namespace": "default"}, + sensitive_attributes=["password"], + schema_version=1, + dependencies=["kubernetes_service.nginx_svc"], + ) + assert entry.schema_version == 1 + assert len(entry.dependencies) == 1 + + def test_defaults(self): + entry = StateEntry( + resource_type="kubernetes_service", + resource_name="svc", + provider_id="id-123", + attributes={}, + ) + assert entry.sensitive_attributes == [] + assert entry.schema_version == 0 + assert entry.dependencies == [] + + +class TestStateFile: + def test_defaults(self): + state = StateFile() + assert state.version == 4 + assert state.terraform_version == "" + assert state.serial == 1 + assert state.lineage == "" + assert state.resources == [] + + def test_to_json_structure(self): + state = StateFile( + terraform_version="1.7.0", + resources=[ + StateEntry( + resource_type="kubernetes_deployment", + resource_name="nginx", + provider_id="apps/v1/deployments/default/nginx", + attributes={"namespace": "default", "replicas": 3}, + schema_version=1, + dependencies=["kubernetes_service.nginx_svc"], + ) + ], + ) + parsed = json.loads(state.to_json()) + assert parsed["version"] == 4 + assert parsed["terraform_version"] == "1.7.0" + assert parsed["serial"] == 1 + assert "lineage" in parsed + assert len(parsed["resources"]) == 1 + + res = parsed["resources"][0] + assert res["mode"] == "managed" + assert res["type"] == "kubernetes_deployment" + assert res["name"] == "nginx" + assert res["instances"][0]["schema_version"] == 1 + assert res["instances"][0]["attributes"]["id"] == "apps/v1/deployments/default/nginx" + + def test_to_json_generates_lineage(self): + state = StateFile() + parsed = json.loads(state.to_json()) + assert len(parsed["lineage"]) > 0 # UUID generated + + +class TestValidationResult: + def test_creation(self): + result = ValidationResult( + init_success=True, + validate_success=True, + plan_success=False, + planned_changes=[ + PlannedChange( + resource_address="kubernetes_deployment.nginx", + change_type="modify", + details="replicas changed", + ) + ], + errors=[], + correction_attempts=1, + ) + assert result.plan_success is False + assert len(result.planned_changes) == 1 + + +class TestValidationError: + def test_creation_with_line(self): + err = ValidationError(file="main.tf", message="invalid block", line=42) + assert err.line == 42 + + def test_creation_without_line(self): + err = ValidationError(file="main.tf", message="missing provider") + assert err.line is None + + +class TestResourceChange: + def test_added(self): + change = ResourceChange( + resource_id="new-resource", + resource_type="kubernetes_service", + resource_name="new-svc", + change_type=ChangeType.ADDED, + ) + assert change.changed_attributes is None + + def test_modified(self): + change = ResourceChange( + resource_id="existing-resource", + resource_type="kubernetes_deployment", + resource_name="nginx", + change_type=ChangeType.MODIFIED, + changed_attributes={"replicas": {"old": 3, "new": 5}}, + ) + assert change.changed_attributes is not None + + +class TestChangeSummary: + def test_creation(self): + summary = ChangeSummary( + added_count=2, + removed_count=1, + modified_count=3, + changes=[], + ) + assert summary.added_count == 2 + assert summary.removed_count == 1 + assert summary.modified_count == 3 diff --git a/tests/unit/test_multi_provider_scanner.py b/tests/unit/test_multi_provider_scanner.py new file mode 100644 index 0000000..35012c5 --- /dev/null +++ b/tests/unit/test_multi_provider_scanner.py @@ -0,0 +1,505 @@ +"""Unit tests for the MultiProviderScanner. + +Tests cover: +- All providers succeed: all resources collected +- One provider fails: others still complete, failed one reported +- Multiple providers fail: remaining still complete +- Error details include provider name and reason +""" + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.multi_provider_scanner import ( + MultiProviderScanner, + MultiProviderScanResult, + ProviderFailure, + ProviderScanEntry, +) +from iac_reverse.scanner.scanner import AuthenticationError + + +# --------------------------------------------------------------------------- +# Helpers / Fixtures +# --------------------------------------------------------------------------- + + +def make_profile(provider: ProviderType = ProviderType.KUBERNETES) -> ScanProfile: + """Create a valid ScanProfile with sensible defaults.""" + return ScanProfile( + provider=provider, + credentials={"token": "test-token"}, + endpoints=["https://api.local:6443"], + resource_type_filters=None, + ) + + +def make_resource( + provider: ProviderType = ProviderType.KUBERNETES, + resource_type: str = "kubernetes_deployment", + name: str = "nginx", +) -> DiscoveredResource: + """Create a sample DiscoveredResource.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=f"{provider.value}/{resource_type}/{name}", + name=name, + provider=provider, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + endpoint="https://api.local:6443", + attributes={"replicas": 3}, + raw_references=[], + ) + + +class SuccessPlugin(ProviderPlugin): + """A plugin that always succeeds with configurable resources.""" + + def __init__(self, resources: list[DiscoveredResource] | None = None): + self._resources = resources or [] + + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["https://api.local:6443"] + + def list_supported_resource_types(self) -> list[str]: + return ["kubernetes_deployment", "kubernetes_service"] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AARCH64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback=None, + ) -> ScanResult: + return ScanResult( + resources=self._resources, + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + +class FailingAuthPlugin(ProviderPlugin): + """A plugin that fails during authentication.""" + + def __init__(self, error_message: str = "Invalid credentials"): + self._error_message = error_message + + def authenticate(self, credentials: dict[str, str]) -> None: + raise RuntimeError(self._error_message) + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return [] + + def list_supported_resource_types(self) -> list[str]: + return [] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback=None, + ) -> ScanResult: + return ScanResult( + resources=[], + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + +class FailingDiscoverPlugin(ProviderPlugin): + """A plugin that fails during resource discovery.""" + + def __init__(self, error_message: str = "Connection refused"): + self._error_message = error_message + + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.STORAGE_APPLIANCE + + def list_endpoints(self) -> list[str]: + return ["https://nas.local:5001"] + + def list_supported_resource_types(self) -> list[str]: + return ["synology_shared_folder"] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback=None, + ) -> ScanResult: + raise ConnectionError(self._error_message) + + +# --------------------------------------------------------------------------- +# Tests: All providers succeed +# --------------------------------------------------------------------------- + + +class TestAllProvidersSucceed: + """Tests for the happy path where all providers complete successfully.""" + + def test_single_provider_all_resources_collected(self): + resources = [ + make_resource(name="deploy-1"), + make_resource(name="deploy-2"), + ] + entry = ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=resources), + ) + scanner = MultiProviderScanner([entry]) + + result = scanner.scan() + + assert len(result.resources) == 2 + assert len(result.failed_providers) == 0 + assert "kubernetes" in result.successful_providers + + def test_multiple_providers_all_resources_merged(self): + k8s_resources = [ + make_resource(ProviderType.KUBERNETES, "kubernetes_deployment", "nginx"), + ] + docker_resources = [ + make_resource(ProviderType.DOCKER_SWARM, "docker_service", "web"), + ] + + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=k8s_resources), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.DOCKER_SWARM), + plugin=SuccessPlugin(resources=docker_resources), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert len(result.resources) == 2 + assert len(result.failed_providers) == 0 + assert len(result.successful_providers) == 2 + + def test_empty_entries_returns_empty_result(self): + scanner = MultiProviderScanner([]) + + result = scanner.scan() + + assert len(result.resources) == 0 + assert len(result.failed_providers) == 0 + assert len(result.successful_providers) == 0 + + def test_scan_timestamp_is_set(self): + entry = ProviderScanEntry( + profile=make_profile(), + plugin=SuccessPlugin(resources=[]), + ) + scanner = MultiProviderScanner([entry]) + + result = scanner.scan() + + assert result.scan_timestamp != "" + + +# --------------------------------------------------------------------------- +# Tests: One provider fails, others succeed +# --------------------------------------------------------------------------- + + +class TestOneProviderFails: + """Tests for partial failure: one provider fails, others complete.""" + + def test_failed_provider_does_not_block_others(self): + k8s_resources = [make_resource(ProviderType.KUBERNETES, name="nginx")] + + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=k8s_resources), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingAuthPlugin("Invalid API key"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + # Kubernetes resources should still be collected + assert len(result.resources) == 1 + assert result.resources[0].name == "nginx" + # Synology should be reported as failed + assert len(result.failed_providers) == 1 + assert result.failed_providers[0].provider_name == "synology" + + def test_failed_provider_reported_with_error_details(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=[]), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingAuthPlugin("Token expired at 2024-01-15"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + failure = result.failed_providers[0] + assert failure.provider_name == "synology" + assert "Token expired" in failure.error_message + assert failure.error_type == "AuthenticationError" + + def test_successful_providers_listed_correctly(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=[]), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.DOCKER_SWARM), + plugin=FailingAuthPlugin("Connection refused"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert "kubernetes" in result.successful_providers + assert "docker_swarm" not in result.successful_providers + + def test_order_does_not_matter_failed_first(self): + """Even if the first provider fails, subsequent ones still run.""" + docker_resources = [make_resource(ProviderType.DOCKER_SWARM, "docker_service", "web")] + + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingAuthPlugin("Auth failed"), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.DOCKER_SWARM), + plugin=SuccessPlugin(resources=docker_resources), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert len(result.resources) == 1 + assert result.resources[0].name == "web" + assert len(result.failed_providers) == 1 + + +# --------------------------------------------------------------------------- +# Tests: Multiple providers fail +# --------------------------------------------------------------------------- + + +class TestMultipleProvidersFail: + """Tests for scenarios where multiple providers fail.""" + + def test_multiple_failures_remaining_still_complete(self): + k8s_resources = [make_resource(ProviderType.KUBERNETES, name="app")] + + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingAuthPlugin("Synology auth failed"), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=k8s_resources), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.HARVESTER), + plugin=FailingAuthPlugin("Harvester unreachable"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert len(result.resources) == 1 + assert result.resources[0].name == "app" + assert len(result.failed_providers) == 2 + assert len(result.successful_providers) == 1 + + def test_all_providers_fail_returns_empty_resources(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingAuthPlugin("Auth error 1"), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.HARVESTER), + plugin=FailingAuthPlugin("Auth error 2"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert len(result.resources) == 0 + assert len(result.failed_providers) == 2 + assert len(result.successful_providers) == 0 + + def test_each_failure_has_distinct_error_details(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingAuthPlugin("Invalid API key"), + ), + ProviderScanEntry( + profile=make_profile(ProviderType.HARVESTER), + plugin=FailingAuthPlugin("Certificate expired"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + provider_names = [f.provider_name for f in result.failed_providers] + assert "synology" in provider_names + assert "harvester" in provider_names + + synology_failure = next( + f for f in result.failed_providers if f.provider_name == "synology" + ) + harvester_failure = next( + f for f in result.failed_providers if f.provider_name == "harvester" + ) + assert "Invalid API key" in synology_failure.error_message + assert "Certificate expired" in harvester_failure.error_message + + +# --------------------------------------------------------------------------- +# Tests: Error details include provider name and reason +# --------------------------------------------------------------------------- + + +class TestErrorDetails: + """Tests that error details contain provider name and failure reason.""" + + def test_auth_error_includes_provider_name(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.DOCKER_SWARM), + plugin=FailingAuthPlugin("Bad token"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert result.failed_providers[0].provider_name == "docker_swarm" + + def test_auth_error_includes_reason(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.DOCKER_SWARM), + plugin=FailingAuthPlugin("Token revoked by admin"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + assert "Token revoked by admin" in result.failed_providers[0].error_message + + def test_connection_error_includes_error_type(self): + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.SYNOLOGY), + plugin=FailingDiscoverPlugin("Connection timed out"), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + # ConnectionError is raised during discover; Scanner wraps it in + # ConnectionLostError. The error_type reflects the wrapped exception. + failure = result.failed_providers[0] + assert failure.provider_name == "synology" + assert failure.error_type == "ConnectionLostError" + assert "Connection lost" in failure.error_message + + def test_validation_error_includes_details(self): + """A provider with invalid profile still reports correctly.""" + # Create a profile with empty credentials to trigger ValueError + bad_profile = ScanProfile( + provider=ProviderType.BARE_METAL, + credentials={}, + endpoints=["https://bmc.local"], + ) + entries = [ + ProviderScanEntry( + profile=bad_profile, + plugin=SuccessPlugin(resources=[]), + ), + ] + scanner = MultiProviderScanner(entries) + + result = scanner.scan() + + failure = result.failed_providers[0] + assert failure.provider_name == "bare_metal" + assert failure.error_type == "ValueError" + assert "credentials" in failure.error_message.lower() + + def test_progress_callback_invoked_for_successful_providers(self): + """Progress callback is passed through to individual scanners.""" + resources = [make_resource(name="test")] + entries = [ + ProviderScanEntry( + profile=make_profile(ProviderType.KUBERNETES), + plugin=SuccessPlugin(resources=resources), + ), + ] + scanner = MultiProviderScanner(entries) + + progress_updates = [] + scanner.scan(progress_callback=progress_updates.append) + + # SuccessPlugin doesn't call progress_callback, but the scan still works + # This verifies the callback is accepted without error + assert len(scanner.entries) == 1 diff --git a/tests/unit/test_plugin_base.py b/tests/unit/test_plugin_base.py new file mode 100644 index 0000000..1641a11 --- /dev/null +++ b/tests/unit/test_plugin_base.py @@ -0,0 +1,74 @@ +"""Unit tests for the ProviderPlugin abstract base class.""" + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + PlatformCategory, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin + + +class TestProviderPluginInterface: + def test_cannot_instantiate_directly(self): + """ProviderPlugin is abstract and cannot be instantiated.""" + with pytest.raises(TypeError): + ProviderPlugin() + + def test_requires_all_abstract_methods(self): + """A subclass must implement all abstract methods.""" + + class IncompletePlugin(ProviderPlugin): + def authenticate(self, credentials): + pass + + with pytest.raises(TypeError): + IncompletePlugin() + + def test_concrete_implementation(self): + """A complete implementation can be instantiated.""" + + class ConcretePlugin(ProviderPlugin): + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["https://localhost:6443"] + + def list_supported_resource_types(self) -> list[str]: + return ["kubernetes_deployment"] + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AMD64 + + def discover_resources(self, endpoints, resource_types, progress_callback): + return ScanResult( + resources=[], + warnings=[], + errors=[], + scan_timestamp="2024-01-01T00:00:00Z", + profile_hash="test", + ) + + plugin = ConcretePlugin() + assert plugin.get_platform_category() == PlatformCategory.CONTAINER_ORCHESTRATION + assert plugin.list_endpoints() == ["https://localhost:6443"] + assert plugin.list_supported_resource_types() == ["kubernetes_deployment"] + assert plugin.detect_architecture("localhost") == CpuArchitecture.AMD64 + + def test_abstract_methods_list(self): + """Verify all expected abstract methods are defined.""" + expected_methods = { + "authenticate", + "get_platform_category", + "list_endpoints", + "list_supported_resource_types", + "detect_architecture", + "discover_resources", + } + assert ProviderPlugin.__abstractmethods__ == expected_methods diff --git a/tests/unit/test_profile_loader.py b/tests/unit/test_profile_loader.py new file mode 100644 index 0000000..fba2fd7 --- /dev/null +++ b/tests/unit/test_profile_loader.py @@ -0,0 +1,324 @@ +"""Unit tests for ProfileLoader - YAML scan profile loading with env var expansion.""" + +import os +import textwrap +from pathlib import Path + +import pytest + +from iac_reverse.cli.profile_loader import ProfileLoader, ProfileLoaderError +from iac_reverse.models import ProviderType + + +@pytest.fixture +def loader(): + """Create a ProfileLoader instance.""" + return ProfileLoader() + + +@pytest.fixture +def tmp_profile(tmp_path): + """Helper to write a YAML profile to a temp file and return its path.""" + + def _write(content: str) -> str: + profile_file = tmp_path / "profile.yaml" + profile_file.write_text(textwrap.dedent(content), encoding="utf-8") + return str(profile_file) + + return _write + + +class TestSingleProfileLoading: + """Tests for loading a single profile from YAML.""" + + def test_loads_single_profile(self, loader, tmp_profile): + path = tmp_profile(""" + provider: kubernetes + credentials: + kubeconfig_path: /home/user/.kube/config + context: pi-cluster + endpoints: + - https://k8s-api.internal.lab:6443 + resource_type_filters: + - kubernetes_deployment + - kubernetes_service + """) + + profiles = loader.load(path) + + assert len(profiles) == 1 + profile = profiles[0] + assert profile.provider == ProviderType.KUBERNETES + assert profile.credentials == { + "kubeconfig_path": "/home/user/.kube/config", + "context": "pi-cluster", + } + assert profile.endpoints == ["https://k8s-api.internal.lab:6443"] + assert profile.resource_type_filters == [ + "kubernetes_deployment", + "kubernetes_service", + ] + + def test_loads_profile_without_optional_fields(self, loader, tmp_profile): + path = tmp_profile(""" + provider: docker_swarm + credentials: + host: tcp://swarm-manager:2376 + """) + + profiles = loader.load(path) + + assert len(profiles) == 1 + profile = profiles[0] + assert profile.provider == ProviderType.DOCKER_SWARM + assert profile.endpoints is None + assert profile.resource_type_filters is None + assert profile.authentik_token is None + + +class TestMultiProfileLoading: + """Tests for loading multiple profiles from a YAML list.""" + + def test_loads_multi_profile_yaml(self, loader, tmp_profile): + path = tmp_profile(""" + - provider: kubernetes + credentials: + kubeconfig_path: /home/user/.kube/config + context: pi-cluster + endpoints: + - https://k8s-api.internal.lab:6443 + + - provider: synology + credentials: + host: nas01.internal.lab + port: "5001" + username: admin + password: secret + endpoints: + - nas01.internal.lab:5001 + """) + + profiles = loader.load(path) + + assert len(profiles) == 2 + assert profiles[0].provider == ProviderType.KUBERNETES + assert profiles[1].provider == ProviderType.SYNOLOGY + assert profiles[1].credentials["host"] == "nas01.internal.lab" + + def test_loads_three_profiles(self, loader, tmp_profile): + path = tmp_profile(""" + - provider: kubernetes + credentials: + context: cluster-1 + + - provider: docker_swarm + credentials: + host: tcp://swarm:2376 + + - provider: windows + credentials: + host: win-server-01 + username: admin + password: pass + """) + + profiles = loader.load(path) + + assert len(profiles) == 3 + assert profiles[0].provider == ProviderType.KUBERNETES + assert profiles[1].provider == ProviderType.DOCKER_SWARM + assert profiles[2].provider == ProviderType.WINDOWS + + +class TestEnvVarExpansion: + """Tests for ${ENV_VAR} and ${ENV_VAR:-default} expansion.""" + + def test_expands_env_var(self, loader, monkeypatch): + monkeypatch.setenv("MY_SECRET", "super-secret-value") + + result = loader.expand_env_vars("${MY_SECRET}") + + assert result == "super-secret-value" + + def test_expands_env_var_with_surrounding_text(self, loader, monkeypatch): + monkeypatch.setenv("HOST", "myserver.local") + + result = loader.expand_env_vars("https://${HOST}:8443") + + assert result == "https://myserver.local:8443" + + def test_expands_multiple_env_vars(self, loader, monkeypatch): + monkeypatch.setenv("USER", "admin") + monkeypatch.setenv("PASS", "secret123") + + result = loader.expand_env_vars("${USER}:${PASS}") + + assert result == "admin:secret123" + + def test_expands_env_var_with_default(self, loader, monkeypatch): + monkeypatch.delenv("MISSING_VAR", raising=False) + + result = loader.expand_env_vars("${MISSING_VAR:-fallback_value}") + + assert result == "fallback_value" + + def test_env_var_set_overrides_default(self, loader, monkeypatch): + monkeypatch.setenv("MY_VAR", "actual_value") + + result = loader.expand_env_vars("${MY_VAR:-default_value}") + + assert result == "actual_value" + + def test_empty_default_is_valid(self, loader, monkeypatch): + monkeypatch.delenv("UNSET_VAR", raising=False) + + result = loader.expand_env_vars("prefix_${UNSET_VAR:-}_suffix") + + assert result == "prefix__suffix" + + def test_missing_env_var_without_default_raises_error(self, loader, monkeypatch): + monkeypatch.delenv("NONEXISTENT_VAR", raising=False) + + with pytest.raises(ProfileLoaderError, match="NONEXISTENT_VAR"): + loader.expand_env_vars("${NONEXISTENT_VAR}") + + def test_no_env_vars_returns_unchanged(self, loader): + result = loader.expand_env_vars("plain text without vars") + + assert result == "plain text without vars" + + +class TestCredentialExpansion: + """Tests for env var expansion applied to credential fields in profiles.""" + + def test_expands_credentials_in_profile(self, loader, tmp_profile, monkeypatch): + monkeypatch.setenv("SYNOLOGY_USER", "admin") + monkeypatch.setenv("SYNOLOGY_PASSWORD", "my_password") + + path = tmp_profile(""" + provider: synology + credentials: + host: nas01.internal.lab + username: "${SYNOLOGY_USER}" + password: "${SYNOLOGY_PASSWORD}" + """) + + profiles = loader.load(path) + + assert profiles[0].credentials["username"] == "admin" + assert profiles[0].credentials["password"] == "my_password" + # Non-env-var values remain unchanged + assert profiles[0].credentials["host"] == "nas01.internal.lab" + + def test_expands_nested_credential_values(self, loader, tmp_profile, monkeypatch): + monkeypatch.setenv("INNER_SECRET", "nested_value") + + path = tmp_profile(""" + provider: windows + credentials: + host: win-server-01 + auth: + token: "${INNER_SECRET}" + type: bearer + """) + + profiles = loader.load(path) + + assert profiles[0].credentials["auth"]["token"] == "nested_value" + assert profiles[0].credentials["auth"]["type"] == "bearer" + + def test_expands_authentik_token(self, loader, tmp_profile, monkeypatch): + monkeypatch.setenv("AUTH_TOKEN", "my-sso-token") + + path = tmp_profile(""" + provider: kubernetes + credentials: + context: cluster-1 + authentik_token: "${AUTH_TOKEN}" + """) + + profiles = loader.load(path) + + assert profiles[0].authentik_token == "my-sso-token" + + def test_credential_with_default_value(self, loader, tmp_profile, monkeypatch): + monkeypatch.delenv("OPTIONAL_PORT", raising=False) + + path = tmp_profile(""" + provider: synology + credentials: + host: nas01 + port: "${OPTIONAL_PORT:-5001}" + """) + + profiles = loader.load(path) + + assert profiles[0].credentials["port"] == "5001" + + +class TestErrorHandling: + """Tests for error cases in profile loading.""" + + def test_file_not_found_raises_error(self, loader): + with pytest.raises(ProfileLoaderError, match="Profile not found"): + loader.load("/nonexistent/path/profile.yaml") + + def test_invalid_yaml_raises_error(self, loader, tmp_profile): + path = tmp_profile(""" + provider: kubernetes + credentials: [invalid: yaml: content + """) + + with pytest.raises(ProfileLoaderError, match="Invalid YAML"): + loader.load(path) + + def test_empty_file_raises_error(self, loader, tmp_path): + profile_file = tmp_path / "empty.yaml" + profile_file.write_text("", encoding="utf-8") + + with pytest.raises(ProfileLoaderError, match="empty"): + loader.load(str(profile_file)) + + def test_unknown_provider_raises_error(self, loader, tmp_profile): + path = tmp_profile(""" + provider: unknown_provider + credentials: + key: value + """) + + with pytest.raises(ProfileLoaderError, match="Unknown provider"): + loader.load(path) + + def test_missing_provider_raises_error(self, loader, tmp_profile): + path = tmp_profile(""" + credentials: + key: value + """) + + with pytest.raises(ProfileLoaderError, match="Missing 'provider'"): + loader.load(path) + + def test_missing_env_var_in_credentials_raises_error( + self, loader, tmp_profile, monkeypatch + ): + monkeypatch.delenv("MISSING_CRED", raising=False) + + path = tmp_profile(""" + provider: kubernetes + credentials: + token: "${MISSING_CRED}" + """) + + with pytest.raises(ProfileLoaderError, match="MISSING_CRED"): + loader.load(path) + + def test_non_dict_in_multi_profile_raises_error(self, loader, tmp_profile): + path = tmp_profile(""" + - provider: kubernetes + credentials: + context: cluster-1 + - just a string + """) + + with pytest.raises(ProfileLoaderError, match="index 1.*mapping"): + loader.load(path) diff --git a/tests/unit/test_provider_block.py b/tests/unit/test_provider_block.py new file mode 100644 index 0000000..8318c72 --- /dev/null +++ b/tests/unit/test_provider_block.py @@ -0,0 +1,414 @@ +"""Unit tests for the ProviderBlockGenerator.""" + +import pytest + +from iac_reverse.models import ProviderType, ScanProfile +from iac_reverse.generator import ProviderBlockGenerator + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_profile( + provider: ProviderType, + credentials: dict[str, str] | None = None, +) -> ScanProfile: + """Create a ScanProfile with sensible defaults for testing.""" + default_creds: dict[ProviderType, dict[str, str]] = { + ProviderType.KUBERNETES: { + "host": "https://k8s-api.local:6443", + "cluster_ca_certificate": "/path/to/ca.crt", + "token": "my-token-123", + }, + ProviderType.DOCKER_SWARM: { + "host": "tcp://swarm-manager:2376", + "cert_path": "/home/user/.docker/certs", + }, + ProviderType.SYNOLOGY: { + "url": "https://nas.local:5001", + "username": "admin", + "password": "secret123", + }, + ProviderType.HARVESTER: { + "kubeconfig": "/home/user/.kube/harvester.yaml", + }, + ProviderType.BARE_METAL: { + "endpoint": "https://bmc.local/redfish/v1", + "username": "root", + "password": "calvin", + }, + ProviderType.WINDOWS: { + "host": "win-server-01.local", + "username": "Administrator", + "password": "P@ssw0rd", + "winrm_port": "5986", + "winrm_use_ssl": "true", + }, + } + creds = credentials if credentials is not None else default_creds[provider] + return ScanProfile(provider=provider, credentials=creds) + + +# --------------------------------------------------------------------------- +# Tests: Single provider generates one provider block +# --------------------------------------------------------------------------- + + +class TestSingleProvider: + """Tests for generating a single provider block.""" + + def test_kubernetes_generates_one_provider_block(self): + """A single Kubernetes provider generates one provider block.""" + profile = make_profile(ProviderType.KUBERNETES) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.KUBERNETES}, + ) + + assert result.filename == "providers.tf" + assert result.content.count('provider "kubernetes"') == 1 + + def test_docker_generates_one_provider_block(self): + """A single Docker Swarm provider generates one provider block.""" + profile = make_profile(ProviderType.DOCKER_SWARM) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.DOCKER_SWARM}, + ) + + assert result.content.count('provider "docker"') == 1 + + def test_synology_generates_one_provider_block(self): + """A single Synology provider generates one provider block.""" + profile = make_profile(ProviderType.SYNOLOGY) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.SYNOLOGY}, + ) + + assert result.content.count('provider "synology"') == 1 + + def test_harvester_generates_one_provider_block(self): + """A single Harvester provider generates one provider block.""" + profile = make_profile(ProviderType.HARVESTER) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.HARVESTER}, + ) + + assert result.content.count('provider "harvester"') == 1 + + def test_bare_metal_generates_one_provider_block(self): + """A single Bare Metal provider generates one provider block.""" + profile = make_profile(ProviderType.BARE_METAL) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.BARE_METAL}, + ) + + assert result.content.count('provider "redfish"') == 1 + + def test_windows_generates_one_provider_block(self): + """A single Windows provider generates one provider block.""" + profile = make_profile(ProviderType.WINDOWS) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.WINDOWS}, + ) + + assert result.content.count('provider "windows"') == 1 + + +# --------------------------------------------------------------------------- +# Tests: Multiple providers generate multiple blocks +# --------------------------------------------------------------------------- + + +class TestMultipleProviders: + """Tests for generating multiple provider blocks.""" + + def test_two_providers_generate_two_blocks(self): + """Two distinct providers generate two provider blocks.""" + profiles = [ + make_profile(ProviderType.KUBERNETES), + make_profile(ProviderType.DOCKER_SWARM), + ] + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=profiles, + provider_types={ProviderType.KUBERNETES, ProviderType.DOCKER_SWARM}, + ) + + assert 'provider "kubernetes"' in result.content + assert 'provider "docker"' in result.content + + def test_three_providers_generate_three_blocks(self): + """Three distinct providers generate three provider blocks.""" + profiles = [ + make_profile(ProviderType.KUBERNETES), + make_profile(ProviderType.SYNOLOGY), + make_profile(ProviderType.WINDOWS), + ] + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=profiles, + provider_types={ + ProviderType.KUBERNETES, + ProviderType.SYNOLOGY, + ProviderType.WINDOWS, + }, + ) + + assert 'provider "kubernetes"' in result.content + assert 'provider "synology"' in result.content + assert 'provider "windows"' in result.content + + def test_all_six_providers(self): + """All six provider types generate six provider blocks.""" + all_types = set(ProviderType) + profiles = [make_profile(pt) for pt in all_types] + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=profiles, + provider_types=all_types, + ) + + assert 'provider "kubernetes"' in result.content + assert 'provider "docker"' in result.content + assert 'provider "synology"' in result.content + assert 'provider "harvester"' in result.content + assert 'provider "redfish"' in result.content + assert 'provider "windows"' in result.content + + +# --------------------------------------------------------------------------- +# Tests: Provider-specific configuration is included +# --------------------------------------------------------------------------- + + +class TestProviderSpecificConfig: + """Tests for provider-specific configuration attributes.""" + + def test_kubernetes_includes_host_ca_token(self): + """Kubernetes provider block includes host, cluster_ca_certificate, token.""" + profile = make_profile(ProviderType.KUBERNETES) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.KUBERNETES}, + ) + + assert "https://k8s-api.local:6443" in result.content + assert "/path/to/ca.crt" in result.content + assert "my-token-123" in result.content + assert "host" in result.content + assert "cluster_ca_certificate" in result.content + assert "token" in result.content + + def test_docker_includes_host_cert_path(self): + """Docker provider block includes host and cert_path.""" + profile = make_profile(ProviderType.DOCKER_SWARM) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.DOCKER_SWARM}, + ) + + assert "tcp://swarm-manager:2376" in result.content + assert "/home/user/.docker/certs" in result.content + assert "host" in result.content + assert "cert_path" in result.content + + def test_synology_includes_url_username_password(self): + """Synology provider block includes url, username, password.""" + profile = make_profile(ProviderType.SYNOLOGY) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.SYNOLOGY}, + ) + + assert "https://nas.local:5001" in result.content + assert "admin" in result.content + assert "secret123" in result.content + assert "url" in result.content + assert "username" in result.content + assert "password" in result.content + + def test_harvester_includes_kubeconfig(self): + """Harvester provider block includes kubeconfig.""" + profile = make_profile(ProviderType.HARVESTER) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.HARVESTER}, + ) + + assert "/home/user/.kube/harvester.yaml" in result.content + assert "kubeconfig" in result.content + + def test_bare_metal_includes_endpoint_username_password(self): + """Bare Metal (redfish) provider block includes endpoint, username, password.""" + profile = make_profile(ProviderType.BARE_METAL) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.BARE_METAL}, + ) + + assert "https://bmc.local/redfish/v1" in result.content + assert "root" in result.content + assert "calvin" in result.content + assert "endpoint" in result.content + assert "username" in result.content + assert "password" in result.content + + def test_windows_includes_host_username_password_winrm(self): + """Windows provider block includes host, username, password, winrm settings.""" + profile = make_profile(ProviderType.WINDOWS) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.WINDOWS}, + ) + + assert "win-server-01.local" in result.content + assert "Administrator" in result.content + assert "P@ssw0rd" in result.content + assert "winrm" in result.content + assert "5986" in result.content + assert "use_ssl" in result.content + + +# --------------------------------------------------------------------------- +# Tests: required_providers block is generated +# --------------------------------------------------------------------------- + + +class TestRequiredProvidersBlock: + """Tests for the terraform { required_providers { ... } } block.""" + + def test_required_providers_block_present(self): + """The output contains a terraform { required_providers { ... } } block.""" + profile = make_profile(ProviderType.KUBERNETES) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.KUBERNETES}, + ) + + assert "terraform {" in result.content + assert "required_providers {" in result.content + + def test_required_providers_includes_source(self): + """The required_providers block includes source for each provider.""" + profile = make_profile(ProviderType.KUBERNETES) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.KUBERNETES}, + ) + + assert 'source = "hashicorp/kubernetes"' in result.content + + def test_required_providers_includes_version(self): + """The required_providers block includes version constraint.""" + profile = make_profile(ProviderType.KUBERNETES) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.KUBERNETES}, + ) + + assert 'version = "~> 2.0"' in result.content + + def test_multiple_providers_all_in_required_providers(self): + """Multiple providers are all listed in required_providers.""" + profiles = [ + make_profile(ProviderType.KUBERNETES), + make_profile(ProviderType.DOCKER_SWARM), + make_profile(ProviderType.BARE_METAL), + ] + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=profiles, + provider_types={ + ProviderType.KUBERNETES, + ProviderType.DOCKER_SWARM, + ProviderType.BARE_METAL, + }, + ) + + assert 'source = "hashicorp/kubernetes"' in result.content + assert 'source = "kreuzwerker/docker"' in result.content + assert 'source = "dell/redfish"' in result.content + + def test_docker_swarm_maps_to_kreuzwerker_docker(self): + """Docker Swarm maps to kreuzwerker/docker provider source.""" + profile = make_profile(ProviderType.DOCKER_SWARM) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.DOCKER_SWARM}, + ) + + assert "docker = {" in result.content + assert 'source = "kreuzwerker/docker"' in result.content + + def test_harvester_maps_to_harvester_source(self): + """Harvester maps to harvester/harvester provider source.""" + profile = make_profile(ProviderType.HARVESTER) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.HARVESTER}, + ) + + assert "harvester = {" in result.content + assert 'source = "harvester/harvester"' in result.content + + def test_provider_type_without_matching_profile_gets_placeholder(self): + """A provider type with no matching profile gets a placeholder block.""" + # Profile is for Kubernetes, but we request Synology too + profile = make_profile(ProviderType.KUBERNETES) + generator = ProviderBlockGenerator() + + result = generator.generate( + profiles=[profile], + provider_types={ProviderType.KUBERNETES, ProviderType.SYNOLOGY}, + ) + + # Synology should still appear in required_providers + assert "synology = {" in result.content + assert 'source = "synology-community/synology"' in result.content + # And should have a placeholder provider block + assert '# No profile provided' in result.content diff --git a/tests/unit/test_resolver.py b/tests/unit/test_resolver.py new file mode 100644 index 0000000..90e2084 --- /dev/null +++ b/tests/unit/test_resolver.py @@ -0,0 +1,481 @@ +"""Unit tests for the DependencyResolver.""" + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) +from iac_reverse.resolver import DependencyResolver + + +# --------------------------------------------------------------------------- +# Helpers / Fixtures +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "default/deployments/nginx", + name: str = "nginx", + raw_references: list[str] | None = None, + attributes: dict | None = None, + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {}, + raw_references=raw_references or [], + ) + + +def make_scan_result(resources: list[DiscoveredResource]) -> ScanResult: + """Create a ScanResult from a list of resources.""" + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test-hash", + ) + + +# --------------------------------------------------------------------------- +# Tests: Simple linear dependency chain +# --------------------------------------------------------------------------- + + +class TestLinearDependencyChain: + """Tests for a simple A -> B -> C dependency chain.""" + + def test_linear_chain_produces_correct_topological_order(self): + """Resources in a linear chain appear in dependency order.""" + # C depends on B, B depends on A + resource_a = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/nginx-svc", + name="nginx-svc", + raw_references=["ns/default"], + attributes={"namespace": "default"}, + ) + resource_c = make_resource( + resource_type="kubernetes_ingress", + unique_id="default/ingresses/nginx-ingress", + name="nginx-ingress", + raw_references=["default/services/nginx-svc"], + attributes={"service": "nginx-svc"}, + ) + + scan_result = make_scan_result([resource_a, resource_b, resource_c]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # A must appear before B, B must appear before C + order = graph.topological_order + assert order.index("ns/default") < order.index("default/services/nginx-svc") + assert order.index("default/services/nginx-svc") < order.index( + "default/ingresses/nginx-ingress" + ) + + def test_linear_chain_produces_correct_relationships(self): + """Each link in the chain produces a relationship.""" + resource_a = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/nginx-svc", + name="nginx-svc", + raw_references=["ns/default"], + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 1 + rel = graph.relationships[0] + assert rel.source_id == "default/services/nginx-svc" + assert rel.target_id == "ns/default" + + +# --------------------------------------------------------------------------- +# Tests: Multiple resources with shared dependencies +# --------------------------------------------------------------------------- + + +class TestSharedDependencies: + """Tests for multiple resources depending on the same resource.""" + + def test_shared_dependency_appears_before_all_dependents(self): + """A shared dependency appears before all resources that depend on it.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/production", + name="production", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="production/deployments/app", + name="app", + raw_references=["ns/production"], + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="production/services/app-svc", + name="app-svc", + raw_references=["ns/production"], + ) + + scan_result = make_scan_result([namespace, deployment, service]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + order = graph.topological_order + ns_idx = order.index("ns/production") + deploy_idx = order.index("production/deployments/app") + svc_idx = order.index("production/services/app-svc") + + assert ns_idx < deploy_idx + assert ns_idx < svc_idx + + def test_shared_dependency_produces_multiple_relationships(self): + """A shared dependency creates one relationship per dependent.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/production", + name="production", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="production/deployments/app", + name="app", + raw_references=["ns/production"], + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="production/services/app-svc", + name="app-svc", + raw_references=["ns/production"], + ) + + scan_result = make_scan_result([namespace, deployment, service]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 2 + target_ids = [r.target_id for r in graph.relationships] + assert target_ids.count("ns/production") == 2 + + +# --------------------------------------------------------------------------- +# Tests: Resources with no references (standalone) +# --------------------------------------------------------------------------- + + +class TestStandaloneResources: + """Tests for resources that have no references to other resources.""" + + def test_standalone_resources_appear_in_topological_order(self): + """Resources with no references still appear in the topological order.""" + resource_a = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/standalone-a", + name="standalone-a", + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/standalone-b", + name="standalone-b", + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.topological_order) == 2 + assert "default/deployments/standalone-a" in graph.topological_order + assert "default/services/standalone-b" in graph.topological_order + + def test_standalone_resources_produce_no_relationships(self): + """Resources with no references produce no relationships.""" + resource_a = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/standalone", + name="standalone", + ) + + scan_result = make_scan_result([resource_a]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 0 + + def test_empty_scan_result_produces_empty_graph(self): + """An empty scan result produces an empty dependency graph.""" + scan_result = make_scan_result([]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.resources == [] + assert graph.relationships == [] + assert graph.topological_order == [] + assert graph.cycles == [] + assert graph.unresolved_references == [] + + +# --------------------------------------------------------------------------- +# Tests: Parent-child relationship detection +# --------------------------------------------------------------------------- + + +class TestParentChildRelationships: + """Tests for parent-child relationship classification.""" + + def test_namespace_reference_classified_as_parent_child(self): + """A reference to a kubernetes_namespace is classified as parent-child.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["ns/default"], + attributes={"namespace": "default"}, + ) + + scan_result = make_scan_result([namespace, deployment]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 1 + assert graph.relationships[0].relationship_type == "parent-child" + + def test_docker_network_reference_classified_as_parent_child(self): + """A reference to a docker_network is classified as parent-child.""" + network = make_resource( + resource_type="docker_network", + unique_id="networks/overlay-net", + name="overlay-net", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ) + service = make_resource( + resource_type="docker_service", + unique_id="services/web", + name="web", + raw_references=["networks/overlay-net"], + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ) + + scan_result = make_scan_result([network, service]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 1 + assert graph.relationships[0].relationship_type == "parent-child" + + def test_dependency_relationship_for_iis_site_to_app_pool(self): + """An IIS site referencing an app pool is classified as dependency.""" + app_pool = make_resource( + resource_type="windows_iis_app_pool", + unique_id="win-server/iis/app_pools/DefaultAppPool", + name="DefaultAppPool", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + iis_site = make_resource( + resource_type="windows_iis_site", + unique_id="win-server/iis/sites/Default Web Site", + name="Default Web Site", + raw_references=["win-server/iis/app_pools/DefaultAppPool"], + attributes={"app_pool": "DefaultAppPool"}, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + + scan_result = make_scan_result([app_pool, iis_site]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 1 + assert graph.relationships[0].relationship_type == "dependency" + + def test_generic_reference_classified_as_reference(self): + """A reference to a non-namespace, non-dependency resource is 'reference'.""" + service = make_resource( + resource_type="kubernetes_service", + unique_id="default/services/nginx-svc", + name="nginx-svc", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/nginx", + name="nginx", + raw_references=["default/services/nginx-svc"], + ) + + scan_result = make_scan_result([service, deployment]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 1 + assert graph.relationships[0].relationship_type == "reference" + + +# --------------------------------------------------------------------------- +# Tests: Topological order validity +# --------------------------------------------------------------------------- + + +class TestTopologicalOrderValidity: + """Tests that topological order is valid (no resource before its dependencies).""" + + def test_no_resource_appears_before_its_dependencies(self): + """In the topological order, no resource appears before any it depends on.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/prod", + name="prod", + ) + config_map = make_resource( + resource_type="kubernetes_config_map", + unique_id="prod/configmaps/app-config", + name="app-config", + raw_references=["ns/prod"], + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="prod/deployments/app", + name="app", + raw_references=["ns/prod", "prod/configmaps/app-config"], + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="prod/services/app-svc", + name="app-svc", + raw_references=["ns/prod"], + ) + ingress = make_resource( + resource_type="kubernetes_ingress", + unique_id="prod/ingresses/app-ingress", + name="app-ingress", + raw_references=["prod/services/app-svc"], + ) + + scan_result = make_scan_result( + [namespace, config_map, deployment, service, ingress] + ) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + order = graph.topological_order + + # Verify: for each relationship, target appears before source + for rel in graph.relationships: + target_idx = order.index(rel.target_id) + source_idx = order.index(rel.source_id) + assert target_idx < source_idx, ( + f"Target {rel.target_id} (idx={target_idx}) should appear before " + f"source {rel.source_id} (idx={source_idx})" + ) + + def test_topological_order_contains_all_resources(self): + """The topological order contains every resource exactly once.""" + resources = [ + make_resource( + resource_type="kubernetes_deployment", + unique_id=f"default/deployments/app-{i}", + name=f"app-{i}", + ) + for i in range(5) + ] + + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.topological_order) == 5 + assert len(set(graph.topological_order)) == 5 # All unique + + def test_graph_returns_all_resources_in_resources_field(self): + """The DependencyGraph.resources field contains all input resources.""" + resources = [ + make_resource( + unique_id=f"resource-{i}", + name=f"res-{i}", + ) + for i in range(3) + ] + + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.resources == resources + + def test_source_attribute_identified_from_attributes(self): + """The source_attribute is identified from the resource's attributes dict.""" + app_pool = make_resource( + resource_type="windows_iis_app_pool", + unique_id="win/iis/app_pools/MyPool", + name="MyPool", + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + site = make_resource( + resource_type="windows_iis_site", + unique_id="win/iis/sites/MySite", + name="MySite", + raw_references=["win/iis/app_pools/MyPool"], + attributes={"app_pool": "MyPool", "state": "Started"}, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + + scan_result = make_scan_result([app_pool, site]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.relationships[0].source_attribute == "app_pool" + + def test_unresolved_references_are_skipped(self): + """References to IDs not in the inventory are skipped (no relationship created).""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["nonexistent/resource/id"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 0 + # Unresolved references are now tracked (task 4.3) + assert len(graph.unresolved_references) == 1 + assert graph.unresolved_references[0].referenced_id == "nonexistent/resource/id" diff --git a/tests/unit/test_resolver_cycles.py b/tests/unit/test_resolver_cycles.py new file mode 100644 index 0000000..f35f02f --- /dev/null +++ b/tests/unit/test_resolver_cycles.py @@ -0,0 +1,537 @@ +"""Unit tests for cycle detection and resolution in the DependencyResolver.""" + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + CycleReport, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) +from iac_reverse.resolver import DependencyResolver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "default/deployments/nginx", + name: str = "nginx", + raw_references: list[str] | None = None, + attributes: dict | None = None, + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {}, + raw_references=raw_references or [], + ) + + +def make_scan_result(resources: list[DiscoveredResource]) -> ScanResult: + """Create a ScanResult from a list of resources.""" + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test-hash", + ) + + +# --------------------------------------------------------------------------- +# Tests: Simple A -> B -> A cycle detection +# --------------------------------------------------------------------------- + + +class TestSimpleTwoNodeCycle: + """Tests for a simple two-node cycle: A -> B -> A.""" + + def test_two_node_cycle_detected(self): + """A simple A -> B -> A cycle is detected.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-b"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Should detect exactly one cycle + assert len(graph.cycles) == 1 + # The cycle should contain both resource IDs + cycle = graph.cycles[0] + assert set(cycle) == {"svc-a", "svc-b"} + + def test_two_node_cycle_has_cycle_report(self): + """A two-node cycle produces a CycleReport with resolution suggestion.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-b"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycle_reports) == 1 + report = graph.cycle_reports[0] + assert isinstance(report, CycleReport) + assert "data source" in report.resolution_strategy.lower() + + def test_two_node_cycle_still_produces_topological_order(self): + """Despite a cycle, a topological order is still produced.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-b"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Topological order should contain both resources + assert len(graph.topological_order) == 2 + assert set(graph.topological_order) == {"svc-a", "svc-b"} + + +# --------------------------------------------------------------------------- +# Tests: Multi-node cycle (A -> B -> C -> A) +# --------------------------------------------------------------------------- + + +class TestMultiNodeCycle: + """Tests for a multi-node cycle: A -> B -> C -> A.""" + + def test_three_node_cycle_detected(self): + """A three-node cycle A -> B -> C -> A is detected.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-c"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + resource_c = make_resource( + resource_type="kubernetes_service", + unique_id="svc-c", + name="service-c", + raw_references=["svc-b"], + ) + + scan_result = make_scan_result([resource_a, resource_b, resource_c]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Should detect exactly one cycle containing all three nodes + assert len(graph.cycles) == 1 + cycle = graph.cycles[0] + assert set(cycle) == {"svc-a", "svc-b", "svc-c"} + + def test_three_node_cycle_produces_topological_order(self): + """A three-node cycle still produces a valid topological order.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-c"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + resource_c = make_resource( + resource_type="kubernetes_service", + unique_id="svc-c", + name="service-c", + raw_references=["svc-b"], + ) + + scan_result = make_scan_result([resource_a, resource_b, resource_c]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.topological_order) == 3 + assert set(graph.topological_order) == {"svc-a", "svc-b", "svc-c"} + + def test_three_node_cycle_report_has_all_nodes(self): + """The cycle report for a 3-node cycle lists all involved resources.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-c"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + resource_c = make_resource( + resource_type="kubernetes_service", + unique_id="svc-c", + name="service-c", + raw_references=["svc-b"], + ) + + scan_result = make_scan_result([resource_a, resource_b, resource_c]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycle_reports) == 1 + report = graph.cycle_reports[0] + assert set(report.cycle) == {"svc-a", "svc-b", "svc-c"} + + +# --------------------------------------------------------------------------- +# Tests: Graph with both cycles and acyclic portions +# --------------------------------------------------------------------------- + + +class TestMixedCyclicAndAcyclicGraph: + """Tests for graphs containing both cyclic and acyclic portions.""" + + def test_cycle_detected_alongside_acyclic_resources(self): + """Cycles are detected even when acyclic resources exist in the graph.""" + # Acyclic chain: D -> E (E depends on D) + resource_d = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/production", + name="production", + ) + resource_e = make_resource( + resource_type="kubernetes_deployment", + unique_id="prod/deployments/app", + name="app", + raw_references=["ns/production"], + ) + + # Cycle: A -> B -> A + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-x", + name="service-x", + raw_references=["svc-y"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-y", + name="service-y", + raw_references=["svc-x"], + ) + + scan_result = make_scan_result( + [resource_d, resource_e, resource_a, resource_b] + ) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Should detect the cycle + assert len(graph.cycles) == 1 + assert set(graph.cycles[0]) == {"svc-x", "svc-y"} + + def test_acyclic_ordering_preserved_despite_cycle_elsewhere(self): + """Acyclic portions maintain correct topological ordering.""" + # Acyclic chain: namespace -> deployment + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/prod", + name="prod", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="prod/deploy/app", + name="app", + raw_references=["ns/prod"], + ) + + # Cycle: X -> Y -> X + svc_x = make_resource( + resource_type="kubernetes_service", + unique_id="svc-x", + name="svc-x", + raw_references=["svc-y"], + ) + svc_y = make_resource( + resource_type="kubernetes_service", + unique_id="svc-y", + name="svc-y", + raw_references=["svc-x"], + ) + + scan_result = make_scan_result([namespace, deployment, svc_x, svc_y]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + order = graph.topological_order + # The acyclic relationship should still be respected + assert order.index("ns/prod") < order.index("prod/deploy/app") + + def test_all_resources_present_in_topological_order(self): + """All resources (cyclic and acyclic) appear in the topological order.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deploy/web", + name="web", + raw_references=["ns/default"], + ) + svc_a = make_resource( + resource_type="kubernetes_service", + unique_id="cycle-a", + name="cycle-a", + raw_references=["cycle-b"], + ) + svc_b = make_resource( + resource_type="kubernetes_service", + unique_id="cycle-b", + name="cycle-b", + raw_references=["cycle-a"], + ) + + scan_result = make_scan_result([namespace, deployment, svc_a, svc_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.topological_order) == 4 + assert set(graph.topological_order) == { + "ns/default", + "default/deploy/web", + "cycle-a", + "cycle-b", + } + + +# --------------------------------------------------------------------------- +# Tests: Resolution suggestion identifies correct relationship to break +# --------------------------------------------------------------------------- + + +class TestResolutionSuggestions: + """Tests that resolution suggestions prefer breaking 'reference' over + 'dependency' over 'parent-child'.""" + + def test_prefers_breaking_reference_over_dependency(self): + """When a cycle has both reference and dependency edges, suggests breaking reference.""" + # Create a cycle where one edge is "dependency" and one is "reference" + # IIS site -> app pool (dependency), app pool -> IIS site (reference) + app_pool = make_resource( + resource_type="windows_iis_app_pool", + unique_id="pool-a", + name="pool-a", + raw_references=["site-a"], # This creates a "reference" relationship + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + iis_site = make_resource( + resource_type="windows_iis_site", + unique_id="site-a", + name="site-a", + raw_references=["pool-a"], # This creates a "dependency" relationship + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + + scan_result = make_scan_result([app_pool, iis_site]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycle_reports) == 1 + report = graph.cycle_reports[0] + # Should suggest breaking the "reference" relationship (pool -> site) + assert report.break_relationship_type == "reference" + + def test_prefers_breaking_dependency_over_parent_child(self): + """When a cycle has dependency and parent-child edges, suggests breaking dependency.""" + # Create a cycle: namespace -> deployment (parent-child back-ref), + # deployment -> namespace (dependency) + # We need to craft this carefully: + # - deployment references namespace -> classified as "parent-child" (namespace is in _NAMESPACE_RESOURCE_TYPES) + # - namespace references deployment -> classified as "reference" (deployment is not special) + # Actually let's use a different setup to get dependency vs parent-child + + # Use harvester: VM -> network (dependency), network -> VM (reference) + network = make_resource( + resource_type="harvester_network", + unique_id="net-a", + name="net-a", + raw_references=["vm-a"], # reference (VM is not a namespace type) + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + ) + vm = make_resource( + resource_type="harvester_virtualmachine", + unique_id="vm-a", + name="vm-a", + raw_references=["net-a"], # dependency (VM depends on network) + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + ) + + scan_result = make_scan_result([network, vm]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycle_reports) == 1 + report = graph.cycle_reports[0] + # The network -> VM edge is "reference", VM -> network is "parent-child" + # (because harvester_network is in _NAMESPACE_RESOURCE_TYPES) + # So it should prefer breaking "reference" + assert report.break_relationship_type == "reference" + + def test_resolution_strategy_mentions_data_source(self): + """Resolution strategy suggests data source lookup as alternative.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-b"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.cycle_reports) == 1 + report = graph.cycle_reports[0] + assert "data source" in report.resolution_strategy.lower() + + def test_suggested_break_edge_is_in_cycle(self): + """The suggested edge to break is actually part of the cycle.""" + resource_a = make_resource( + resource_type="kubernetes_service", + unique_id="svc-a", + name="service-a", + raw_references=["svc-b"], + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="svc-b", + name="service-b", + raw_references=["svc-a"], + ) + + scan_result = make_scan_result([resource_a, resource_b]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + report = graph.cycle_reports[0] + # The suggested break edge nodes should be in the cycle + source, target = report.suggested_break + assert source in report.cycle + assert target in report.cycle + + +# --------------------------------------------------------------------------- +# Tests: No cycles in acyclic graph +# --------------------------------------------------------------------------- + + +class TestNoCycles: + """Tests that acyclic graphs report no cycles.""" + + def test_linear_chain_has_no_cycles(self): + """A simple linear chain has no cycles.""" + resource_a = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + resource_b = make_resource( + resource_type="kubernetes_service", + unique_id="default/svc/app", + name="app", + raw_references=["ns/default"], + ) + resource_c = make_resource( + resource_type="kubernetes_ingress", + unique_id="default/ingress/app", + name="app-ingress", + raw_references=["default/svc/app"], + ) + + scan_result = make_scan_result([resource_a, resource_b, resource_c]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.cycles == [] + assert graph.cycle_reports == [] + + def test_empty_graph_has_no_cycles(self): + """An empty graph has no cycles.""" + scan_result = make_scan_result([]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.cycles == [] + assert graph.cycle_reports == [] + + def test_standalone_resources_have_no_cycles(self): + """Resources with no references have no cycles.""" + resources = [ + make_resource(unique_id=f"res-{i}", name=f"res-{i}") + for i in range(5) + ] + + scan_result = make_scan_result(resources) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.cycles == [] + assert graph.cycle_reports == [] diff --git a/tests/unit/test_resolver_unresolved.py b/tests/unit/test_resolver_unresolved.py new file mode 100644 index 0000000..b6abc97 --- /dev/null +++ b/tests/unit/test_resolver_unresolved.py @@ -0,0 +1,445 @@ +"""Unit tests for unresolved reference handling in the DependencyResolver.""" + +import logging + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) +from iac_reverse.resolver import DependencyResolver + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "default/deployments/nginx", + name: str = "nginx", + raw_references: list[str] | None = None, + attributes: dict | None = None, + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {}, + raw_references=raw_references or [], + ) + + +def make_scan_result(resources: list[DiscoveredResource]) -> ScanResult: + """Create a ScanResult from a list of resources.""" + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="test-hash", + ) + + +# --------------------------------------------------------------------------- +# Tests: Single unresolved reference +# --------------------------------------------------------------------------- + + +class TestSingleUnresolvedReference: + """Tests for a single unresolved reference creating an UnresolvedReference entry.""" + + def test_single_unresolved_reference_creates_entry(self): + """A raw_reference pointing to an ID not in the inventory creates an UnresolvedReference.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["nonexistent/resource/id"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.unresolved_references) == 1 + unresolved = graph.unresolved_references[0] + assert unresolved.source_resource_id == "default/deployments/app" + assert unresolved.referenced_id == "nonexistent/resource/id" + + def test_unresolved_reference_source_attribute_from_attributes(self): + """The source_attribute is identified from the resource's attributes dict.""" + resource = make_resource( + resource_type="windows_iis_site", + unique_id="win/iis/sites/MySite", + name="MySite", + raw_references=["missing-pool-id"], + attributes={"app_pool": "missing-pool-id", "state": "Started"}, + provider=ProviderType.WINDOWS, + platform_category=PlatformCategory.WINDOWS, + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.unresolved_references) == 1 + assert graph.unresolved_references[0].source_attribute == "app_pool" + + def test_unresolved_reference_fallback_to_raw_references(self): + """Falls back to 'raw_references' when the ref ID isn't in attributes.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["some/external/resource"], + attributes={"replicas": "3"}, + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].source_attribute == "raw_references" + + +# --------------------------------------------------------------------------- +# Tests: Multiple unresolved references from same resource +# --------------------------------------------------------------------------- + + +class TestMultipleUnresolvedFromSameResource: + """Tests for multiple unresolved references from the same resource.""" + + def test_multiple_unresolved_from_same_resource(self): + """Multiple unresolved references from one resource create multiple entries.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=[ + "missing/namespace/id", + "missing/configmap/id", + "missing/secret/id", + ], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.unresolved_references) == 3 + referenced_ids = [u.referenced_id for u in graph.unresolved_references] + assert "missing/namespace/id" in referenced_ids + assert "missing/configmap/id" in referenced_ids + assert "missing/secret/id" in referenced_ids + + def test_all_entries_share_same_source_resource_id(self): + """All unresolved entries from the same resource share the source_resource_id.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["missing/ref-a", "missing/ref-b"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + for unresolved in graph.unresolved_references: + assert unresolved.source_resource_id == "default/deployments/app" + + +# --------------------------------------------------------------------------- +# Tests: Mix of resolved and unresolved references +# --------------------------------------------------------------------------- + + +class TestMixedResolvedAndUnresolved: + """Tests for a mix of resolved and unresolved references.""" + + def test_resolved_creates_relationship_unresolved_creates_entry(self): + """Resolved references create relationships; unresolved create entries.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["ns/default", "missing/configmap/app-config"], + attributes={"namespace": "default"}, + ) + + scan_result = make_scan_result([namespace, deployment]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # One resolved relationship + assert len(graph.relationships) == 1 + assert graph.relationships[0].target_id == "ns/default" + + # One unresolved reference + assert len(graph.unresolved_references) == 1 + assert ( + graph.unresolved_references[0].referenced_id + == "missing/configmap/app-config" + ) + + def test_mixed_references_multiple_resources(self): + """Multiple resources with a mix of resolved and unresolved references.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/prod", + name="prod", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="prod/deployments/web", + name="web", + raw_references=["ns/prod", "external/database/postgres"], + ) + service = make_resource( + resource_type="kubernetes_service", + unique_id="prod/services/web-svc", + name="web-svc", + raw_references=["ns/prod", "missing/endpoint"], + ) + + scan_result = make_scan_result([namespace, deployment, service]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Two resolved relationships (deployment->ns, service->ns) + assert len(graph.relationships) == 2 + + # Two unresolved references + assert len(graph.unresolved_references) == 2 + unresolved_ids = [u.referenced_id for u in graph.unresolved_references] + assert "external/database/postgres" in unresolved_ids + assert "missing/endpoint" in unresolved_ids + + +# --------------------------------------------------------------------------- +# Tests: suggested_resolution is "data_source" for ID-like references +# --------------------------------------------------------------------------- + + +class TestSuggestedResolutionDataSource: + """Tests that suggested_resolution is 'data_source' for ID-like references.""" + + def test_reference_with_slash_suggests_data_source(self): + """A reference containing '/' is suggested as 'data_source'.""" + resource = make_resource( + raw_references=["external/vpc/vpc-12345"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].suggested_resolution == "data_source" + + def test_reference_with_colon_suggests_data_source(self): + """A reference containing ':' is suggested as 'data_source'.""" + resource = make_resource( + raw_references=["arn:aws:ec2:us-east-1:123456:instance/i-abc123"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].suggested_resolution == "data_source" + + def test_reference_with_both_slash_and_colon_suggests_data_source(self): + """A reference containing both '/' and ':' is suggested as 'data_source'.""" + resource = make_resource( + raw_references=["provider:type/resource-name"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].suggested_resolution == "data_source" + + +# --------------------------------------------------------------------------- +# Tests: suggested_resolution is "variable" for simple name references +# --------------------------------------------------------------------------- + + +class TestSuggestedResolutionVariable: + """Tests that suggested_resolution is 'variable' for simple name references.""" + + def test_simple_name_suggests_variable(self): + """A simple name without '/' or ':' is suggested as 'variable'.""" + resource = make_resource( + raw_references=["my-environment"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].suggested_resolution == "variable" + + def test_alphanumeric_name_suggests_variable(self): + """An alphanumeric name is suggested as 'variable'.""" + resource = make_resource( + raw_references=["production123"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].suggested_resolution == "variable" + + def test_name_with_dashes_and_underscores_suggests_variable(self): + """A name with dashes and underscores (but no / or :) is 'variable'.""" + resource = make_resource( + raw_references=["my_app-pool-name"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert graph.unresolved_references[0].suggested_resolution == "variable" + + +# --------------------------------------------------------------------------- +# Tests: Unresolved references don't create graph edges or relationships +# --------------------------------------------------------------------------- + + +class TestUnresolvedDontCreateEdges: + """Tests that unresolved references don't create graph edges or relationships.""" + + def test_unresolved_reference_creates_no_relationship(self): + """An unresolved reference does not create a ResourceRelationship.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["nonexistent/resource"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + assert len(graph.relationships) == 0 + + def test_unresolved_reference_does_not_affect_topological_order(self): + """Unresolved references don't add extra nodes to the topological order.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["nonexistent/resource"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # Only the actual resource should be in the topological order + assert graph.topological_order == ["default/deployments/app"] + + def test_unresolved_does_not_block_resolved_relationships(self): + """Unresolved references don't prevent resolved references from working.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["ns/default", "nonexistent/configmap"], + ) + + scan_result = make_scan_result([namespace, deployment]) + resolver = DependencyResolver(scan_result) + graph = resolver.resolve() + + # The resolved relationship still works + assert len(graph.relationships) == 1 + assert graph.relationships[0].source_id == "default/deployments/app" + assert graph.relationships[0].target_id == "ns/default" + + # Topological order is correct + order = graph.topological_order + assert order.index("ns/default") < order.index("default/deployments/app") + + # Unresolved reference is tracked + assert len(graph.unresolved_references) == 1 + + +# --------------------------------------------------------------------------- +# Tests: Warning logging for unresolved references +# --------------------------------------------------------------------------- + + +class TestUnresolvedReferenceLogging: + """Tests that warnings are logged for unresolved references.""" + + def test_warning_logged_for_unresolved_reference(self, caplog): + """A warning is logged for each unresolved reference.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["missing/resource/id"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + + with caplog.at_level(logging.WARNING, logger="iac_reverse.resolver.resolver"): + graph = resolver.resolve() + + assert len(caplog.records) == 1 + record = caplog.records[0] + assert record.levelname == "WARNING" + assert "missing/resource/id" in record.message + assert "default/deployments/app" in record.message + + def test_multiple_warnings_logged_for_multiple_unresolved(self, caplog): + """A warning is logged for each unresolved reference.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="default/deployments/app", + name="app", + raw_references=["missing/ref-a", "missing/ref-b"], + ) + + scan_result = make_scan_result([resource]) + resolver = DependencyResolver(scan_result) + + with caplog.at_level(logging.WARNING, logger="iac_reverse.resolver.resolver"): + graph = resolver.resolve() + + warning_messages = [r.message for r in caplog.records if r.levelname == "WARNING"] + assert len(warning_messages) == 2 + assert any("missing/ref-a" in msg for msg in warning_messages) + assert any("missing/ref-b" in msg for msg in warning_messages) diff --git a/tests/unit/test_resource_merger.py b/tests/unit/test_resource_merger.py new file mode 100644 index 0000000..68fdf76 --- /dev/null +++ b/tests/unit/test_resource_merger.py @@ -0,0 +1,343 @@ +"""Unit tests for the ResourceMerger.""" + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) +from iac_reverse.generator import ResourceMerger + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_resource( + name: str = "nginx", + resource_type: str = "kubernetes_deployment", + unique_id: str = "default/deployments/nginx", + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, + architecture: CpuArchitecture = CpuArchitecture.AARCH64, + attributes: dict | None = None, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=architecture, + endpoint="https://api.local:6443", + attributes=attributes or {}, + raw_references=[], + ) + + +def make_scan_result(resources: list[DiscoveredResource]) -> ScanResult: + """Create a ScanResult wrapping the given resources.""" + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash="abc123", + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestResourceMergerSingleResult: + """Single scan result passes through unchanged.""" + + def test_single_scan_result_passes_through_unchanged(self): + """Resources from a single scan result are returned as-is.""" + resources = [ + make_resource(name="nginx", unique_id="k8s/nginx"), + make_resource(name="redis", unique_id="k8s/redis"), + ] + scan_result = make_scan_result(resources) + + merger = ResourceMerger() + merged = merger.merge([scan_result]) + + assert len(merged) == 2 + assert merged[0].name == "nginx" + assert merged[1].name == "redis" + + def test_empty_scan_result_returns_empty_list(self): + """An empty scan result produces an empty merged list.""" + scan_result = make_scan_result([]) + + merger = ResourceMerger() + merged = merger.merge([scan_result]) + + assert merged == [] + + +class TestResourceMergerNoConflicts: + """Two scan results with no conflicts merge cleanly.""" + + def test_two_results_different_names_merge_cleanly(self): + """Resources with different names from different providers merge without prefixing.""" + k8s_resources = [ + make_resource( + name="nginx", + unique_id="k8s/nginx", + provider=ProviderType.KUBERNETES, + ), + ] + docker_resources = [ + make_resource( + name="redis", + unique_id="docker/redis", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(docker_resources), + ]) + + assert len(merged) == 2 + names = {r.name for r in merged} + assert names == {"nginx", "redis"} + + def test_same_name_same_provider_no_conflict(self): + """Resources with the same name from the same provider are not conflicts.""" + resources = [ + make_resource(name="nginx", unique_id="k8s/ns1/nginx"), + make_resource(name="nginx", unique_id="k8s/ns2/nginx"), + ] + scan_result = make_scan_result(resources) + + merger = ResourceMerger() + merged = merger.merge([scan_result]) + + assert len(merged) == 2 + # Both keep original name since they're from the same provider + assert all(r.name == "nginx" for r in merged) + + +class TestResourceMergerConflictResolution: + """Two scan results with name conflicts get provider-prefixed names.""" + + def test_conflicting_names_get_provider_prefix(self): + """Resources with the same name from different providers get prefixed.""" + k8s_resources = [ + make_resource( + name="nginx", + unique_id="k8s/nginx", + provider=ProviderType.KUBERNETES, + ), + ] + docker_resources = [ + make_resource( + name="nginx", + unique_id="docker/nginx", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(docker_resources), + ]) + + assert len(merged) == 2 + names = {r.name for r in merged} + assert "kubernetes_nginx" in names + assert "docker_swarm_nginx" in names + + def test_non_conflicting_names_unchanged_alongside_conflicts(self): + """Non-conflicting resources keep their original names even when conflicts exist.""" + k8s_resources = [ + make_resource(name="nginx", unique_id="k8s/nginx", provider=ProviderType.KUBERNETES), + make_resource(name="postgres", unique_id="k8s/postgres", provider=ProviderType.KUBERNETES), + ] + docker_resources = [ + make_resource( + name="nginx", + unique_id="docker/nginx", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(docker_resources), + ]) + + assert len(merged) == 3 + names = {r.name for r in merged} + assert "kubernetes_nginx" in names + assert "docker_swarm_nginx" in names + assert "postgres" in names + + +class TestResourceMergerPreservesAttributes: + """Provider-specific attributes are preserved.""" + + def test_attributes_preserved_after_merge(self): + """Provider-specific attributes remain unchanged after merging.""" + k8s_attrs = {"namespace": "default", "replicas": 3, "image": "nginx:1.25"} + docker_attrs = {"mode": "replicated", "replicas": 2, "network": "overlay"} + + k8s_resources = [ + make_resource( + name="nginx", + unique_id="k8s/nginx", + provider=ProviderType.KUBERNETES, + attributes=k8s_attrs, + ), + ] + docker_resources = [ + make_resource( + name="nginx", + unique_id="docker/nginx", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + attributes=docker_attrs, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(docker_resources), + ]) + + k8s_merged = next(r for r in merged if r.name == "kubernetes_nginx") + docker_merged = next(r for r in merged if r.name == "docker_swarm_nginx") + + assert k8s_merged.attributes == k8s_attrs + assert docker_merged.attributes == docker_attrs + + def test_provider_and_metadata_preserved(self): + """Provider type, platform category, and architecture are preserved.""" + k8s_resources = [ + make_resource( + name="app", + unique_id="k8s/app", + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + ), + ] + harvester_resources = [ + make_resource( + name="app", + unique_id="harvester/app", + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + architecture=CpuArchitecture.AMD64, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(harvester_resources), + ]) + + k8s_merged = next(r for r in merged if r.name == "kubernetes_app") + harvester_merged = next(r for r in merged if r.name == "harvester_app") + + assert k8s_merged.provider == ProviderType.KUBERNETES + assert k8s_merged.platform_category == PlatformCategory.CONTAINER_ORCHESTRATION + assert k8s_merged.architecture == CpuArchitecture.AARCH64 + + assert harvester_merged.provider == ProviderType.HARVESTER + assert harvester_merged.platform_category == PlatformCategory.HCI + assert harvester_merged.architecture == CpuArchitecture.AMD64 + + +class TestResourceMergerBothAppearInOutput: + """Resources from different providers with same name both appear in output.""" + + def test_both_conflicting_resources_present(self): + """Both resources with the same name from different providers appear in output.""" + k8s_resources = [ + make_resource( + name="webserver", + unique_id="k8s/webserver", + provider=ProviderType.KUBERNETES, + attributes={"replicas": 3}, + ), + ] + docker_resources = [ + make_resource( + name="webserver", + unique_id="docker/webserver", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + attributes={"mode": "global"}, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(docker_resources), + ]) + + assert len(merged) == 2 + unique_ids = {r.unique_id for r in merged} + assert "k8s/webserver" in unique_ids + assert "docker/webserver" in unique_ids + + def test_three_providers_same_name_all_appear(self): + """Three providers with the same resource name all appear with prefixes.""" + k8s_resources = [ + make_resource( + name="monitor", + unique_id="k8s/monitor", + provider=ProviderType.KUBERNETES, + ), + ] + docker_resources = [ + make_resource( + name="monitor", + unique_id="docker/monitor", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ), + ] + harvester_resources = [ + make_resource( + name="monitor", + unique_id="harvester/monitor", + provider=ProviderType.HARVESTER, + platform_category=PlatformCategory.HCI, + architecture=CpuArchitecture.AMD64, + ), + ] + + merger = ResourceMerger() + merged = merger.merge([ + make_scan_result(k8s_resources), + make_scan_result(docker_resources), + make_scan_result(harvester_resources), + ]) + + assert len(merged) == 3 + names = {r.name for r in merged} + assert "kubernetes_monitor" in names + assert "docker_swarm_monitor" in names + assert "harvester_monitor" in names diff --git a/tests/unit/test_sanitize.py b/tests/unit/test_sanitize.py new file mode 100644 index 0000000..bccd0eb --- /dev/null +++ b/tests/unit/test_sanitize.py @@ -0,0 +1,115 @@ +"""Unit tests for identifier sanitization.""" + +import re + +import pytest + +from iac_reverse.generator.sanitize import sanitize_identifier + + +TERRAFORM_IDENTIFIER_RE = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + +class TestSanitizeIdentifierNormalNames: + def test_simple_name_passes_through(self): + assert sanitize_identifier("nginx") == "nginx" + + def test_name_with_underscores_passes_through(self): + assert sanitize_identifier("my_app") == "my_app" + + def test_alphanumeric_name_passes_through(self): + assert sanitize_identifier("app123") == "app123" + + +class TestSanitizeIdentifierHyphens: + def test_hyphens_replaced_with_underscore(self): + assert sanitize_identifier("my-app") == "my_app" + + def test_multiple_hyphens(self): + assert sanitize_identifier("my-cool-app") == "my_cool_app" + + +class TestSanitizeIdentifierLeadingDigits: + def test_leading_digit_gets_underscore_prefix(self): + assert sanitize_identifier("123abc") == "_123abc" + + def test_all_digits(self): + result = sanitize_identifier("12345") + assert result == "_12345" + assert TERRAFORM_IDENTIFIER_RE.match(result) + + +class TestSanitizeIdentifierSpaces: + def test_spaces_replaced_with_underscore(self): + assert sanitize_identifier("my app") == "my_app" + + def test_multiple_spaces_collapse(self): + assert sanitize_identifier("my app") == "my_app" + + +class TestSanitizeIdentifierUnicode: + def test_unicode_replaced(self): + # é is replaced with underscore, trailing underscore preserved + assert sanitize_identifier("café") == "caf_" + + def test_all_unicode(self): + result = sanitize_identifier("日本語") + assert result == "_resource" + + def test_emoji_replaced(self): + result = sanitize_identifier("app🚀name") + assert result == "app_name" + + +class TestSanitizeIdentifierEmptyAndSpecial: + def test_empty_string_returns_resource(self): + assert sanitize_identifier("") == "_resource" + + def test_all_special_chars_returns_resource(self): + assert sanitize_identifier("@#$%^&*") == "_resource" + + def test_single_special_char(self): + assert sanitize_identifier("!") == "_resource" + + +class TestSanitizeIdentifierConsecutiveSpecialChars: + def test_multiple_consecutive_special_chars_collapse(self): + assert sanitize_identifier("a---b") == "a_b" + + def test_mixed_special_chars_collapse(self): + assert sanitize_identifier("a-.-b") == "a_b" + + def test_leading_special_chars_collapse(self): + # Leading hyphens become underscore (valid identifier start) + result = sanitize_identifier("---abc") + assert result == "_abc" + + +class TestSanitizeIdentifierAlwaysValid: + """Verify the result always matches the Terraform identifier regex.""" + + @pytest.mark.parametrize("name", [ + "nginx", + "my-app", + "123abc", + "my app", + "café", + "", + "@#$%^&*", + "a---b", + "___", + "hello_world_123", + "日本語テスト", + " leading spaces", + "trailing ", + "MixedCase-Name_123", + "a", + "_", + "0", + ]) + def test_result_matches_terraform_regex(self, name): + result = sanitize_identifier(name) + assert TERRAFORM_IDENTIFIER_RE.match(result), ( + f"sanitize_identifier({name!r}) = {result!r} does not match " + f"Terraform identifier pattern" + ) diff --git a/tests/unit/test_scan_profile_validation.py b/tests/unit/test_scan_profile_validation.py new file mode 100644 index 0000000..1000ebc --- /dev/null +++ b/tests/unit/test_scan_profile_validation.py @@ -0,0 +1,193 @@ +"""Unit tests for ScanProfile validation logic.""" + +import pytest + +from iac_reverse.models import ( + MAX_RESOURCE_TYPE_FILTERS, + PROVIDER_SUPPORTED_RESOURCE_TYPES, + ProviderType, + ScanProfile, +) + + +class TestScanProfileValidationCredentials: + """Tests for credentials validation.""" + + def test_empty_credentials_returns_error(self): + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={}, + ) + errors = profile.validate() + assert any("credentials must not be empty" in e for e in errors) + + def test_non_empty_credentials_passes(self): + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={"kubeconfig_path": "/home/user/.kube/config"}, + ) + errors = profile.validate() + assert errors == [] + + +class TestScanProfileValidationResourceTypeFilters: + """Tests for resource_type_filters count limit.""" + + def test_none_filters_passes(self): + profile = ScanProfile( + provider=ProviderType.DOCKER_SWARM, + credentials={"host": "localhost"}, + resource_type_filters=None, + ) + errors = profile.validate() + assert errors == [] + + def test_empty_filters_passes(self): + profile = ScanProfile( + provider=ProviderType.DOCKER_SWARM, + credentials={"host": "localhost"}, + resource_type_filters=[], + ) + errors = profile.validate() + assert errors == [] + + def test_filters_at_max_limit_passes(self): + # Use valid resource types repeated to reach the limit + valid_types = PROVIDER_SUPPORTED_RESOURCE_TYPES[ProviderType.WINDOWS] + filters = (valid_types * (MAX_RESOURCE_TYPE_FILTERS // len(valid_types) + 1))[ + :MAX_RESOURCE_TYPE_FILTERS + ] + profile = ScanProfile( + provider=ProviderType.WINDOWS, + credentials={"host": "win01"}, + resource_type_filters=filters, + ) + errors = profile.validate() + assert errors == [] + + def test_filters_exceeding_max_limit_returns_error(self): + # Use valid resource types repeated to exceed the limit + valid_types = PROVIDER_SUPPORTED_RESOURCE_TYPES[ProviderType.WINDOWS] + filters = (valid_types * (MAX_RESOURCE_TYPE_FILTERS // len(valid_types) + 1))[ + : MAX_RESOURCE_TYPE_FILTERS + 1 + ] + profile = ScanProfile( + provider=ProviderType.WINDOWS, + credentials={"host": "win01"}, + resource_type_filters=filters, + ) + errors = profile.validate() + assert any("at most" in e and "200" in e for e in errors) + + def test_max_limit_is_200(self): + assert MAX_RESOURCE_TYPE_FILTERS == 200 + + +class TestScanProfileValidationResourceTypeSupport: + """Tests for resource type validation against provider's supported types.""" + + def test_valid_resource_types_for_kubernetes(self): + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={"kubeconfig_path": "/path"}, + resource_type_filters=["kubernetes_deployment", "kubernetes_service"], + ) + errors = profile.validate() + assert errors == [] + + def test_valid_resource_types_for_docker_swarm(self): + profile = ScanProfile( + provider=ProviderType.DOCKER_SWARM, + credentials={"host": "localhost"}, + resource_type_filters=["docker_service", "docker_network"], + ) + errors = profile.validate() + assert errors == [] + + def test_unsupported_resource_type_returns_error(self): + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={"kubeconfig_path": "/path"}, + resource_type_filters=["kubernetes_deployment", "invalid_type"], + ) + errors = profile.validate() + assert any("unsupported resource types" in e for e in errors) + assert any("invalid_type" in e for e in errors) + + def test_cross_provider_resource_type_returns_error(self): + """A resource type valid for one provider is invalid for another.""" + profile = ScanProfile( + provider=ProviderType.KUBERNETES, + credentials={"kubeconfig_path": "/path"}, + resource_type_filters=["docker_service"], + ) + errors = profile.validate() + assert any("unsupported resource types" in e for e in errors) + assert any("docker_service" in e for e in errors) + + def test_multiple_unsupported_types_listed_in_error(self): + profile = ScanProfile( + provider=ProviderType.SYNOLOGY, + credentials={"host": "nas01"}, + resource_type_filters=["fake_type_a", "fake_type_b"], + ) + errors = profile.validate() + assert any("fake_type_a" in e and "fake_type_b" in e for e in errors) + + @pytest.mark.parametrize("provider", list(ProviderType)) + def test_all_providers_have_supported_types(self, provider): + """Every provider must have at least one supported resource type.""" + assert provider in PROVIDER_SUPPORTED_RESOURCE_TYPES + assert len(PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]) > 0 + + @pytest.mark.parametrize("provider", list(ProviderType)) + def test_all_supported_types_pass_validation(self, provider): + """All listed supported types for a provider should pass validation.""" + supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider] + profile = ScanProfile( + provider=provider, + credentials={"key": "value"}, + resource_type_filters=supported, + ) + errors = profile.validate() + assert errors == [] + + +class TestScanProfileValidationMultipleErrors: + """Tests that all validation errors are returned in a single response.""" + + def test_empty_credentials_and_too_many_filters(self): + filters = ["invalid_type"] * (MAX_RESOURCE_TYPE_FILTERS + 1) + profile = ScanProfile( + provider=ProviderType.HARVESTER, + credentials={}, + resource_type_filters=filters, + ) + errors = profile.validate() + # Should have at least: credentials error, too many filters, unsupported types + assert len(errors) >= 3 + assert any("credentials" in e for e in errors) + assert any("at most" in e for e in errors) + assert any("unsupported" in e for e in errors) + + def test_empty_credentials_and_unsupported_types(self): + profile = ScanProfile( + provider=ProviderType.BARE_METAL, + credentials={}, + resource_type_filters=["nonexistent_type"], + ) + errors = profile.validate() + assert len(errors) >= 2 + assert any("credentials" in e for e in errors) + assert any("unsupported" in e for e in errors) + + def test_no_short_circuit_on_first_error(self): + """Validation must not stop at the first error found.""" + profile = ScanProfile( + provider=ProviderType.WINDOWS, + credentials={}, + resource_type_filters=["totally_fake_resource"], + ) + errors = profile.validate() + # Both credentials and unsupported type errors should be present + assert len(errors) >= 2 diff --git a/tests/unit/test_scanner.py b/tests/unit/test_scanner.py new file mode 100644 index 0000000..d199974 --- /dev/null +++ b/tests/unit/test_scanner.py @@ -0,0 +1,480 @@ +"""Unit tests for the Scanner orchestrator.""" + +import time +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import ( + AuthenticationError, + ConnectionLostError, + Scanner, + ScanTimeoutError, +) + + +# --------------------------------------------------------------------------- +# Helpers / Fixtures +# --------------------------------------------------------------------------- + + +def make_profile(**overrides) -> ScanProfile: + """Create a valid ScanProfile with sensible defaults.""" + defaults = { + "provider": ProviderType.KUBERNETES, + "credentials": {"kubeconfig_path": "/home/user/.kube/config"}, + "endpoints": ["https://k8s-api.local:6443"], + "resource_type_filters": None, + } + defaults.update(overrides) + return ScanProfile(**defaults) + + +def make_resource(resource_type: str = "kubernetes_deployment", name: str = "nginx") -> DiscoveredResource: + """Create a sample DiscoveredResource.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=f"apps/v1/{resource_type}/{name}", + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes={"replicas": 3}, + raw_references=[], + ) + + +class MockPlugin(ProviderPlugin): + """A mock provider plugin for testing.""" + + def __init__( + self, + supported_types=None, + endpoints=None, + auth_error=None, + discover_result=None, + discover_side_effect=None, + ): + self._supported_types = supported_types or [ + "kubernetes_deployment", + "kubernetes_service", + "kubernetes_ingress", + ] + self._endpoints = endpoints or ["https://k8s-api.local:6443"] + self._auth_error = auth_error + self._discover_result = discover_result + self._discover_side_effect = discover_side_effect + + def authenticate(self, credentials: dict[str, str]) -> None: + if self._auth_error: + raise self._auth_error + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return self._endpoints + + def list_supported_resource_types(self) -> list[str]: + return self._supported_types + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AARCH64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback=None, + ) -> ScanResult: + if self._discover_side_effect: + raise self._discover_side_effect + + if self._discover_result: + return self._discover_result + + # Default: return one resource per type and invoke progress callback + resources = [] + for i, rt in enumerate(resource_types): + resources.append(make_resource(resource_type=rt, name=f"{rt}_instance")) + if progress_callback: + progress_callback( + ScanProgress( + current_resource_type=rt, + resources_discovered=len(resources), + resource_types_completed=i + 1, + total_resource_types=len(resource_types), + ) + ) + + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + +# --------------------------------------------------------------------------- +# Tests: Successful scan flow +# --------------------------------------------------------------------------- + + +class TestSuccessfulScan: + """Tests for the happy path scan flow.""" + + def test_scan_returns_all_discovered_resources(self): + profile = make_profile() + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 3 + assert result.is_partial is False + assert result.scan_timestamp != "" + assert result.profile_hash != "" + + def test_scan_with_resource_type_filters(self): + profile = make_profile( + resource_type_filters=["kubernetes_deployment", "kubernetes_service"] + ) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 2 + resource_types = [r.resource_type for r in result.resources] + assert "kubernetes_deployment" in resource_types + assert "kubernetes_service" in resource_types + + def test_scan_uses_plugin_endpoints_when_profile_has_none(self): + profile = make_profile(endpoints=None) + plugin = MockPlugin(endpoints=["https://fallback.local:6443"]) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 3 + + def test_scan_uses_profile_endpoints_when_provided(self): + profile = make_profile(endpoints=["https://custom.local:6443"]) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + # The plugin's discover_resources will be called with profile endpoints + result = scanner.scan() + assert result is not None + + +# --------------------------------------------------------------------------- +# Tests: Authentication failure handling +# --------------------------------------------------------------------------- + + +class TestAuthenticationFailure: + """Tests for authentication error handling.""" + + def test_auth_failure_raises_authentication_error(self): + profile = make_profile() + plugin = MockPlugin(auth_error=RuntimeError("Invalid token")) + scanner = Scanner(profile, plugin) + + with pytest.raises(AuthenticationError) as exc_info: + scanner.scan() + + assert "kubernetes" in exc_info.value.provider_name + assert "Invalid token" in exc_info.value.reason + + def test_auth_error_contains_provider_name(self): + profile = make_profile(provider=ProviderType.DOCKER_SWARM) + plugin = MockPlugin(auth_error=RuntimeError("Connection refused")) + scanner = Scanner(profile, plugin) + + with pytest.raises(AuthenticationError) as exc_info: + scanner.scan() + + assert exc_info.value.provider_name == "docker_swarm" + + def test_auth_error_contains_reason(self): + profile = make_profile() + plugin = MockPlugin(auth_error=RuntimeError("Certificate expired")) + scanner = Scanner(profile, plugin) + + with pytest.raises(AuthenticationError) as exc_info: + scanner.scan() + + assert "Certificate expired" in exc_info.value.reason + + def test_invalid_profile_raises_value_error(self): + profile = make_profile(credentials={}) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + with pytest.raises(ValueError, match="Invalid scan profile"): + scanner.scan() + + def test_no_plugin_raises_value_error(self): + profile = make_profile() + scanner = Scanner(profile, plugin=None) + + with pytest.raises(ValueError, match="No provider plugin"): + scanner.scan() + + +# --------------------------------------------------------------------------- +# Tests: Progress callback invocation +# --------------------------------------------------------------------------- + + +class TestProgressCallback: + """Tests for progress reporting.""" + + def test_progress_callback_invoked_per_resource_type(self): + profile = make_profile() + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + progress_updates = [] + scanner.scan(progress_callback=progress_updates.append) + + # Should have one progress update per resource type (3 types) + assert len(progress_updates) == 3 + + def test_progress_callback_contains_correct_counts(self): + profile = make_profile() + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + progress_updates = [] + scanner.scan(progress_callback=progress_updates.append) + + # Last update should show all types completed + last = progress_updates[-1] + assert last.resource_types_completed == 3 + assert last.total_resource_types == 3 + + def test_progress_callback_shows_incremental_completion(self): + profile = make_profile() + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + progress_updates = [] + scanner.scan(progress_callback=progress_updates.append) + + for i, update in enumerate(progress_updates): + assert update.resource_types_completed == i + 1 + + def test_scan_works_without_progress_callback(self): + profile = make_profile() + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + # Should not raise + result = scanner.scan(progress_callback=None) + assert result is not None + + +# --------------------------------------------------------------------------- +# Tests: Retry logic on transient errors +# --------------------------------------------------------------------------- + + +class TestRetryLogic: + """Tests for retry behavior on transient errors.""" + + def test_retries_on_transient_error_then_succeeds(self): + profile = make_profile() + call_count = {"n": 0} + expected_result = ScanResult( + resources=[make_resource()], + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + class RetryPlugin(MockPlugin): + def discover_resources(self, endpoints, resource_types, progress_callback=None): + call_count["n"] += 1 + if call_count["n"] < 3: + raise RuntimeError("Transient network error") + return expected_result + + plugin = RetryPlugin() + scanner = Scanner(profile, plugin) + + with patch("iac_reverse.scanner.scanner.time.sleep"): + result = scanner.scan() + + assert len(result.resources) == 1 + assert call_count["n"] == 3 + + def test_returns_error_result_after_max_retries_exhausted(self): + profile = make_profile() + + class AlwaysFailPlugin(MockPlugin): + def discover_resources(self, endpoints, resource_types, progress_callback=None): + raise RuntimeError("Persistent failure") + + plugin = AlwaysFailPlugin() + scanner = Scanner(profile, plugin) + + with patch("iac_reverse.scanner.scanner.time.sleep"): + result = scanner.scan() + + assert result.is_partial is True + assert len(result.errors) > 0 + assert "Persistent failure" in result.errors[0] + + def test_exponential_backoff_timing(self): + profile = make_profile() + sleep_calls = [] + + class FailPlugin(MockPlugin): + def discover_resources(self, endpoints, resource_types, progress_callback=None): + raise RuntimeError("Transient error") + + plugin = FailPlugin() + scanner = Scanner(profile, plugin) + + with patch("iac_reverse.scanner.scanner.time.sleep", side_effect=lambda s: sleep_calls.append(s)): + scanner.scan() + + # Should have 3 sleep calls (retries 0, 1, 2 → backoff 1, 2, 4) + assert len(sleep_calls) == 3 + assert sleep_calls[0] == 1.0 + assert sleep_calls[1] == 2.0 + assert sleep_calls[2] == 4.0 + + +# --------------------------------------------------------------------------- +# Tests: Partial inventory on connection loss +# --------------------------------------------------------------------------- + + +class TestConnectionLoss: + """Tests for partial inventory return on connection loss.""" + + def test_connection_error_returns_partial_result(self): + profile = make_profile() + + class ConnectionLossPlugin(MockPlugin): + def discover_resources(self, endpoints, resource_types, progress_callback=None): + raise ConnectionError("Connection reset by peer") + + plugin = ConnectionLossPlugin() + scanner = Scanner(profile, plugin) + + with pytest.raises(ConnectionLostError) as exc_info: + scanner.scan() + + partial = exc_info.value.partial_result + assert partial.is_partial is True + assert len(partial.errors) > 0 + + def test_connection_lost_error_has_partial_result(self): + profile = make_profile() + partial = ScanResult( + resources=[make_resource()], + warnings=["partial"], + errors=["lost connection"], + scan_timestamp="", + profile_hash="", + is_partial=True, + ) + + class RaisesConnectionLost(MockPlugin): + def discover_resources(self, endpoints, resource_types, progress_callback=None): + raise ConnectionLostError(partial_result=partial) + + plugin = RaisesConnectionLost() + scanner = Scanner(profile, plugin) + + with pytest.raises(ConnectionLostError) as exc_info: + scanner.scan() + + assert exc_info.value.partial_result.is_partial is True + assert len(exc_info.value.partial_result.resources) == 1 + + +# --------------------------------------------------------------------------- +# Tests: Warning for unsupported resource types +# --------------------------------------------------------------------------- + + +class TestUnsupportedResourceTypes: + """Tests for warning logging on unsupported resource types.""" + + def test_unsupported_types_generate_warnings(self): + profile = make_profile( + resource_type_filters=[ + "kubernetes_deployment", + "nonexistent_type", + "another_fake_type", + ] + ) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + # Should have warnings for the 2 unsupported types + unsupported_warnings = [ + w for w in result.warnings if "Unsupported resource type" in w + ] + assert len(unsupported_warnings) == 2 + assert "nonexistent_type" in unsupported_warnings[0] + assert "another_fake_type" in unsupported_warnings[1] + + def test_unsupported_types_do_not_block_supported_types(self): + profile = make_profile( + resource_type_filters=[ + "kubernetes_deployment", + "nonexistent_type", + ] + ) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + # Should still discover the supported type + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "kubernetes_deployment" + + def test_all_unsupported_types_returns_empty_resources(self): + profile = make_profile( + resource_type_filters=["fake_type_1", "fake_type_2"] + ) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 0 + assert len(result.warnings) == 2 + + def test_warning_includes_provider_name(self): + profile = make_profile( + resource_type_filters=["unsupported_thing"] + ) + plugin = MockPlugin() + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert "kubernetes" in result.warnings[0] diff --git a/tests/unit/test_scanner_filtering.py b/tests/unit/test_scanner_filtering.py new file mode 100644 index 0000000..56356df --- /dev/null +++ b/tests/unit/test_scanner_filtering.py @@ -0,0 +1,234 @@ +"""Unit tests for Scanner resource type filtering behavior. + +Validates Requirements 6.2 and 6.3: +- 6.2: When resource type filters are specified, discover only those types +- 6.3: When no resource type filters are specified, discover all supported types +""" + +import pytest +from unittest.mock import patch + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanProfile, + ScanProgress, + ScanResult, +) +from iac_reverse.plugin_base import ProviderPlugin +from iac_reverse.scanner.scanner import Scanner + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_profile(**overrides) -> ScanProfile: + """Create a valid ScanProfile with sensible defaults.""" + defaults = { + "provider": ProviderType.KUBERNETES, + "credentials": {"kubeconfig_path": "/home/user/.kube/config"}, + "endpoints": ["https://k8s-api.local:6443"], + "resource_type_filters": None, + } + defaults.update(overrides) + return ScanProfile(**defaults) + + +def make_resource(resource_type: str, name: str = "instance") -> DiscoveredResource: + """Create a sample DiscoveredResource.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=f"{resource_type}/{name}", + name=name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes={}, + raw_references=[], + ) + + +class FakePlugin(ProviderPlugin): + """A fake plugin that tracks which resource types were requested.""" + + def __init__(self, supported_types: list[str]): + self._supported_types = supported_types + self.discovered_types: list[str] = [] + + def authenticate(self, credentials: dict[str, str]) -> None: + pass + + def get_platform_category(self) -> PlatformCategory: + return PlatformCategory.CONTAINER_ORCHESTRATION + + def list_endpoints(self) -> list[str]: + return ["https://k8s-api.local:6443"] + + def list_supported_resource_types(self) -> list[str]: + return self._supported_types + + def detect_architecture(self, endpoint: str) -> CpuArchitecture: + return CpuArchitecture.AARCH64 + + def discover_resources( + self, + endpoints: list[str], + resource_types: list[str], + progress_callback=None, + ) -> ScanResult: + self.discovered_types = resource_types + resources = [make_resource(rt) for rt in resource_types] + return ScanResult( + resources=resources, + warnings=[], + errors=[], + scan_timestamp="", + profile_hash="", + ) + + +# --------------------------------------------------------------------------- +# Tests: Requirement 6.2 - Filters specified, only those types discovered +# --------------------------------------------------------------------------- + + +class TestFilteredResourceTypes: + """Requirement 6.2: When filters are specified, discover only listed types.""" + + def test_single_filter_discovers_only_that_type(self): + """Only the single filtered type should be discovered.""" + supported = ["kubernetes_deployment", "kubernetes_service", "kubernetes_ingress"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile(resource_type_filters=["kubernetes_deployment"]) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "kubernetes_deployment" + assert plugin.discovered_types == ["kubernetes_deployment"] + + def test_multiple_filters_discovers_only_those_types(self): + """Only the filtered types should be discovered, not all supported.""" + supported = ["kubernetes_deployment", "kubernetes_service", "kubernetes_ingress"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile( + resource_type_filters=["kubernetes_deployment", "kubernetes_service"] + ) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 2 + discovered_types = {r.resource_type for r in result.resources} + assert discovered_types == {"kubernetes_deployment", "kubernetes_service"} + assert "kubernetes_ingress" not in plugin.discovered_types + + def test_unsupported_types_in_filter_are_skipped_with_warnings(self): + """Unsupported types in the filter list produce warnings and are skipped.""" + supported = ["kubernetes_deployment", "kubernetes_service"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile( + resource_type_filters=[ + "kubernetes_deployment", + "nonexistent_type", + "another_fake", + ] + ) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + # Only the supported type should be discovered + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "kubernetes_deployment" + + # Warnings for unsupported types + assert len(result.warnings) == 2 + warning_text = " ".join(result.warnings) + assert "nonexistent_type" in warning_text + assert "another_fake" in warning_text + + def test_empty_filter_list_results_in_no_resources(self): + """An empty filter list means no types are requested, so nothing is discovered.""" + supported = ["kubernetes_deployment", "kubernetes_service", "kubernetes_ingress"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile(resource_type_filters=[]) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 0 + assert plugin.discovered_types == [] + + def test_all_filters_unsupported_results_in_no_resources(self): + """When all filtered types are unsupported, no resources are discovered.""" + supported = ["kubernetes_deployment", "kubernetes_service"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile( + resource_type_filters=["fake_type_a", "fake_type_b"] + ) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 0 + assert len(result.warnings) == 2 + + +# --------------------------------------------------------------------------- +# Tests: Requirement 6.3 - No filters, discover all supported types +# --------------------------------------------------------------------------- + + +class TestNoFilterAllTypes: + """Requirement 6.3: When no filters specified, discover all supported types.""" + + def test_no_filters_discovers_all_supported_types(self): + """With resource_type_filters=None, all supported types are discovered.""" + supported = ["kubernetes_deployment", "kubernetes_service", "kubernetes_ingress"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile(resource_type_filters=None) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 3 + discovered_types = {r.resource_type for r in result.resources} + assert discovered_types == set(supported) + assert plugin.discovered_types == supported + + def test_no_filters_with_many_supported_types(self): + """All supported types are passed to discover_resources when no filter.""" + supported = [ + "kubernetes_deployment", + "kubernetes_service", + "kubernetes_ingress", + "kubernetes_config_map", + "kubernetes_persistent_volume", + "kubernetes_namespace", + ] + plugin = FakePlugin(supported_types=supported) + profile = make_profile(resource_type_filters=None) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert len(result.resources) == 6 + assert plugin.discovered_types == supported + + def test_no_filters_produces_no_warnings(self): + """When no filters are specified, there should be no filtering warnings.""" + supported = ["kubernetes_deployment", "kubernetes_service"] + plugin = FakePlugin(supported_types=supported) + profile = make_profile(resource_type_filters=None) + scanner = Scanner(profile, plugin) + + result = scanner.scan() + + assert result.warnings == [] diff --git a/tests/unit/test_snapshot_store.py b/tests/unit/test_snapshot_store.py new file mode 100644 index 0000000..3a9f354 --- /dev/null +++ b/tests/unit/test_snapshot_store.py @@ -0,0 +1,245 @@ +"""Unit tests for the SnapshotStore class.""" + +import json +import time +from pathlib import Path + +import pytest + +from iac_reverse.incremental.snapshot_store import SnapshotStore +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + PlatformCategory, + ProviderType, + ScanResult, +) + + +def _make_scan_result( + profile_hash: str = "abc123", + resource_name: str = "test-resource", +) -> ScanResult: + """Create a sample ScanResult for testing.""" + resource = DiscoveredResource( + resource_type="kubernetes_deployment", + unique_id="apps/v1/deployments/default/nginx", + name=resource_name, + provider=ProviderType.KUBERNETES, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.internal.lab:6443", + attributes={"replicas": 3, "image": "nginx:1.25"}, + raw_references=["default/services/nginx-svc"], + ) + return ScanResult( + resources=[resource], + warnings=["some warning"], + errors=[], + scan_timestamp="2024-01-15T10:30:00Z", + profile_hash=profile_hash, + is_partial=False, + ) + + +class TestStoreSnapshot: + """Tests for storing snapshots.""" + + def test_store_creates_file_in_correct_directory(self, tmp_path: Path) -> None: + """Storing a snapshot creates a JSON file in the snapshot directory.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + result = _make_scan_result(profile_hash="prof1") + + store.store_snapshot(result, "prof1") + + files = list(snapshot_dir.iterdir()) + assert len(files) == 1 + assert files[0].name.startswith("prof1_") + assert files[0].name.endswith(".json") + + def test_store_creates_directory_if_not_exists(self, tmp_path: Path) -> None: + """Storing a snapshot creates the snapshot directory if it doesn't exist.""" + snapshot_dir = tmp_path / "nested" / "deep" / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + result = _make_scan_result() + + store.store_snapshot(result, "abc123") + + assert snapshot_dir.exists() + assert len(list(snapshot_dir.iterdir())) == 1 + + def test_stored_file_contains_valid_json(self, tmp_path: Path) -> None: + """The stored snapshot file contains valid JSON with expected fields.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + result = _make_scan_result(profile_hash="prof1") + + store.store_snapshot(result, "prof1") + + files = list(snapshot_dir.iterdir()) + with open(files[0], "r") as f: + data = json.load(f) + + assert data["profile_hash"] == "prof1" + assert data["scan_timestamp"] == "2024-01-15T10:30:00Z" + assert len(data["resources"]) == 1 + assert data["resources"][0]["resource_type"] == "kubernetes_deployment" + assert data["resources"][0]["provider"] == "kubernetes" + assert data["resources"][0]["architecture"] == "aarch64" + + +class TestLoadPrevious: + """Tests for loading previous snapshots.""" + + def test_load_returns_correct_scan_result(self, tmp_path: Path) -> None: + """Loading a previous snapshot returns the correct ScanResult.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + original = _make_scan_result(profile_hash="prof1", resource_name="nginx") + + store.store_snapshot(original, "prof1") + loaded = store.load_previous("prof1") + + assert loaded is not None + assert loaded.profile_hash == "prof1" + assert loaded.scan_timestamp == "2024-01-15T10:30:00Z" + assert loaded.is_partial is False + assert loaded.warnings == ["some warning"] + assert loaded.errors == [] + assert len(loaded.resources) == 1 + + resource = loaded.resources[0] + assert resource.resource_type == "kubernetes_deployment" + assert resource.unique_id == "apps/v1/deployments/default/nginx" + assert resource.name == "nginx" + assert resource.provider == ProviderType.KUBERNETES + assert resource.platform_category == PlatformCategory.CONTAINER_ORCHESTRATION + assert resource.architecture == CpuArchitecture.AARCH64 + assert resource.endpoint == "https://k8s-api.internal.lab:6443" + assert resource.attributes == {"replicas": 3, "image": "nginx:1.25"} + assert resource.raw_references == ["default/services/nginx-svc"] + + def test_load_returns_none_when_no_snapshot_exists(self, tmp_path: Path) -> None: + """Loading when no snapshot exists returns None.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + result = store.load_previous("nonexistent") + + assert result is None + + def test_load_returns_none_when_directory_does_not_exist( + self, tmp_path: Path + ) -> None: + """Loading when the snapshot directory doesn't exist returns None.""" + snapshot_dir = tmp_path / "does_not_exist" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + result = store.load_previous("prof1") + + assert result is None + + def test_load_returns_most_recent_snapshot(self, tmp_path: Path) -> None: + """When multiple snapshots exist, load returns the most recent one.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + # Store first snapshot + result1 = _make_scan_result(profile_hash="prof1", resource_name="first") + store.store_snapshot(result1, "prof1") + time.sleep(1.1) # Ensure different timestamp + + # Store second snapshot + result2 = _make_scan_result(profile_hash="prof1", resource_name="second") + store.store_snapshot(result2, "prof1") + + loaded = store.load_previous("prof1") + assert loaded is not None + assert loaded.resources[0].name == "second" + + +class TestRetention: + """Tests for snapshot retention/pruning.""" + + def test_retains_at_least_two_most_recent_snapshots(self, tmp_path: Path) -> None: + """Only the 2 most recent snapshots are kept per profile.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + # Store 4 snapshots with different timestamps + for i in range(4): + result = _make_scan_result( + profile_hash="prof1", resource_name=f"resource-{i}" + ) + store.store_snapshot(result, "prof1") + time.sleep(1.1) # Ensure different timestamps + + # Should only have 2 files remaining + files = list(snapshot_dir.iterdir()) + assert len(files) == 2 + + # The most recent should be loadable + loaded = store.load_previous("prof1") + assert loaded is not None + assert loaded.resources[0].name == "resource-3" + + def test_two_snapshots_are_not_pruned(self, tmp_path: Path) -> None: + """Exactly 2 snapshots are retained without pruning.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + store.store_snapshot(_make_scan_result(profile_hash="prof1"), "prof1") + time.sleep(1.1) + store.store_snapshot(_make_scan_result(profile_hash="prof1"), "prof1") + + files = list(snapshot_dir.iterdir()) + assert len(files) == 2 + + +class TestMultipleProfiles: + """Tests for multiple profile isolation.""" + + def test_multiple_profiles_do_not_interfere(self, tmp_path: Path) -> None: + """Snapshots from different profiles don't interfere with each other.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + result_a = _make_scan_result(profile_hash="profile_a", resource_name="res-a") + result_b = _make_scan_result(profile_hash="profile_b", resource_name="res-b") + + store.store_snapshot(result_a, "profile_a") + store.store_snapshot(result_b, "profile_b") + + loaded_a = store.load_previous("profile_a") + loaded_b = store.load_previous("profile_b") + + assert loaded_a is not None + assert loaded_a.resources[0].name == "res-a" + assert loaded_b is not None + assert loaded_b.resources[0].name == "res-b" + + def test_pruning_only_affects_matching_profile(self, tmp_path: Path) -> None: + """Pruning for one profile does not remove snapshots from another.""" + snapshot_dir = tmp_path / "snapshots" + store = SnapshotStore(base_dir=str(snapshot_dir)) + + # Store 4 snapshots for profile_a (should prune to 2) + for i in range(4): + result = _make_scan_result( + profile_hash="profile_a", resource_name=f"a-{i}" + ) + store.store_snapshot(result, "profile_a") + time.sleep(1.1) + + # Store 1 snapshot for profile_b + result_b = _make_scan_result(profile_hash="profile_b", resource_name="b-0") + store.store_snapshot(result_b, "profile_b") + + # profile_a should have 2 files, profile_b should have 1 + all_files = list(snapshot_dir.iterdir()) + a_files = [f for f in all_files if f.name.startswith("profile_a_")] + b_files = [f for f in all_files if f.name.startswith("profile_b_")] + + assert len(a_files) == 2 + assert len(b_files) == 1 diff --git a/tests/unit/test_state_builder.py b/tests/unit/test_state_builder.py new file mode 100644 index 0000000..0cd17a4 --- /dev/null +++ b/tests/unit/test_state_builder.py @@ -0,0 +1,681 @@ +"""Unit tests for the StateBuilder.""" + +import json +import uuid + +import pytest + +from iac_reverse.models import ( + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + GeneratedFile, + PlatformCategory, + ProviderType, + ResourceRelationship, + StateEntry, + StateFile, +) +from iac_reverse.state_builder import StateBuilder + + +# --------------------------------------------------------------------------- +# Helpers / Fixtures +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "apps/v1/deployments/default/nginx", + name: str = "nginx", + raw_references: list[str] | None = None, + attributes: dict | None = None, + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {"replicas": 3, "image": "nginx:1.25"}, + raw_references=raw_references or [], + ) + + +def make_code_generation_result() -> CodeGenerationResult: + """Create a minimal CodeGenerationResult for testing.""" + return CodeGenerationResult( + resource_files=[ + GeneratedFile( + filename="kubernetes_deployment.tf", + content='resource "kubernetes_deployment" "nginx" {}', + resource_count=1, + ) + ], + variables_file=GeneratedFile( + filename="variables.tf", content="", resource_count=0 + ), + provider_file=GeneratedFile( + filename="providers.tf", content="", resource_count=0 + ), + ) + + +def make_dependency_graph( + resources: list[DiscoveredResource], + relationships: list[ResourceRelationship] | None = None, +) -> DependencyGraph: + """Create a DependencyGraph from resources and optional relationships.""" + return DependencyGraph( + resources=resources, + relationships=relationships or [], + topological_order=[r.unique_id for r in resources], + cycles=[], + unresolved_references=[], + ) + + +# --------------------------------------------------------------------------- +# Tests: Single resource produces valid state entry +# --------------------------------------------------------------------------- + + +class TestSingleResource: + """Tests for state generation with a single resource.""" + + def test_single_resource_produces_one_state_entry(self): + """A single resource in the graph produces exactly one state entry.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="3.2.1") + + assert len(state.resources) == 1 + + def test_state_entry_has_correct_resource_type(self): + """State entry resource_type matches the discovered resource type.""" + resource = make_resource(resource_type="kubernetes_deployment") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.resources[0].resource_type == "kubernetes_deployment" + + def test_state_entry_has_sanitized_resource_name(self): + """State entry resource_name is a valid Terraform identifier.""" + resource = make_resource(name="my-nginx-app") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + # Hyphens should be replaced with underscores + assert state.resources[0].resource_name == "my_nginx_app" + + def test_state_entry_provider_id_is_unique_id(self): + """State entry provider_id is the live infrastructure unique_id.""" + resource = make_resource(unique_id="apps/v1/deployments/default/nginx") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.resources[0].provider_id == "apps/v1/deployments/default/nginx" + + def test_state_entry_attributes_from_discovery(self): + """State entry attributes contain the full discovery attribute set.""" + attrs = {"replicas": 3, "image": "nginx:1.25", "namespace": "default"} + resource = make_resource(attributes=attrs) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.resources[0].attributes == attrs + + +# --------------------------------------------------------------------------- +# Tests: Multiple resources produce multiple state entries +# --------------------------------------------------------------------------- + + +class TestMultipleResources: + """Tests for state generation with multiple resources.""" + + def test_multiple_resources_produce_multiple_entries(self): + """Each resource in the graph produces a corresponding state entry.""" + resources = [ + make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ), + make_resource( + resource_type="kubernetes_service", + unique_id="svc/nginx-svc", + name="nginx-svc", + ), + make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ), + ] + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="2.0.0") + + assert len(state.resources) == 3 + + def test_multiple_resources_have_distinct_entries(self): + """Each state entry corresponds to a different resource.""" + resources = [ + make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ), + make_resource( + resource_type="kubernetes_service", + unique_id="svc/redis", + name="redis", + ), + ] + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + types = {e.resource_type for e in state.resources} + assert "kubernetes_deployment" in types + assert "kubernetes_service" in types + + +# --------------------------------------------------------------------------- +# Tests: Lineage is a valid UUID +# --------------------------------------------------------------------------- + + +class TestLineage: + """Tests for state file lineage UUID generation.""" + + def test_lineage_is_valid_uuid(self): + """The state file lineage is a valid UUID string.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + # Should not raise ValueError + parsed = uuid.UUID(state.lineage) + assert str(parsed) == state.lineage + + def test_lineage_is_unique_per_build(self): + """Each build produces a different lineage UUID.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state1 = builder.build(code_result, graph, provider_version="1.0.0") + state2 = builder.build(code_result, graph, provider_version="1.0.0") + + assert state1.lineage != state2.lineage + + +# --------------------------------------------------------------------------- +# Tests: Dependencies are included as Terraform resource addresses +# --------------------------------------------------------------------------- + + +class TestDependencies: + """Tests for dependency references in state entries.""" + + def test_dependencies_as_terraform_addresses(self): + """Dependencies are formatted as resource_type.resource_name.""" + namespace = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + raw_references=["ns/default"], + ) + relationship = ResourceRelationship( + source_id="deploy/nginx", + target_id="ns/default", + relationship_type="dependency", + source_attribute="namespace", + ) + graph = make_dependency_graph( + [namespace, deployment], relationships=[relationship] + ) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + # Find the deployment entry + deploy_entry = next( + e for e in state.resources if e.resource_type == "kubernetes_deployment" + ) + assert "kubernetes_namespace.default" in deploy_entry.dependencies + + def test_resource_without_dependencies_has_empty_list(self): + """A resource with no relationships has an empty dependencies list.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.resources[0].dependencies == [] + + def test_multiple_dependencies_all_included(self): + """All dependency relationships are included in the state entry.""" + ns = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + svc = make_resource( + resource_type="kubernetes_service", + unique_id="svc/nginx-svc", + name="nginx-svc", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ) + relationships = [ + ResourceRelationship( + source_id="deploy/nginx", + target_id="ns/default", + relationship_type="dependency", + source_attribute="namespace", + ), + ResourceRelationship( + source_id="deploy/nginx", + target_id="svc/nginx-svc", + relationship_type="reference", + source_attribute="service", + ), + ] + graph = make_dependency_graph([ns, svc, deployment], relationships=relationships) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + deploy_entry = next( + e for e in state.resources if e.resource_type == "kubernetes_deployment" + ) + assert len(deploy_entry.dependencies) == 2 + assert "kubernetes_namespace.default" in deploy_entry.dependencies + assert "kubernetes_service.nginx_svc" in deploy_entry.dependencies + + +# --------------------------------------------------------------------------- +# Tests: Sensitive attributes are marked +# --------------------------------------------------------------------------- + + +class TestSensitiveAttributes: + """Tests for sensitive attribute detection.""" + + def test_password_attribute_marked_sensitive(self): + """Attributes containing 'password' are marked sensitive.""" + resource = make_resource( + attributes={"db_password": "secret123", "name": "mydb"} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert "db_password" in state.resources[0].sensitive_attributes + + def test_token_attribute_marked_sensitive(self): + """Attributes containing 'token' are marked sensitive.""" + resource = make_resource( + attributes={"api_token": "abc123", "endpoint": "https://api.local"} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert "api_token" in state.resources[0].sensitive_attributes + + def test_secret_attribute_marked_sensitive(self): + """Attributes containing 'secret' are marked sensitive.""" + resource = make_resource( + attributes={"client_secret": "xyz", "client_id": "app1"} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert "client_secret" in state.resources[0].sensitive_attributes + + def test_key_attribute_marked_sensitive(self): + """Attributes containing 'key' are marked sensitive.""" + resource = make_resource( + attributes={"private_key": "-----BEGIN RSA KEY-----", "name": "cert1"} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert "private_key" in state.resources[0].sensitive_attributes + + def test_certificate_attribute_marked_sensitive(self): + """Attributes containing 'certificate' are marked sensitive.""" + resource = make_resource( + attributes={"tls_certificate": "cert-data", "port": 443} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert "tls_certificate" in state.resources[0].sensitive_attributes + + def test_non_sensitive_attributes_not_marked(self): + """Attributes without sensitive patterns are not marked.""" + resource = make_resource( + attributes={"replicas": 3, "image": "nginx:1.25", "namespace": "default"} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.resources[0].sensitive_attributes == [] + + def test_nested_sensitive_attributes_detected(self): + """Sensitive attributes in nested dicts are detected.""" + resource = make_resource( + attributes={ + "config": {"database_password": "secret", "host": "localhost"} + } + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert "config.database_password" in state.resources[0].sensitive_attributes + + +# --------------------------------------------------------------------------- +# Tests: Schema version is set from provider_version +# --------------------------------------------------------------------------- + + +class TestSchemaVersion: + """Tests for schema_version setting from provider_version.""" + + def test_schema_version_from_major_version(self): + """Schema version is the major version number from provider_version.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="3.2.1") + + assert state.resources[0].schema_version == 3 + + def test_schema_version_single_digit(self): + """Schema version works with a single digit version string.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1") + + assert state.resources[0].schema_version == 1 + + def test_schema_version_defaults_to_zero_on_invalid(self): + """Schema version defaults to 0 if provider_version is unparseable.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="invalid") + + assert state.resources[0].schema_version == 0 + + +# --------------------------------------------------------------------------- +# Tests: to_json() produces valid JSON with correct structure +# --------------------------------------------------------------------------- + + +class TestToJson: + """Tests for state file JSON serialization.""" + + def test_to_json_produces_valid_json(self): + """to_json() output is valid JSON.""" + resource = make_resource( + attributes={"replicas": 3, "image": "nginx:1.25"} + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + json_str = state.to_json() + parsed = json.loads(json_str) + assert isinstance(parsed, dict) + + def test_to_json_has_version_4(self): + """JSON output has version field set to 4.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + parsed = json.loads(state.to_json()) + assert parsed["version"] == 4 + + def test_to_json_has_serial_1(self): + """JSON output has serial field set to 1.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + parsed = json.loads(state.to_json()) + assert parsed["serial"] == 1 + + def test_to_json_has_terraform_version(self): + """JSON output includes the terraform_version.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder(terraform_version="1.7.0") + state = builder.build(code_result, graph, provider_version="1.0.0") + + parsed = json.loads(state.to_json()) + assert parsed["terraform_version"] == "1.7.0" + + def test_to_json_has_valid_lineage_uuid(self): + """JSON output lineage is a valid UUID.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + parsed = json.loads(state.to_json()) + # Should not raise ValueError + uuid.UUID(parsed["lineage"]) + + def test_to_json_resources_have_correct_structure(self): + """JSON resources have mode, type, name, provider, and instances.""" + resource = make_resource( + resource_type="kubernetes_deployment", + name="nginx", + attributes={"replicas": 3}, + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="2.0.0") + + parsed = json.loads(state.to_json()) + assert len(parsed["resources"]) == 1 + + res = parsed["resources"][0] + assert res["mode"] == "managed" + assert res["type"] == "kubernetes_deployment" + assert res["name"] == "nginx" + assert "provider" in res + assert len(res["instances"]) == 1 + + def test_to_json_instance_has_attributes_with_id(self): + """JSON instance attributes include the provider_id as 'id'.""" + resource = make_resource( + unique_id="apps/v1/deployments/default/nginx", + attributes={"replicas": 3}, + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + parsed = json.loads(state.to_json()) + instance = parsed["resources"][0]["instances"][0] + assert instance["attributes"]["id"] == "apps/v1/deployments/default/nginx" + assert instance["attributes"]["replicas"] == 3 + + def test_to_json_instance_has_schema_version(self): + """JSON instance includes schema_version.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="3.0.0") + + parsed = json.loads(state.to_json()) + instance = parsed["resources"][0]["instances"][0] + assert instance["schema_version"] == 3 + + def test_to_json_instance_has_dependencies(self): + """JSON instance includes dependencies list.""" + ns = make_resource( + resource_type="kubernetes_namespace", + unique_id="ns/default", + name="default", + ) + deployment = make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ) + relationship = ResourceRelationship( + source_id="deploy/nginx", + target_id="ns/default", + relationship_type="dependency", + source_attribute="namespace", + ) + graph = make_dependency_graph([ns, deployment], relationships=[relationship]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + parsed = json.loads(state.to_json()) + deploy_res = next( + r for r in parsed["resources"] if r["type"] == "kubernetes_deployment" + ) + instance = deploy_res["instances"][0] + assert "kubernetes_namespace.default" in instance["dependencies"] + + +# --------------------------------------------------------------------------- +# Tests: State file metadata +# --------------------------------------------------------------------------- + + +class TestStateFileMetadata: + """Tests for state file top-level metadata.""" + + def test_version_is_4(self): + """State file version is always 4.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.version == 4 + + def test_serial_is_1(self): + """State file serial is 1 for initial generation.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.serial == 1 + + def test_custom_terraform_version(self): + """StateBuilder accepts a custom terraform_version.""" + resource = make_resource() + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder(terraform_version="1.8.0") + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert state.terraform_version == "1.8.0" diff --git a/tests/unit/test_state_builder_unmapped.py b/tests/unit/test_state_builder_unmapped.py new file mode 100644 index 0000000..bc5fb51 --- /dev/null +++ b/tests/unit/test_state_builder_unmapped.py @@ -0,0 +1,468 @@ +"""Unit tests for unmapped resource handling in StateBuilder. + +Tests that resources with missing provider-assigned identifiers or +unrecognized resource types are excluded from the state file, warnings +are logged, and the unmapped resources list is correctly populated. +""" + +import logging + +import pytest + +from iac_reverse.models import ( + CodeGenerationResult, + CpuArchitecture, + DependencyGraph, + DiscoveredResource, + GeneratedFile, + PlatformCategory, + ProviderType, + ResourceRelationship, +) +from iac_reverse.state_builder import StateBuilder + + +# --------------------------------------------------------------------------- +# Helpers / Fixtures +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "apps/v1/deployments/default/nginx", + name: str = "nginx", + attributes: dict | None = None, + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=CpuArchitecture.AARCH64, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {"replicas": 3, "image": "nginx:1.25"}, + raw_references=[], + ) + + +def make_code_generation_result() -> CodeGenerationResult: + """Create a minimal CodeGenerationResult for testing.""" + return CodeGenerationResult( + resource_files=[ + GeneratedFile( + filename="kubernetes_deployment.tf", + content='resource "kubernetes_deployment" "nginx" {}', + resource_count=1, + ) + ], + variables_file=GeneratedFile( + filename="variables.tf", content="", resource_count=0 + ), + provider_file=GeneratedFile( + filename="providers.tf", content="", resource_count=0 + ), + ) + + +def make_dependency_graph( + resources: list[DiscoveredResource], + relationships: list[ResourceRelationship] | None = None, +) -> DependencyGraph: + """Create a DependencyGraph from resources and optional relationships.""" + return DependencyGraph( + resources=resources, + relationships=relationships or [], + topological_order=[r.unique_id for r in resources], + cycles=[], + unresolved_references=[], + ) + + +# --------------------------------------------------------------------------- +# Tests: Resource with empty unique_id is excluded from state +# --------------------------------------------------------------------------- + + +class TestEmptyUniqueIdExcluded: + """Resources with empty unique_id are excluded from state.""" + + def test_empty_string_unique_id_excluded(self): + """A resource with empty string unique_id produces no state entry.""" + resource = make_resource(unique_id="") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + def test_whitespace_only_unique_id_excluded(self): + """A resource with whitespace-only unique_id produces no state entry.""" + resource = make_resource(unique_id=" ") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + def test_tabs_and_newlines_unique_id_excluded(self): + """A resource with tabs/newlines-only unique_id is excluded.""" + resource = make_resource(unique_id="\t\n") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + +# --------------------------------------------------------------------------- +# Tests: Resource with None-like identifier is excluded +# --------------------------------------------------------------------------- + + +class TestNoneLikeIdentifierExcluded: + """Resources with None-like identifiers are excluded from state.""" + + def test_none_unique_id_excluded(self): + """A resource with None unique_id produces no state entry.""" + resource = make_resource() + # Manually set unique_id to None (bypassing type hints) + resource.unique_id = None # type: ignore[assignment] + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + def test_empty_string_is_falsy_excluded(self): + """Empty string is falsy and should be excluded.""" + resource = make_resource(unique_id="") + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + +# --------------------------------------------------------------------------- +# Tests: Unrecognized resource type is excluded +# --------------------------------------------------------------------------- + + +class TestUnrecognizedResourceTypeExcluded: + """Resources with unrecognized resource types are excluded.""" + + def test_unknown_resource_type_excluded(self): + """A resource with an unrecognized type produces no state entry.""" + resource = make_resource( + resource_type="totally_unknown_type", + unique_id="some/valid/id", + name="mystery", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + def test_misspelled_resource_type_excluded(self): + """A misspelled resource type is excluded.""" + resource = make_resource( + resource_type="kuberntes_deployment", # typo + unique_id="deploy/nginx", + name="nginx", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 0 + + +# --------------------------------------------------------------------------- +# Tests: Warning is logged for unmapped resources +# --------------------------------------------------------------------------- + + +class TestWarningLogged: + """Warnings are logged for unmapped resources.""" + + def test_warning_logged_for_empty_unique_id(self, caplog): + """A warning is logged when a resource has empty unique_id.""" + resource = make_resource( + unique_id="", + name="orphan-resource", + resource_type="kubernetes_deployment", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + with caplog.at_level(logging.WARNING): + builder.build(code_result, graph, provider_version="1.0.0") + + assert any("orphan-resource" in record.message for record in caplog.records) + assert any( + "missing provider-assigned resource identifier" in record.message + for record in caplog.records + ) + + def test_warning_logged_for_unrecognized_type(self, caplog): + """A warning is logged when a resource has unrecognized type.""" + resource = make_resource( + resource_type="alien_resource", + unique_id="some/id", + name="alien", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + with caplog.at_level(logging.WARNING): + builder.build(code_result, graph, provider_version="1.0.0") + + assert any("alien" in record.message for record in caplog.records) + assert any( + "not recognized" in record.message for record in caplog.records + ) + + def test_warning_includes_resource_type_and_name(self, caplog): + """Warning message identifies the resource by type and name.""" + resource = make_resource( + resource_type="unknown_type", + unique_id="id/123", + name="my-resource", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + with caplog.at_level(logging.WARNING): + builder.build(code_result, graph, provider_version="1.0.0") + + # The warning should contain the resource identifier (type.name) + assert any( + "unknown_type.my-resource" in record.message + for record in caplog.records + ) + + +# --------------------------------------------------------------------------- +# Tests: Mapped resources still produce valid state entries alongside unmapped +# --------------------------------------------------------------------------- + + +class TestMappedAlongsideUnmapped: + """Mapped resources produce valid entries even when unmapped ones exist.""" + + def test_valid_resource_produces_entry_alongside_unmapped(self): + """A valid resource still gets a state entry when others are unmapped.""" + valid_resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ) + unmapped_resource = make_resource( + resource_type="kubernetes_service", + unique_id="", # empty - unmappable + name="orphan-svc", + ) + graph = make_dependency_graph([valid_resource, unmapped_resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 1 + assert state.resources[0].resource_type == "kubernetes_deployment" + assert state.resources[0].provider_id == "deploy/nginx" + + def test_multiple_valid_resources_with_one_unmapped(self): + """Multiple valid resources produce entries; unmapped one is excluded.""" + resources = [ + make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ), + make_resource( + resource_type="kubernetes_service", + unique_id="svc/redis", + name="redis", + ), + make_resource( + resource_type="unknown_type", + unique_id="id/mystery", + name="mystery", + ), + ] + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="1.0.0") + + assert len(state.resources) == 2 + types = {e.resource_type for e in state.resources} + assert "kubernetes_deployment" in types + assert "kubernetes_service" in types + assert "unknown_type" not in types + + def test_valid_entries_have_correct_attributes(self): + """Valid entries retain full attributes even when unmapped exist.""" + valid_resource = make_resource( + resource_type="docker_service", + unique_id="svc/web", + name="web", + attributes={"replicas": 2, "image": "web:latest"}, + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ) + unmapped_resource = make_resource( + resource_type="docker_service", + unique_id="", + name="broken", + provider=ProviderType.DOCKER_SWARM, + platform_category=PlatformCategory.CONTAINER_ORCHESTRATION, + ) + graph = make_dependency_graph([valid_resource, unmapped_resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + state = builder.build(code_result, graph, provider_version="2.0.0") + + assert len(state.resources) == 1 + assert state.resources[0].attributes == {"replicas": 2, "image": "web:latest"} + + +# --------------------------------------------------------------------------- +# Tests: Unmapped resources list contains correct entries +# --------------------------------------------------------------------------- + + +class TestUnmappedResourcesList: + """The unmapped_resources property contains correct entries.""" + + def test_unmapped_list_empty_when_all_mapped(self): + """When all resources are mappable, unmapped list is empty.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + builder.build(code_result, graph, provider_version="1.0.0") + + assert builder.unmapped_resources == [] + + def test_unmapped_list_contains_empty_id_resource(self): + """Resource with empty unique_id appears in unmapped list.""" + resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="", + name="orphan", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + builder.build(code_result, graph, provider_version="1.0.0") + + assert len(builder.unmapped_resources) == 1 + identifier, reason = builder.unmapped_resources[0] + assert "kubernetes_deployment.orphan" == identifier + assert "missing provider-assigned resource identifier" in reason + + def test_unmapped_list_contains_unrecognized_type_resource(self): + """Resource with unrecognized type appears in unmapped list.""" + resource = make_resource( + resource_type="alien_widget", + unique_id="widget/123", + name="my-widget", + ) + graph = make_dependency_graph([resource]) + code_result = make_code_generation_result() + + builder = StateBuilder() + builder.build(code_result, graph, provider_version="1.0.0") + + assert len(builder.unmapped_resources) == 1 + identifier, reason = builder.unmapped_resources[0] + assert "alien_widget.my-widget" == identifier + assert "not recognized" in reason + + def test_unmapped_list_contains_multiple_entries(self): + """Multiple unmapped resources all appear in the list.""" + resources = [ + make_resource( + resource_type="kubernetes_deployment", + unique_id="", + name="no-id", + ), + make_resource( + resource_type="fake_type", + unique_id="fake/id", + name="fake", + ), + make_resource( + resource_type="kubernetes_service", + unique_id="svc/valid", + name="valid", + ), + ] + graph = make_dependency_graph(resources) + code_result = make_code_generation_result() + + builder = StateBuilder() + builder.build(code_result, graph, provider_version="1.0.0") + + assert len(builder.unmapped_resources) == 2 + identifiers = [entry[0] for entry in builder.unmapped_resources] + assert "kubernetes_deployment.no-id" in identifiers + assert "fake_type.fake" in identifiers + + def test_unmapped_list_resets_on_new_build(self): + """The unmapped list is reset on each new build call.""" + unmapped_resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="", + name="orphan", + ) + valid_resource = make_resource( + resource_type="kubernetes_deployment", + unique_id="deploy/nginx", + name="nginx", + ) + code_result = make_code_generation_result() + + builder = StateBuilder() + + # First build with unmapped resource + graph1 = make_dependency_graph([unmapped_resource]) + builder.build(code_result, graph1, provider_version="1.0.0") + assert len(builder.unmapped_resources) == 1 + + # Second build with valid resource - unmapped list should reset + graph2 = make_dependency_graph([valid_resource]) + builder.build(code_result, graph2, provider_version="1.0.0") + assert len(builder.unmapped_resources) == 0 diff --git a/tests/unit/test_synology_plugin.py b/tests/unit/test_synology_plugin.py new file mode 100644 index 0000000..d1d92ae --- /dev/null +++ b/tests/unit/test_synology_plugin.py @@ -0,0 +1,499 @@ +"""Unit tests for the Synology DSM provider plugin.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + PlatformCategory, + ScanProgress, +) +from iac_reverse.scanner import AuthenticationError, SynologyPlugin + + +class TestSynologyPluginAuthentication: + """Tests for SynologyPlugin.authenticate().""" + + def test_authenticate_missing_host_raises(self): + """Authentication fails when host is not provided.""" + plugin = SynologyPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"port": "5001", "username": "admin", "password": "pass"}) + assert "synology" in str(exc_info.value).lower() + assert "host" in str(exc_info.value).lower() + + def test_authenticate_missing_username_raises(self): + """Authentication fails when username is not provided.""" + plugin = SynologyPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"host": "nas01", "port": "5001", "password": "pass"}) + assert "username" in str(exc_info.value).lower() + + def test_authenticate_missing_password_raises(self): + """Authentication fails when password is not provided.""" + plugin = SynologyPlugin() + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({"host": "nas01", "port": "5001", "username": "admin"}) + assert "password" in str(exc_info.value).lower() + + @patch("iac_reverse.scanner.synology_plugin.SynologyDSM") + def test_authenticate_success(self, mock_dsm_class): + """Successful authentication sets internal state.""" + mock_api = MagicMock() + mock_api.login.return_value = True + mock_dsm_class.return_value = mock_api + + plugin = SynologyPlugin() + plugin.authenticate({ + "host": "nas01.local", + "port": "5001", + "username": "admin", + "password": "secret", + "use_ssl": "true", + }) + + assert plugin._authenticated is True + mock_dsm_class.assert_called_once_with( + "nas01.local", 5001, "admin", "secret", use_https=True, verify_ssl=False + ) + + @patch("iac_reverse.scanner.synology_plugin.SynologyDSM") + def test_authenticate_login_failure(self, mock_dsm_class): + """Authentication raises when login returns False.""" + mock_api = MagicMock() + mock_api.login.return_value = False + mock_dsm_class.return_value = mock_api + + plugin = SynologyPlugin() + + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({ + "host": "nas01.local", + "port": "5001", + "username": "admin", + "password": "wrong", + }) + assert "synology" in str(exc_info.value).lower() + assert "login failed" in str(exc_info.value).lower() + + @patch("iac_reverse.scanner.synology_plugin.SynologyDSM") + def test_authenticate_connection_error(self, mock_dsm_class): + """Authentication raises on connection error.""" + mock_dsm_class.side_effect = ConnectionError("Connection refused") + + plugin = SynologyPlugin() + + with pytest.raises(AuthenticationError) as exc_info: + plugin.authenticate({ + "host": "nas01.local", + "port": "5001", + "username": "admin", + "password": "secret", + }) + assert "synology" in str(exc_info.value).lower() + + @patch("iac_reverse.scanner.synology_plugin.SynologyDSM") + def test_authenticate_use_ssl_false(self, mock_dsm_class): + """Authentication respects use_ssl=false.""" + mock_api = MagicMock() + mock_api.login.return_value = True + mock_dsm_class.return_value = mock_api + + plugin = SynologyPlugin() + plugin.authenticate({ + "host": "nas01.local", + "port": "5000", + "username": "admin", + "password": "secret", + "use_ssl": "false", + }) + + mock_dsm_class.assert_called_once_with( + "nas01.local", 5000, "admin", "secret", use_https=False, verify_ssl=False + ) + + +class TestSynologyPluginMetadata: + """Tests for SynologyPlugin metadata methods.""" + + def test_get_platform_category(self): + """Returns STORAGE_APPLIANCE category.""" + plugin = SynologyPlugin() + assert plugin.get_platform_category() == PlatformCategory.STORAGE_APPLIANCE + + def test_list_supported_resource_types(self): + """Returns all five Synology resource types.""" + plugin = SynologyPlugin() + types = plugin.list_supported_resource_types() + assert types == [ + "synology_shared_folder", + "synology_volume", + "synology_storage_pool", + "synology_replication_task", + "synology_user", + ] + + def test_list_endpoints_default(self): + """Returns HTTPS endpoint by default.""" + plugin = SynologyPlugin() + plugin._host = "nas01.local" + plugin._port = "5001" + plugin._use_ssl = True + assert plugin.list_endpoints() == ["https://nas01.local:5001"] + + def test_list_endpoints_no_ssl(self): + """Returns HTTP endpoint when SSL is disabled.""" + plugin = SynologyPlugin() + plugin._host = "nas01.local" + plugin._port = "5000" + plugin._use_ssl = False + assert plugin.list_endpoints() == ["http://nas01.local:5000"] + + +class TestSynologyPluginArchitecture: + """Tests for SynologyPlugin.detect_architecture().""" + + def test_detect_architecture_no_api(self): + """Returns AMD64 when no API is connected.""" + plugin = SynologyPlugin() + assert plugin.detect_architecture("https://nas01:5001") == CpuArchitecture.AMD64 + + def test_detect_architecture_arm(self): + """Detects ARM architecture from model info.""" + plugin = SynologyPlugin() + mock_info = MagicMock() + mock_info.model = "DS220j" + mock_info.cpu_hardware_name = "ARM Realtek RTD1296" + plugin._api = MagicMock() + plugin._api.information = mock_info + + result = plugin.detect_architecture("https://nas01:5001") + assert result == CpuArchitecture.ARM + + def test_detect_architecture_aarch64(self): + """Detects AArch64 architecture from model info.""" + plugin = SynologyPlugin() + mock_info = MagicMock() + mock_info.model = "DS923+" + mock_info.cpu_hardware_name = "aarch64 Cortex-A55" + plugin._api = MagicMock() + plugin._api.information = mock_info + + result = plugin.detect_architecture("https://nas01:5001") + assert result == CpuArchitecture.AARCH64 + + def test_detect_architecture_amd64(self): + """Detects AMD64 architecture from model info.""" + plugin = SynologyPlugin() + mock_info = MagicMock() + mock_info.model = "DS1621+" + mock_info.cpu_hardware_name = "AMD Ryzen V1500B" + plugin._api = MagicMock() + plugin._api.information = mock_info + + result = plugin.detect_architecture("https://nas01:5001") + assert result == CpuArchitecture.AMD64 + + def test_detect_architecture_alpine_is_arm(self): + """Detects Alpine (Marvell ARM) as ARM architecture.""" + plugin = SynologyPlugin() + mock_info = MagicMock() + mock_info.model = "DS218j" + mock_info.cpu_hardware_name = "Alpine AL-212" + plugin._api = MagicMock() + plugin._api.information = mock_info + + result = plugin.detect_architecture("https://nas01:5001") + assert result == CpuArchitecture.ARM + + def test_detect_architecture_exception_returns_amd64(self): + """Returns AMD64 on exception during detection.""" + plugin = SynologyPlugin() + plugin._api = MagicMock() + plugin._api.information = property(lambda self: (_ for _ in ()).throw(RuntimeError("fail"))) + # Simulate attribute access raising + type(plugin._api).information = property(lambda self: (_ for _ in ()).throw(RuntimeError("fail"))) + + result = plugin.detect_architecture("https://nas01:5001") + assert result == CpuArchitecture.AMD64 + + +class TestSynologyPluginDiscovery: + """Tests for SynologyPlugin.discover_resources().""" + + def _make_authenticated_plugin(self): + """Create a plugin with mocked API.""" + plugin = SynologyPlugin() + plugin._api = MagicMock() + plugin._authenticated = True + plugin._host = "nas01.local" + plugin._port = "5001" + plugin._use_ssl = True + + # Default: no architecture info + mock_info = MagicMock() + mock_info.model = "DS920+" + mock_info.cpu_hardware_name = "Intel Celeron J4125" + plugin._api.information = mock_info + + return plugin + + def test_discover_shared_folders(self): + """Discovers shared folders from storage API.""" + plugin = self._make_authenticated_plugin() + plugin._api.storage.shares = [ + { + "name": "photos", + "path": "/volume1/photos", + "desc": "Photo library", + "is_encrypted": False, + "enable_recycle_bin": True, + "vol_path": "/volume1", + }, + { + "name": "backups", + "path": "/volume1/backups", + "desc": "Backup storage", + "is_encrypted": True, + "enable_recycle_bin": False, + "vol_path": "/volume1", + }, + ] + + progress_updates = [] + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_shared_folder"], + progress_callback=lambda p: progress_updates.append(p), + ) + + assert len(result.resources) == 2 + assert result.resources[0].resource_type == "synology_shared_folder" + assert result.resources[0].name == "photos" + assert result.resources[0].unique_id == "synology/shared_folder/photos" + assert result.resources[0].provider.value == "synology" + assert result.resources[0].platform_category == PlatformCategory.STORAGE_APPLIANCE + assert result.resources[0].architecture == CpuArchitecture.AMD64 + assert result.resources[0].attributes["encryption"] is False + assert result.resources[1].name == "backups" + assert result.resources[1].attributes["encryption"] is True + + def test_discover_volumes(self): + """Discovers volumes from storage API.""" + plugin = self._make_authenticated_plugin() + plugin._api.storage.volumes = [ + { + "id": "volume_1", + "display_name": "Volume 1", + "status": "normal", + "fs_type": "btrfs", + "size": {"total": "4TB", "used": "2TB"}, + "pool_path": "pool_1", + }, + ] + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_volume"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + vol = result.resources[0] + assert vol.resource_type == "synology_volume" + assert vol.unique_id == "synology/volume/volume_1" + assert vol.name == "Volume 1" + assert vol.attributes["fs_type"] == "btrfs" + assert vol.attributes["size_total"] == "4TB" + assert vol.raw_references == ["synology/storage_pool/pool_1"] + + def test_discover_storage_pools(self): + """Discovers storage pools from storage API.""" + plugin = self._make_authenticated_plugin() + plugin._api.storage.storage_pools = [ + { + "id": "pool_1", + "display_name": "Storage Pool 1", + "status": "normal", + "raid_type": "SHR-2", + "size": {"total": "8TB", "used": "4TB"}, + "disks": ["disk1", "disk2", "disk3", "disk4"], + }, + ] + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_storage_pool"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + pool = result.resources[0] + assert pool.resource_type == "synology_storage_pool" + assert pool.unique_id == "synology/storage_pool/pool_1" + assert pool.attributes["raid_type"] == "SHR-2" + assert pool.attributes["disk_count"] == 4 + + def test_discover_replication_tasks(self): + """Discovers replication tasks from replication API.""" + plugin = self._make_authenticated_plugin() + plugin._api.replication.tasks = [ + { + "id": "repl_1", + "name": "Offsite Backup", + "status": "active", + "type": "snapshot_replication", + "destination": "remote-nas.local", + "schedule": {"frequency": "daily"}, + "shared_folder": "backups", + }, + ] + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_replication_task"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 1 + task = result.resources[0] + assert task.resource_type == "synology_replication_task" + assert task.unique_id == "synology/replication_task/repl_1" + assert task.name == "Offsite Backup" + assert task.attributes["destination"] == "remote-nas.local" + assert task.raw_references == ["synology/shared_folder/backups"] + + def test_discover_users(self): + """Discovers users from users API.""" + plugin = self._make_authenticated_plugin() + plugin._api.users.users = [ + { + "name": "admin", + "description": "System administrator", + "email": "admin@example.com", + "expired": False, + "groups": ["administrators"], + }, + { + "name": "john", + "description": "Regular user", + "email": "john@example.com", + "expired": False, + "groups": ["users"], + }, + ] + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_user"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 2 + assert result.resources[0].resource_type == "synology_user" + assert result.resources[0].name == "admin" + assert result.resources[0].unique_id == "synology/user/admin" + assert result.resources[0].attributes["groups"] == ["administrators"] + assert result.resources[1].name == "john" + + def test_discover_multiple_resource_types(self): + """Discovers multiple resource types in one call.""" + plugin = self._make_authenticated_plugin() + plugin._api.storage.shares = [ + {"name": "data", "path": "/volume1/data", "desc": "", "is_encrypted": False, + "enable_recycle_bin": True, "vol_path": "/volume1"}, + ] + plugin._api.users.users = [ + {"name": "admin", "description": "", "email": "", "expired": False, "groups": []}, + ] + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_shared_folder", "synology_user"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 2 + types = {r.resource_type for r in result.resources} + assert types == {"synology_shared_folder", "synology_user"} + + def test_discover_unsupported_type_adds_warning(self): + """Unsupported resource types produce warnings.""" + plugin = self._make_authenticated_plugin() + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_nonexistent"], + progress_callback=lambda p: None, + ) + + assert len(result.resources) == 0 + assert len(result.warnings) == 1 + assert "synology_nonexistent" in result.warnings[0] + + def test_discover_progress_callback_called(self): + """Progress callback is invoked for each resource type.""" + plugin = self._make_authenticated_plugin() + plugin._api.storage.shares = [] + + progress_updates: list[ScanProgress] = [] + plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_shared_folder", "synology_volume"], + progress_callback=lambda p: progress_updates.append(p), + ) + + # Should have initial + per-type + final updates + assert len(progress_updates) >= 3 + assert progress_updates[0].total_resource_types == 2 + + def test_discover_handles_api_error_gracefully(self): + """API errors for one type don't prevent other types from being discovered.""" + plugin = self._make_authenticated_plugin() + + # Storage raises an error + type(plugin._api).storage = property( + lambda self: (_ for _ in ()).throw(RuntimeError("API error")) + ) + plugin._api.users = MagicMock() + plugin._api.users.users = [ + {"name": "admin", "description": "", "email": "", "expired": False, "groups": []}, + ] + + result = plugin.discover_resources( + endpoints=["https://nas01.local:5001"], + resource_types=["synology_shared_folder", "synology_user"], + progress_callback=lambda p: None, + ) + + # Users should still be discovered even though shared folders errored + assert len(result.errors) == 1 + assert "synology_shared_folder" in result.errors[0] + user_resources = [r for r in result.resources if r.resource_type == "synology_user"] + assert len(user_resources) == 1 + + def test_discover_empty_endpoints_uses_default(self): + """When endpoints list is empty, uses list_endpoints().""" + plugin = self._make_authenticated_plugin() + plugin._api.storage.shares = [] + + result = plugin.discover_resources( + endpoints=[], + resource_types=["synology_shared_folder"], + progress_callback=lambda p: None, + ) + + # Should not raise - uses default endpoint + assert result is not None + + +class TestSynologyPluginIsProviderPlugin: + """Verify SynologyPlugin properly implements ProviderPlugin ABC.""" + + def test_is_instance_of_provider_plugin(self): + """SynologyPlugin is a ProviderPlugin.""" + from iac_reverse.plugin_base import ProviderPlugin + + plugin = SynologyPlugin() + assert isinstance(plugin, ProviderPlugin) diff --git a/tests/unit/test_validator.py b/tests/unit/test_validator.py new file mode 100644 index 0000000..7429dca --- /dev/null +++ b/tests/unit/test_validator.py @@ -0,0 +1,689 @@ +"""Unit tests for the Terraform Validator. + +Tests cover: +- Successful validation (init + validate + plan all pass) +- Missing terraform binary +- terraform init failure +- terraform validate failure with errors +- terraform plan showing drift (planned changes) +- Error parsing from terraform JSON output +""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import PlannedChange, ValidationError, ValidationResult +from iac_reverse.validator import Validator + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_completed_process(returncode=0, stdout="", stderr=""): + """Create a mock CompletedProcess-like object.""" + mock = MagicMock() + mock.returncode = returncode + mock.stdout = stdout + mock.stderr = stderr + return mock + + +VALIDATE_SUCCESS_JSON = json.dumps( + {"valid": True, "error_count": 0, "diagnostics": []} +) + +PLAN_NO_CHANGES_JSON = "\n".join( + [ + json.dumps({"type": "version", "terraform": "1.7.0"}), + json.dumps( + { + "type": "change_summary", + "changes": {"add": 0, "change": 0, "remove": 0}, + } + ), + ] +) + + +# --------------------------------------------------------------------------- +# Tests: missing terraform binary +# --------------------------------------------------------------------------- + + +class TestMissingTerraformBinary: + def test_returns_failure_result_when_terraform_not_found(self, tmp_path): + """When terraform binary is absent, all success flags are False and + a descriptive error is included.""" + with patch("shutil.which", return_value=None): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert isinstance(result, ValidationResult) + assert result.init_success is False + assert result.validate_success is False + assert result.plan_success is False + assert len(result.errors) == 1 + error = result.errors[0] + assert "Terraform" in error.message + assert "required" in error.message.lower() or "PATH" in error.message + + def test_no_planned_changes_when_terraform_not_found(self, tmp_path): + with patch("shutil.which", return_value=None): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.planned_changes == [] + + def test_correction_attempts_zero_when_terraform_not_found(self, tmp_path): + with patch("shutil.which", return_value=None): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.correction_attempts == 0 + + +# --------------------------------------------------------------------------- +# Tests: terraform init failure +# --------------------------------------------------------------------------- + + +class TestTerraformInitFailure: + def test_init_failure_returns_correct_flags(self, tmp_path): + """When terraform init fails, init_success is False and subsequent + stages are not run.""" + init_result = _make_completed_process( + returncode=1, stderr="Error: Failed to install provider" + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", return_value=init_result + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.init_success is False + assert result.validate_success is False + assert result.plan_success is False + + def test_init_failure_includes_error_message(self, tmp_path): + init_result = _make_completed_process( + returncode=1, stderr="Error: Failed to install provider" + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", return_value=init_result + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.errors) >= 1 + assert "terraform init failed" in result.errors[0].message.lower() + assert "Failed to install provider" in result.errors[0].message + + def test_init_failure_stops_pipeline(self, tmp_path): + """After init failure, validate and plan should not be called.""" + init_result = _make_completed_process(returncode=1, stderr="init error") + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", return_value=init_result + ) as mock_run: + validator = Validator() + validator.validate(str(tmp_path)) + + # Only one subprocess.run call (for init) + assert mock_run.call_count == 1 + + +# --------------------------------------------------------------------------- +# Tests: successful validation +# --------------------------------------------------------------------------- + + +class TestSuccessfulValidation: + def test_all_flags_true_on_success(self, tmp_path): + """When init, validate, and plan all succeed with zero changes, + all success flags are True.""" + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.init_success is True + assert result.validate_success is True + assert result.plan_success is True + + def test_no_errors_on_success(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.errors == [] + + def test_no_planned_changes_on_success(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.planned_changes == [] + + def test_correction_attempts_zero_on_success(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.correction_attempts == 0 + + +# --------------------------------------------------------------------------- +# Tests: terraform validate failure with errors +# --------------------------------------------------------------------------- + + +class TestTerraformValidateFailure: + def _make_validate_error_json( + self, filename="main.tf", line=10, summary="Invalid attribute", detail="No such attribute" + ): + return json.dumps( + { + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": summary, + "detail": detail, + "range": { + "filename": filename, + "start": {"line": line, "column": 1}, + "end": {"line": line, "column": 20}, + }, + } + ], + } + ) + + def test_validate_failure_sets_correct_flags(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, stdout=self._make_validate_error_json() + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.init_success is True + assert result.validate_success is False + assert result.plan_success is False + + def test_validate_failure_parses_file_name(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, + stdout=self._make_validate_error_json(filename="kubernetes_deployment.tf"), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.errors) == 1 + assert result.errors[0].file == "kubernetes_deployment.tf" + + def test_validate_failure_parses_line_number(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, + stdout=self._make_validate_error_json(line=42), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.errors[0].line == 42 + + def test_validate_failure_parses_error_message(self, tmp_path): + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, + stdout=self._make_validate_error_json( + summary="Unsupported argument", + detail="An argument named 'foo' is not expected here.", + ), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert "Unsupported argument" in result.errors[0].message + assert "foo" in result.errors[0].message + + def test_validate_failure_multiple_errors(self, tmp_path): + validate_json = json.dumps( + { + "valid": False, + "error_count": 2, + "diagnostics": [ + { + "severity": "error", + "summary": "Error one", + "detail": "", + "range": { + "filename": "main.tf", + "start": {"line": 5, "column": 1}, + }, + }, + { + "severity": "error", + "summary": "Error two", + "detail": "", + "range": { + "filename": "variables.tf", + "start": {"line": 12, "column": 3}, + }, + }, + ], + } + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, stdout=validate_json + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.errors) == 2 + assert result.errors[0].file == "main.tf" + assert result.errors[1].file == "variables.tf" + + def test_validate_ignores_warning_diagnostics(self, tmp_path): + """Only error-severity diagnostics should be included in errors.""" + validate_json = json.dumps( + { + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "warning", + "summary": "Deprecated attribute", + "detail": "Use new_attr instead.", + "range": { + "filename": "main.tf", + "start": {"line": 3, "column": 1}, + }, + }, + { + "severity": "error", + "summary": "Real error", + "detail": "", + "range": { + "filename": "main.tf", + "start": {"line": 7, "column": 1}, + }, + }, + ], + } + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, stdout=validate_json + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.errors) == 1 + assert result.errors[0].message == "Real error" + + def test_validate_failure_stops_plan(self, tmp_path): + """When validate fails, plan should not be run.""" + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, + stdout=self._make_validate_error_json(), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ) as mock_run: + validator = Validator() + validator.validate(str(tmp_path)) + + # Only init and validate calls, no plan + assert mock_run.call_count == 2 + + +# --------------------------------------------------------------------------- +# Tests: terraform plan showing drift +# --------------------------------------------------------------------------- + + +class TestTerraformPlanDrift: + def _make_plan_with_changes(self, changes): + """Build a terraform plan JSON stream with the given changes. + + changes: list of (addr, action) tuples + """ + lines = [json.dumps({"type": "version", "terraform": "1.7.0"})] + for addr, action in changes: + lines.append( + json.dumps( + { + "type": "planned_change", + "change": { + "resource": {"addr": addr}, + "action": action, + }, + } + ) + ) + total_add = sum(1 for _, a in changes if a == "create") + total_change = sum(1 for _, a in changes if a == "update") + total_remove = sum(1 for _, a in changes if a == "delete") + lines.append( + json.dumps( + { + "type": "change_summary", + "changes": { + "add": total_add, + "change": total_change, + "remove": total_remove, + }, + } + ) + ) + return "\n".join(lines) + + def test_plan_with_add_sets_plan_success_false(self, tmp_path): + plan_output = self._make_plan_with_changes( + [("kubernetes_deployment.nginx", "create")] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.plan_success is False + + def test_plan_with_add_reports_change_type_add(self, tmp_path): + plan_output = self._make_plan_with_changes( + [("kubernetes_deployment.nginx", "create")] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.planned_changes) == 1 + change = result.planned_changes[0] + assert change.resource_address == "kubernetes_deployment.nginx" + assert change.change_type == "add" + + def test_plan_with_update_reports_change_type_modify(self, tmp_path): + plan_output = self._make_plan_with_changes( + [("docker_service.web", "update")] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.planned_changes[0].change_type == "modify" + + def test_plan_with_delete_reports_change_type_destroy(self, tmp_path): + plan_output = self._make_plan_with_changes( + [("harvester_virtualmachine.dev_vm", "delete")] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.planned_changes[0].change_type == "destroy" + + def test_plan_with_multiple_changes(self, tmp_path): + plan_output = self._make_plan_with_changes( + [ + ("kubernetes_deployment.nginx", "create"), + ("docker_service.web", "update"), + ("harvester_virtualmachine.old_vm", "delete"), + ] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.planned_changes) == 3 + addresses = {c.resource_address for c in result.planned_changes} + assert "kubernetes_deployment.nginx" in addresses + assert "docker_service.web" in addresses + assert "harvester_virtualmachine.old_vm" in addresses + + def test_plan_with_changes_sets_validate_success_true(self, tmp_path): + """Drift does not affect validate_success — only plan_success.""" + plan_output = self._make_plan_with_changes( + [("kubernetes_deployment.nginx", "create")] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=2, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.init_success is True + assert result.validate_success is True + assert result.plan_success is False + + +# --------------------------------------------------------------------------- +# Tests: JSON parsing edge cases +# --------------------------------------------------------------------------- + + +class TestJsonParsing: + def test_validate_with_invalid_json_output(self, tmp_path): + """When terraform validate returns non-JSON, a parse error is reported.""" + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, stdout="not valid json" + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is False + assert len(result.errors) >= 1 + assert any("parse" in e.message.lower() or "json" in e.message.lower() for e in result.errors) + + def test_validate_error_without_range(self, tmp_path): + """Errors without range info should still be parsed with empty file and no line.""" + validate_json = json.dumps( + { + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": "No range error", + "detail": "Something went wrong", + } + ], + } + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=1, stdout=validate_json + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", side_effect=[init_result, validate_result] + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert len(result.errors) == 1 + assert result.errors[0].file == "" + assert result.errors[0].line is None + assert "No range error" in result.errors[0].message + + def test_plan_with_malformed_lines_skipped(self, tmp_path): + """Malformed JSON lines in plan output should be skipped gracefully.""" + plan_output = "\n".join( + [ + "not json", + json.dumps({"type": "version", "terraform": "1.7.0"}), + "also not json", + json.dumps( + { + "type": "change_summary", + "changes": {"add": 0, "change": 0, "remove": 0}, + } + ), + ] + ) + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process(returncode=0, stdout=plan_output) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + # Should not raise; plan_success True because no changes + assert result.plan_success is True + assert result.planned_changes == [] + + +# --------------------------------------------------------------------------- +# Tests: Validator export +# --------------------------------------------------------------------------- + + +class TestValidatorExport: + def test_validator_importable_from_package(self): + from iac_reverse.validator import Validator as V + + assert V is Validator + + def test_validator_is_instantiable(self): + v = Validator() + assert v is not None diff --git a/tests/unit/test_validator_autocorrect.py b/tests/unit/test_validator_autocorrect.py new file mode 100644 index 0000000..e13838f --- /dev/null +++ b/tests/unit/test_validator_autocorrect.py @@ -0,0 +1,766 @@ +"""Unit tests for the Validator auto-correction loop. + +Tests cover: +- Successful correction on first attempt +- Correction after multiple attempts +- Failure after max attempts exhausted +- correction_attempts count is accurate +- Original errors are preserved when correction fails +""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import ValidationError, ValidationResult +from iac_reverse.validator import Validator + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_completed_process(returncode=0, stdout="", stderr=""): + """Create a mock CompletedProcess-like object.""" + mock = MagicMock() + mock.returncode = returncode + mock.stdout = stdout + mock.stderr = stderr + return mock + + +VALIDATE_SUCCESS_JSON = json.dumps( + {"valid": True, "error_count": 0, "diagnostics": []} +) + +PLAN_NO_CHANGES_JSON = "\n".join( + [ + json.dumps({"type": "version", "terraform": "1.7.0"}), + json.dumps( + { + "type": "change_summary", + "changes": {"add": 0, "change": 0, "remove": 0}, + } + ), + ] +) + + +def _make_validate_error_json( + filename="main.tf", line=10, summary="Unsupported argument", detail="An argument named 'bad_attr' is not expected here." +): + return json.dumps( + { + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": summary, + "detail": detail, + "range": { + "filename": filename, + "start": {"line": line, "column": 1}, + "end": {"line": line, "column": 20}, + }, + } + ], + } + ) + + +def _make_missing_provider_error_json(provider_name="kubernetes"): + return json.dumps( + { + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": "Missing required provider", + "detail": f"provider \"{provider_name}\" configuration not present", + "range": { + "filename": "main.tf", + "start": {"line": 1, "column": 1}, + "end": {"line": 1, "column": 20}, + }, + } + ], + } + ) + + +# --------------------------------------------------------------------------- +# Tests: Successful correction on first attempt +# --------------------------------------------------------------------------- + + +class TestSuccessfulCorrectionFirstAttempt: + """Validation fails initially but succeeds after one correction attempt.""" + + def test_removes_unknown_attribute_and_passes(self, tmp_path): + """When an unknown attribute error occurs, the offending line is removed + and re-validation succeeds on the first attempt.""" + # Create a .tf file with a bad attribute on line 3 + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' ami = "ami-123"\n' + ' bad_attr = "should be removed"\n' + ' instance_type = "t2.micro"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + # First validate fails with unknown attribute on line 3 + validate_fail = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=3, + summary="Unsupported argument", + detail="An argument named 'bad_attr' is not expected here.", + ), + ) + # Second validate succeeds after correction + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_fail, validate_success, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.init_success is True + assert result.validate_success is True + assert result.correction_attempts == 1 + + # Verify the file was corrected + content = tf_file.read_text() + assert "bad_attr" not in content + assert "ami" in content + assert "instance_type" in content + + def test_adds_missing_provider_block_and_passes(self, tmp_path): + """When a missing provider error occurs, an empty provider block is added + and re-validation succeeds.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "kubernetes_deployment" "app" {\n' + ' metadata {\n' + ' name = "app"\n' + ' }\n' + '}\n' + ) + + init_result = _make_completed_process(returncode=0) + validate_fail = _make_completed_process( + returncode=1, + stdout=_make_missing_provider_error_json("kubernetes"), + ) + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_fail, validate_success, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is True + assert result.correction_attempts == 1 + + # Verify provider block was added + providers_file = tmp_path / "providers.tf" + assert providers_file.exists() + content = providers_file.read_text() + assert 'provider "kubernetes"' in content + + +# --------------------------------------------------------------------------- +# Tests: Correction after multiple attempts +# --------------------------------------------------------------------------- + + +class TestCorrectionAfterMultipleAttempts: + """Validation requires multiple correction attempts before succeeding.""" + + def test_succeeds_after_two_correction_attempts(self, tmp_path): + """Two different errors require two correction passes.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' ami = "ami-123"\n' + ' bad_attr1 = "remove me"\n' + ' instance_type = "t2.micro"\n' + ' bad_attr2 = "also remove me"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + + # First validate: error on line 3 (bad_attr1) + validate_fail_1 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=3, + summary="Unsupported argument", + detail="An argument named 'bad_attr1' is not expected here.", + ), + ) + # Second validate: error on line 4 (bad_attr2 is now on line 4 after removal) + validate_fail_2 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=4, + summary="Unsupported argument", + detail="An argument named 'bad_attr2' is not expected here.", + ), + ) + # Third validate succeeds + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + validate_fail_1, + validate_fail_2, + validate_success, + plan_result, + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is True + assert result.correction_attempts == 2 + + def test_succeeds_on_third_attempt(self, tmp_path): + """Validation succeeds on the third (max) correction attempt.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' ami = "ami-123"\n' + ' bad1 = "x"\n' + ' bad2 = "y"\n' + ' bad3 = "z"\n' + ' instance_type = "t2.micro"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + + validate_fail_1 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", line=3, + summary="Unsupported argument", + detail="An argument named 'bad1' is not expected here.", + ), + ) + validate_fail_2 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", line=3, + summary="Unsupported argument", + detail="An argument named 'bad2' is not expected here.", + ), + ) + validate_fail_3 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", line=3, + summary="Unsupported argument", + detail="An argument named 'bad3' is not expected here.", + ), + ) + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + validate_fail_1, + validate_fail_2, + validate_fail_3, + validate_success, + plan_result, + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is True + assert result.correction_attempts == 3 + + +# --------------------------------------------------------------------------- +# Tests: Failure after max attempts exhausted +# --------------------------------------------------------------------------- + + +class TestFailureAfterMaxAttempts: + """Validation still fails after all correction attempts are exhausted.""" + + def test_fails_after_max_attempts_with_uncorrectable_error(self, tmp_path): + """When errors cannot be corrected, fails after max attempts.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' ami = "ami-123"\n' + ' bad_attr = "remove me"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + + # Each validate returns the same error (correction removes line but + # new errors keep appearing) + validate_fail = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=3, + summary="Unsupported argument", + detail="An argument named 'bad_attr' is not expected here.", + ), + ) + # After first correction removes line 3, subsequent validates still fail + # with a different uncorrectable error (no file/line info) + validate_fail_no_fix = _make_completed_process( + returncode=1, + stdout=json.dumps({ + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": "Some complex error", + "detail": "Cannot be auto-corrected", + } + ], + }), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + validate_fail, + validate_fail_no_fix, # After first correction, new error with no file + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is False + # Stopped after 1 attempt because the second error has no file info + # and cannot be corrected + assert result.correction_attempts >= 1 + + def test_fails_with_max_3_attempts_default(self, tmp_path): + """With default max_correction_attempts=3, stops after 3 attempts.""" + tf_file = tmp_path / "main.tf" + # Write a file where the error line keeps being valid for removal + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' attr1 = "a"\n' + ' attr2 = "b"\n' + ' attr3 = "c"\n' + ' attr4 = "d"\n' + ' attr5 = "e"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + + def make_fail(line, attr): + return _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=line, + summary="Unsupported argument", + detail=f"An argument named '{attr}' is not expected here.", + ), + ) + + # 4 failures: first 3 get corrected, 4th is returned as final failure + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + make_fail(2, "attr1"), + make_fail(2, "attr2"), + make_fail(2, "attr3"), + make_fail(2, "attr4"), # This one exceeds max attempts + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is False + assert result.correction_attempts == 3 + + def test_custom_max_attempts_respected(self, tmp_path): + """Custom max_correction_attempts value is respected.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' attr1 = "a"\n' + ' attr2 = "b"\n' + ' attr3 = "c"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + + def make_fail(line, attr): + return _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=line, + summary="Unsupported argument", + detail=f"An argument named '{attr}' is not expected here.", + ), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + make_fail(2, "attr1"), + make_fail(2, "attr2"), # Exceeds max_correction_attempts=1 + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path), max_correction_attempts=1) + + assert result.validate_success is False + assert result.correction_attempts == 1 + + +# --------------------------------------------------------------------------- +# Tests: correction_attempts count is accurate +# --------------------------------------------------------------------------- + + +class TestCorrectionAttemptsCount: + """The correction_attempts field accurately reflects the number of attempts.""" + + def test_zero_attempts_when_validation_passes_immediately(self, tmp_path): + """No correction attempts when validation passes on first try.""" + init_result = _make_completed_process(returncode=0) + validate_result = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_result, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.correction_attempts == 0 + + def test_one_attempt_when_first_correction_fixes(self, tmp_path): + """correction_attempts is 1 when first correction resolves the issue.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "null_resource" "test" {\n' + ' unknown_field = "value"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + validate_fail = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=2, + summary="Unsupported argument", + detail="An argument named 'unknown_field' is not expected here.", + ), + ) + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_fail, validate_success, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.correction_attempts == 1 + + def test_attempts_count_matches_actual_corrections_applied(self, tmp_path): + """correction_attempts matches the number of correction loops executed.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "null_resource" "test" {\n' + ' bad1 = "x"\n' + ' bad2 = "y"\n' + ' good = "keep"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + validate_fail_1 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", line=2, + summary="Unsupported argument", + detail="An argument named 'bad1' is not expected here.", + ), + ) + validate_fail_2 = _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", line=2, + summary="Unsupported argument", + detail="An argument named 'bad2' is not expected here.", + ), + ) + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + validate_fail_1, + validate_fail_2, + validate_success, + plan_result, + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.correction_attempts == 2 + + +# --------------------------------------------------------------------------- +# Tests: Original errors preserved when correction fails +# --------------------------------------------------------------------------- + + +class TestOriginalErrorsPreserved: + """When correction fails, the remaining errors are reported accurately.""" + + def test_errors_from_final_validation_are_returned(self, tmp_path): + """After exhausting attempts, the errors from the last validation run + are returned in the result.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "aws_instance" "example" {\n' + ' attr1 = "a"\n' + ' attr2 = "b"\n' + ' attr3 = "c"\n' + ' attr4 = "d"\n' + "}\n" + ) + + init_result = _make_completed_process(returncode=0) + + def make_fail(attr): + return _make_completed_process( + returncode=1, + stdout=_make_validate_error_json( + filename="main.tf", + line=2, + summary="Unsupported argument", + detail=f"An argument named '{attr}' is not expected here.", + ), + ) + + # 3 corrections + final failure + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[ + init_result, + make_fail("attr1"), + make_fail("attr2"), + make_fail("attr3"), + make_fail("attr4"), + ], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is False + assert len(result.errors) >= 1 + # The final error should be about attr4 + assert "attr4" in result.errors[0].message + + def test_uncorrectable_errors_preserved_with_file_info(self, tmp_path): + """Errors that cannot be corrected retain their file and line info.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text('invalid terraform content\n') + + init_result = _make_completed_process(returncode=0) + + # Error that doesn't match any correction pattern + validate_fail = _make_completed_process( + returncode=1, + stdout=json.dumps({ + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": "Completely unknown error type", + "detail": "This cannot be auto-corrected at all", + "range": { + "filename": "main.tf", + "start": {"line": 1, "column": 1}, + "end": {"line": 1, "column": 10}, + }, + } + ], + }), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_fail], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is False + assert result.correction_attempts == 0 # No correction was possible + assert len(result.errors) == 1 + assert result.errors[0].file == "main.tf" + assert result.errors[0].line == 1 + assert "Completely unknown error type" in result.errors[0].message + + def test_no_correction_when_error_has_no_file(self, tmp_path): + """Errors without file information cannot be corrected.""" + init_result = _make_completed_process(returncode=0) + + validate_fail = _make_completed_process( + returncode=1, + stdout=json.dumps({ + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": "Unsupported argument", + "detail": "An argument named 'foo' is not expected here.", + } + ], + }), + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_fail], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is False + # No file info means no correction can be applied + assert result.correction_attempts == 0 + + +# --------------------------------------------------------------------------- +# Tests: Syntax error correction +# --------------------------------------------------------------------------- + + +class TestSyntaxErrorCorrection: + """Tests for syntax error auto-correction.""" + + def test_trailing_comma_fixed(self, tmp_path): + """Trailing commas before closing braces are removed.""" + tf_file = tmp_path / "main.tf" + tf_file.write_text( + 'resource "null_resource" "test" {\n' + ' triggers = {\n' + ' key = "value",\n' + ' }\n' + '}\n' + ) + + init_result = _make_completed_process(returncode=0) + validate_fail = _make_completed_process( + returncode=1, + stdout=json.dumps({ + "valid": False, + "error_count": 1, + "diagnostics": [ + { + "severity": "error", + "summary": "Invalid character", + "detail": "trailing comma not allowed", + "range": { + "filename": "main.tf", + "start": {"line": 3, "column": 18}, + "end": {"line": 3, "column": 19}, + }, + } + ], + }), + ) + validate_success = _make_completed_process( + returncode=0, stdout=VALIDATE_SUCCESS_JSON + ) + plan_result = _make_completed_process( + returncode=0, stdout=PLAN_NO_CHANGES_JSON + ) + + with patch("shutil.which", return_value="/usr/bin/terraform"), patch( + "subprocess.run", + side_effect=[init_result, validate_fail, validate_success, plan_result], + ): + validator = Validator() + result = validator.validate(str(tmp_path)) + + assert result.validate_success is True + assert result.correction_attempts == 1 + + content = tf_file.read_text() + assert ",\n }" not in content diff --git a/tests/unit/test_variable_extractor.py b/tests/unit/test_variable_extractor.py new file mode 100644 index 0000000..08e8ca8 --- /dev/null +++ b/tests/unit/test_variable_extractor.py @@ -0,0 +1,405 @@ +"""Unit tests for the VariableExtractor.""" + +import pytest + +from iac_reverse.models import ( + CpuArchitecture, + DiscoveredResource, + ExtractedVariable, + PlatformCategory, + ProviderType, +) +from iac_reverse.generator import VariableExtractor + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_resource( + resource_type: str = "kubernetes_deployment", + unique_id: str = "default/deployments/nginx", + name: str = "nginx", + provider: ProviderType = ProviderType.KUBERNETES, + platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION, + architecture: CpuArchitecture = CpuArchitecture.AARCH64, + attributes: dict | None = None, +) -> DiscoveredResource: + """Create a sample DiscoveredResource for testing.""" + return DiscoveredResource( + resource_type=resource_type, + unique_id=unique_id, + name=name, + provider=provider, + platform_category=platform_category, + architecture=architecture, + endpoint="https://k8s-api.local:6443", + attributes=attributes or {}, + raw_references=[], + ) + + +# --------------------------------------------------------------------------- +# Tests: No shared values produces no variables +# --------------------------------------------------------------------------- + + +class TestNoSharedValues: + """Tests that no variables are produced when values are not shared.""" + + def test_empty_resources_produces_no_variables(self): + """An empty resource list produces no variables.""" + extractor = VariableExtractor() + result = extractor.extract_variables([]) + assert result == [] + + def test_single_resource_produces_no_variables(self): + """A single resource cannot have shared values.""" + resource = make_resource(attributes={"namespace": "default", "replicas": 3}) + extractor = VariableExtractor() + result = extractor.extract_variables([resource]) + assert result == [] + + def test_two_resources_no_common_values_produces_no_variables(self): + """Two resources with completely different attribute values produce no variables.""" + r1 = make_resource( + unique_id="r1", + name="app-a", + attributes={"namespace": "alpha", "replicas": 1}, + ) + r2 = make_resource( + unique_id="r2", + name="app-b", + attributes={"namespace": "beta", "replicas": 2}, + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + assert result == [] + + def test_two_resources_same_key_different_values_no_variable(self): + """Two resources with the same key but different values produce no variable.""" + r1 = make_resource( + unique_id="r1", + name="app-a", + attributes={"environment": "staging"}, + ) + r2 = make_resource( + unique_id="r2", + name="app-b", + attributes={"environment": "production"}, + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + assert result == [] + + +# --------------------------------------------------------------------------- +# Tests: Value appearing in 2 resources produces a variable +# --------------------------------------------------------------------------- + + +class TestSharedValueExtraction: + """Tests that shared values are correctly extracted as variables.""" + + def test_same_value_in_two_resources_produces_variable(self): + """A value appearing in 2 resources for the same key produces a variable.""" + r1 = make_resource( + unique_id="r1", + name="app-a", + attributes={"namespace": "production"}, + ) + r2 = make_resource( + unique_id="r2", + name="app-b", + attributes={"namespace": "production"}, + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert len(result) == 1 + assert result[0].name == "var_namespace" + + def test_same_value_in_three_resources_produces_one_variable(self): + """A value appearing in 3 resources still produces exactly one variable.""" + resources = [ + make_resource( + unique_id=f"r{i}", + name=f"app-{i}", + attributes={"region": "us-east-1"}, + ) + for i in range(3) + ] + extractor = VariableExtractor() + result = extractor.extract_variables(resources) + + assert len(result) == 1 + assert result[0].name == "var_region" + + def test_multiple_shared_keys_produce_multiple_variables(self): + """Multiple shared attribute keys each produce their own variable.""" + r1 = make_resource( + unique_id="r1", + name="app-a", + attributes={"namespace": "default", "environment": "prod"}, + ) + r2 = make_resource( + unique_id="r2", + name="app-b", + attributes={"namespace": "default", "environment": "prod"}, + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + var_names = {v.name for v in result} + assert "var_namespace" in var_names + assert "var_environment" in var_names + + +# --------------------------------------------------------------------------- +# Tests: Default is set to most common value +# --------------------------------------------------------------------------- + + +class TestDefaultValue: + """Tests that the default value is set to the most common value.""" + + def test_default_is_most_common_value(self): + """When only one value is shared (2+ resources), default is that value.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"namespace": "production"} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"namespace": "production"} + ) + r3 = make_resource( + unique_id="r3", name="app-c", attributes={"namespace": "staging"} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2, r3]) + + # "production" appears in 2 resources, "staging" in 1 (not shared) + # The variable for "production" should have default = "production" + assert len(result) == 1 + assert result[0].default_value == '"production"' + + def test_default_with_equal_counts_picks_one(self): + """When values have equal counts, each variable gets its own value as default.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"namespace": "alpha"} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"namespace": "alpha"} + ) + r3 = make_resource( + unique_id="r3", name="app-c", attributes={"namespace": "beta"} + ) + r4 = make_resource( + unique_id="r4", name="app-d", attributes={"namespace": "beta"} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2, r3, r4]) + + # Both "alpha" and "beta" appear in 2 resources each + # Both should produce variables; each with its own value as default + assert len(result) == 2 + defaults = {v.default_value for v in result} + assert '"alpha"' in defaults + assert '"beta"' in defaults + + +# --------------------------------------------------------------------------- +# Tests: Type expression matches value type +# --------------------------------------------------------------------------- + + +class TestTypeExpression: + """Tests that type expressions match the Python value type.""" + + def test_string_value_has_string_type(self): + """A string attribute value produces type = string.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"namespace": "default"} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"namespace": "default"} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert result[0].type_expr == "string" + + def test_integer_value_has_number_type(self): + """An integer attribute value produces type = number.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"replicas": 3} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"replicas": 3} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert result[0].type_expr == "number" + + def test_boolean_value_has_bool_type(self): + """A boolean attribute value produces type = bool.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"enabled": True} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"enabled": True} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert result[0].type_expr == "bool" + + def test_number_default_format(self): + """A number default is formatted without quotes.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"replicas": 3} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"replicas": 3} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert result[0].default_value == "3" + + def test_bool_default_format(self): + """A boolean default is formatted as true/false.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"enabled": False} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"enabled": False} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert result[0].default_value == "false" + + +# --------------------------------------------------------------------------- +# Tests: used_by lists all resources using the variable +# --------------------------------------------------------------------------- + + +class TestUsedBy: + """Tests that used_by correctly lists all resources using the variable.""" + + def test_used_by_contains_both_resource_ids(self): + """used_by lists both resource unique_ids that share the value.""" + r1 = make_resource( + unique_id="ns/deployments/app-a", + name="app-a", + attributes={"namespace": "production"}, + ) + r2 = make_resource( + unique_id="ns/deployments/app-b", + name="app-b", + attributes={"namespace": "production"}, + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2]) + + assert len(result) == 1 + assert "ns/deployments/app-a" in result[0].used_by + assert "ns/deployments/app-b" in result[0].used_by + + def test_used_by_contains_all_three_resource_ids(self): + """used_by lists all three resource unique_ids when value is shared by 3.""" + resources = [ + make_resource( + unique_id=f"ns/deployments/app-{i}", + name=f"app-{i}", + attributes={"environment": "prod"}, + ) + for i in range(3) + ] + extractor = VariableExtractor() + result = extractor.extract_variables(resources) + + assert len(result[0].used_by) == 3 + for i in range(3): + assert f"ns/deployments/app-{i}" in result[0].used_by + + def test_used_by_excludes_resources_with_different_value(self): + """used_by does not include resources that have a different value for the key.""" + r1 = make_resource( + unique_id="r1", name="app-a", attributes={"namespace": "production"} + ) + r2 = make_resource( + unique_id="r2", name="app-b", attributes={"namespace": "production"} + ) + r3 = make_resource( + unique_id="r3", name="app-c", attributes={"namespace": "staging"} + ) + extractor = VariableExtractor() + result = extractor.extract_variables([r1, r2, r3]) + + # Only the "production" variable should exist (2 resources) + assert len(result) == 1 + assert "r1" in result[0].used_by + assert "r2" in result[0].used_by + assert "r3" not in result[0].used_by + + +# --------------------------------------------------------------------------- +# Tests: generate_variables_tf output +# --------------------------------------------------------------------------- + + +class TestGenerateVariablesTf: + """Tests for the variables.tf file content generation.""" + + def test_empty_variables_produces_empty_string(self): + """No variables produces an empty string.""" + extractor = VariableExtractor() + result = extractor.generate_variables_tf([]) + assert result == "" + + def test_single_variable_produces_valid_block(self): + """A single variable produces a valid Terraform variable block.""" + var = ExtractedVariable( + name="var_namespace", + type_expr="string", + default_value='"production"', + description="Shared namespace value extracted from 2 resources", + used_by=["r1", "r2"], + ) + extractor = VariableExtractor() + result = extractor.generate_variables_tf([var]) + + assert 'variable "var_namespace"' in result + assert "type = string" in result + assert 'description = "Shared namespace value extracted from 2 resources"' in result + assert 'default = "production"' in result + + def test_multiple_variables_separated_by_blank_line(self): + """Multiple variables are separated by blank lines.""" + vars_list = [ + ExtractedVariable( + name="var_namespace", + type_expr="string", + default_value='"default"', + description="Shared namespace", + used_by=["r1", "r2"], + ), + ExtractedVariable( + name="var_replicas", + type_expr="number", + default_value="3", + description="Shared replicas", + used_by=["r1", "r2"], + ), + ] + extractor = VariableExtractor() + result = extractor.generate_variables_tf(vars_list) + + assert 'variable "var_namespace"' in result + assert 'variable "var_replicas"' in result + # Two blocks separated by a blank line + assert "\n\n" in result diff --git a/tests/unit/test_windows_plugin.py b/tests/unit/test_windows_plugin.py new file mode 100644 index 0000000..1bd6935 --- /dev/null +++ b/tests/unit/test_windows_plugin.py @@ -0,0 +1,529 @@ +"""Unit tests for the Windows Discovery Plugin. + +Tests use mocks for the winrm session to avoid requiring actual +Windows hosts for testing. +""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from iac_reverse.models import CpuArchitecture, PlatformCategory, ProviderType +from iac_reverse.scanner.scanner import AuthenticationError +from iac_reverse.scanner.windows_plugin import ( + InsufficientPrivilegesError, + WindowsDiscoveryPlugin, + WinRMNotEnabledError, + WMIQueryError, + WINDOWS_RESOURCE_TYPES, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def plugin(): + """Create a fresh WindowsDiscoveryPlugin instance.""" + return WindowsDiscoveryPlugin() + + +@pytest.fixture +def credentials(): + """Standard test credentials.""" + return { + "host": "192.168.1.100", + "username": "admin", + "password": "secret", + "transport": "ntlm", + "port": "5986", + "use_ssl": "true", + } + + +@pytest.fixture +def mock_session(): + """Create a mock WinRM session.""" + session = MagicMock() + return session + + +def make_ps_result(stdout: str = "", stderr: str = "", status_code: int = 0): + """Helper to create a mock PowerShell result.""" + result = MagicMock() + result.std_out = stdout.encode("utf-8") + result.std_err = stderr.encode("utf-8") + result.status_code = status_code + return result + + +# --------------------------------------------------------------------------- +# Authentication Tests +# --------------------------------------------------------------------------- + + +class TestAuthenticate: + """Tests for WindowsDiscoveryPlugin.authenticate().""" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_authenticate_success(self, mock_session_cls, plugin, credentials): + """Successful authentication creates a session.""" + mock_session = MagicMock() + mock_session.run_ps.return_value = make_ps_result("WIN-SERVER01") + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + + mock_session_cls.assert_called_once_with( + "https://192.168.1.100:5986/wsman", + auth=("admin", "secret"), + transport="ntlm", + server_cert_validation="ignore", + ) + assert plugin._host == "192.168.1.100" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_authenticate_http_no_ssl(self, mock_session_cls, plugin): + """Authentication with use_ssl=false uses HTTP.""" + creds = { + "host": "myhost", + "username": "user", + "password": "pass", + "transport": "ntlm", + "port": "5985", + "use_ssl": "false", + } + mock_session = MagicMock() + mock_session.run_ps.return_value = make_ps_result("MYHOST") + mock_session_cls.return_value = mock_session + + plugin.authenticate(creds) + + mock_session_cls.assert_called_once_with( + "http://myhost:5985/wsman", + auth=("user", "pass"), + transport="ntlm", + server_cert_validation="validate", + ) + + def test_authenticate_missing_host(self, plugin): + """Missing host raises AuthenticationError.""" + creds = {"username": "user", "password": "pass"} + with pytest.raises(AuthenticationError, match="host is required"): + plugin.authenticate(creds) + + def test_authenticate_missing_username(self, plugin): + """Missing username raises AuthenticationError.""" + creds = {"host": "myhost", "password": "pass"} + with pytest.raises(AuthenticationError, match="username is required"): + plugin.authenticate(creds) + + def test_authenticate_missing_password(self, plugin): + """Missing password raises AuthenticationError.""" + creds = {"host": "myhost", "username": "user"} + with pytest.raises(AuthenticationError, match="password is required"): + plugin.authenticate(creds) + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_authenticate_connection_refused(self, mock_session_cls, plugin, credentials): + """Connection refused raises WinRMNotEnabledError.""" + mock_session_cls.side_effect = Exception("connection refused") + + with pytest.raises(WinRMNotEnabledError): + plugin.authenticate(credentials) + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_authenticate_access_denied(self, mock_session_cls, plugin, credentials): + """Access denied during auth test raises AuthenticationError.""" + mock_session = MagicMock() + mock_session.run_ps.return_value = make_ps_result( + stderr="Access is denied", status_code=1 + ) + mock_session_cls.return_value = mock_session + + with pytest.raises(AuthenticationError): + plugin.authenticate(credentials) + + +# --------------------------------------------------------------------------- +# Platform Category and Resource Types Tests +# --------------------------------------------------------------------------- + + +class TestPlatformInfo: + """Tests for platform category and resource type listing.""" + + def test_get_platform_category(self, plugin): + """Returns PlatformCategory.WINDOWS.""" + assert plugin.get_platform_category() == PlatformCategory.WINDOWS + + def test_list_supported_resource_types(self, plugin): + """Returns all 13 Windows resource types.""" + types = plugin.list_supported_resource_types() + assert len(types) == 13 + assert "windows_service" in types + assert "windows_scheduled_task" in types + assert "windows_iis_site" in types + assert "windows_iis_app_pool" in types + assert "windows_network_adapter" in types + assert "windows_firewall_rule" in types + assert "windows_installed_software" in types + assert "windows_feature" in types + assert "windows_hyperv_vm" in types + assert "windows_hyperv_switch" in types + assert "windows_dns_record" in types + assert "windows_local_user" in types + assert "windows_local_group" in types + + def test_list_endpoints_before_auth(self, plugin): + """Returns empty list before authentication.""" + assert plugin.list_endpoints() == [] + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_list_endpoints_after_auth(self, mock_session_cls, plugin, credentials): + """Returns host after authentication.""" + mock_session = MagicMock() + mock_session.run_ps.return_value = make_ps_result("SERVER") + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + assert plugin.list_endpoints() == ["192.168.1.100"] + + +# --------------------------------------------------------------------------- +# Architecture Detection Tests +# --------------------------------------------------------------------------- + + +class TestDetectArchitecture: + """Tests for CPU architecture detection via WMI.""" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_detect_amd64(self, mock_session_cls, plugin, credentials): + """Architecture code 9 maps to AMD64.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), # auth test + make_ps_result("9"), # architecture query + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + arch = plugin.detect_architecture("192.168.1.100") + assert arch == CpuArchitecture.AMD64 + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_detect_arm(self, mock_session_cls, plugin, credentials): + """Architecture code 5 maps to ARM.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("5"), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + arch = plugin.detect_architecture("192.168.1.100") + assert arch == CpuArchitecture.ARM + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_detect_aarch64(self, mock_session_cls, plugin, credentials): + """Architecture code 12 maps to AARCH64.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("12"), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + arch = plugin.detect_architecture("192.168.1.100") + assert arch == CpuArchitecture.AARCH64 + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_detect_architecture_wmi_failure(self, mock_session_cls, plugin, credentials): + """WMI query failure raises WMIQueryError.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result(stderr="WMI error", status_code=1), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + with pytest.raises(WMIQueryError): + plugin.detect_architecture("192.168.1.100") + + +# --------------------------------------------------------------------------- +# Resource Discovery Tests +# --------------------------------------------------------------------------- + + +class TestDiscoverResources: + """Tests for resource discovery via WinRM.""" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_services(self, mock_session_cls, plugin, credentials): + """Discovers Windows services.""" + services_json = json.dumps([ + {"Name": "wuauserv", "DisplayName": "Windows Update", "Status": 4, "StartType": 3}, + {"Name": "Spooler", "DisplayName": "Print Spooler", "Status": 4, "StartType": 2}, + ]) + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), # auth + make_ps_result("9"), # architecture + make_ps_result("false"), # hyperv check + make_ps_result(services_json), # services + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_service"], + progress_callback=callback, + ) + + assert len(result.resources) == 2 + assert result.resources[0].resource_type == "windows_service" + assert result.resources[0].name == "wuauserv" + assert result.resources[0].provider == ProviderType.WINDOWS + assert result.resources[0].platform_category == PlatformCategory.WINDOWS + assert result.resources[0].architecture == CpuArchitecture.AMD64 + callback.assert_called() + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_scheduled_tasks(self, mock_session_cls, plugin, credentials): + """Discovers scheduled tasks.""" + tasks_json = json.dumps([ + {"TaskName": "Backup", "TaskPath": "\\Custom\\", "State": 3}, + ]) + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("false"), + make_ps_result(tasks_json), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_scheduled_task"], + progress_callback=callback, + ) + + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "windows_scheduled_task" + assert result.resources[0].name == "Backup" + assert result.resources[0].attributes["task_path"] == "\\Custom\\" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_hyperv_skipped_when_not_installed( + self, mock_session_cls, plugin, credentials + ): + """Hyper-V resources are skipped when role is not installed.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("false"), # hyperv NOT installed + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_hyperv_vm"], + progress_callback=callback, + ) + + assert len(result.resources) == 0 + assert any("Hyper-V role not installed" in w for w in result.warnings) + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_hyperv_vms_when_installed( + self, mock_session_cls, plugin, credentials + ): + """Hyper-V VMs are discovered when role is installed.""" + vms_json = json.dumps([ + { + "Name": "TestVM", + "VMId": "abc-123", + "State": 2, + "MemoryAssigned": 4294967296, + "ProcessorCount": 4, + "Generation": 2, + } + ]) + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("true"), # hyperv IS installed + make_ps_result(vms_json), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_hyperv_vm"], + progress_callback=callback, + ) + + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "windows_hyperv_vm" + assert result.resources[0].name == "TestVM" + assert result.resources[0].attributes["vm_id"] == "abc-123" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_local_users(self, mock_session_cls, plugin, credentials): + """Discovers local user accounts.""" + users_json = json.dumps([ + {"Name": "Administrator", "Enabled": True, "Description": "Built-in admin", "LastLogon": "2024-01-01"}, + ]) + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("false"), + make_ps_result(users_json), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_local_user"], + progress_callback=callback, + ) + + assert len(result.resources) == 1 + assert result.resources[0].resource_type == "windows_local_user" + assert result.resources[0].name == "Administrator" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_insufficient_privileges( + self, mock_session_cls, plugin, credentials + ): + """Insufficient privileges are captured as errors.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("false"), + make_ps_result(stderr="Access is denied", status_code=1), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_service"], + progress_callback=callback, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 1 + assert "Insufficient privileges" in result.errors[0] + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_wmi_query_failure( + self, mock_session_cls, plugin, credentials + ): + """WMI query failures are captured as errors.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("false"), + make_ps_result(stderr="Invalid class", status_code=1), + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_feature"], + progress_callback=callback, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 1 + assert "WMI query failed" in result.errors[0] + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_discover_empty_result(self, mock_session_cls, plugin, credentials): + """Empty PowerShell output returns no resources.""" + mock_session = MagicMock() + mock_session.run_ps.side_effect = [ + make_ps_result("SERVER"), + make_ps_result("9"), + make_ps_result("false"), + make_ps_result(""), # empty output + ] + mock_session_cls.return_value = mock_session + + plugin.authenticate(credentials) + callback = MagicMock() + result = plugin.discover_resources( + endpoints=["192.168.1.100"], + resource_types=["windows_local_group"], + progress_callback=callback, + ) + + assert len(result.resources) == 0 + assert len(result.errors) == 0 + + +# --------------------------------------------------------------------------- +# Error Handling Tests +# --------------------------------------------------------------------------- + + +class TestErrorHandling: + """Tests for WinRM-specific error handling.""" + + def test_winrm_not_enabled_error(self): + """WinRMNotEnabledError contains host info.""" + err = WinRMNotEnabledError("myhost", "connection refused") + assert "myhost" in str(err) + assert "connection refused" in str(err) + assert err.host == "myhost" + + def test_wmi_query_error(self): + """WMIQueryError contains query info.""" + err = WMIQueryError("Win32_Processor", "invalid class") + assert "Win32_Processor" in str(err) + assert "invalid class" in str(err) + assert err.query == "Win32_Processor" + + def test_insufficient_privileges_error(self): + """InsufficientPrivilegesError contains operation info.""" + err = InsufficientPrivilegesError("Get-Service", "access denied") + assert "Get-Service" in str(err) + assert "access denied" in str(err) + assert err.operation == "Get-Service" + + @patch("iac_reverse.scanner.windows_plugin.winrm.Session") + def test_no_session_raises_winrm_not_enabled( + self, mock_session_cls, plugin + ): + """Running PowerShell without session raises WinRMNotEnabledError.""" + plugin._host = "testhost" + with pytest.raises(WinRMNotEnabledError): + plugin._run_powershell("Get-Service")