Created IAC reverse generator

This commit is contained in:
p2913020
2026-05-22 00:19:30 -04:00
parent d04c2c6e4b
commit 1a11244fff
161 changed files with 26806 additions and 51 deletions

View File

@@ -6,15 +6,15 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
## Tasks ## Tasks
- [ ] 1. Set up project structure and core data models - [x] 1. Set up project structure and core data models
- [ ] 1.1 Create project directory structure, pyproject.toml, and install dependencies - [x] 1.1 Create project directory structure, pyproject.toml, and install dependencies
- Create `src/iac_reverse/` package with `__init__.py` - Create `src/iac_reverse/` package with `__init__.py`
- Create subdirectories: `scanner/`, `resolver/`, `generator/`, `state_builder/`, `validator/`, `incremental/`, `auth/`, `cli/` - Create subdirectories: `scanner/`, `resolver/`, `generator/`, `state_builder/`, `validator/`, `incremental/`, `auth/`, `cli/`
- Set up `pyproject.toml` with dependencies: kubernetes, docker, pywinrm, hypothesis, pytest, click, jinja2, networkx, pyyaml, python-synology - Set up `pyproject.toml` with dependencies: kubernetes, docker, pywinrm, hypothesis, pytest, click, jinja2, networkx, pyyaml, python-synology
- Create `tests/` directory with `unit/`, `property/`, `integration/` subdirectories - Create `tests/` directory with `unit/`, `property/`, `integration/` subdirectories
- _Requirements: 1.1, 5.1, 5.2_ - _Requirements: 1.1, 5.1, 5.2_
- [ ] 1.2 Define core enums, data classes, and interfaces - [x] 1.2 Define core enums, data classes, and interfaces
- Implement `ProviderType` enum (docker_swarm, kubernetes, synology, harvester, bare_metal, windows) - Implement `ProviderType` enum (docker_swarm, kubernetes, synology, harvester, bare_metal, windows)
- Implement `PlatformCategory` enum (container_orchestration, storage_appliance, hci, bare_metal, windows) and `PROVIDER_PLATFORM_MAP` - Implement `PlatformCategory` enum (container_orchestration, storage_appliance, hci, bare_metal, windows) and `PROVIDER_PLATFORM_MAP`
- Implement `CpuArchitecture` enum (amd64, arm, aarch64) - Implement `CpuArchitecture` enum (amd64, arm, aarch64)
@@ -27,19 +27,19 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Define `ProviderPlugin` abstract base class with all abstract methods - Define `ProviderPlugin` abstract base class with all abstract methods
- _Requirements: 1.1, 1.2, 2.1, 3.1, 4.1, 5.1, 5.2, 8.1_ - _Requirements: 1.1, 1.2, 2.1, 3.1, 4.1, 5.1, 5.2, 8.1_
- [ ] 1.3 Implement ScanProfile validation logic - [x] 1.3 Implement ScanProfile validation logic
- Validate mandatory fields: provider type and non-empty credentials - Validate mandatory fields: provider type and non-empty credentials
- Validate optional fields: resource_type_filters max 200 entries, endpoints list - Validate optional fields: resource_type_filters max 200 entries, endpoints list
- Validate resource types against provider's supported types - Validate resource types against provider's supported types
- Return all validation errors in a single response - Return all validation errors in a single response
- _Requirements: 6.1, 6.6, 6.7_ - _Requirements: 6.1, 6.6, 6.7_
- [ ]* 1.4 Write property test for scan profile validation (Property 20) - [x] 1.4 Write property test for scan profile validation (Property 20)
- **Property 20: Scan profile validation completeness** - **Property 20: Scan profile validation completeness**
- **Validates: Requirements 6.1, 6.6, 6.7** - **Validates: Requirements 6.1, 6.6, 6.7**
- [ ] 2. Implement Scanner core and provider plugin system - [x] 2. Implement Scanner core and provider plugin system
- [ ] 2.1 Implement Scanner orchestrator with progress reporting and error handling - [x] 2.1 Implement Scanner orchestrator with progress reporting and error handling
- Create `Scanner` class that accepts a `ScanProfile` and orchestrates discovery - Create `Scanner` class that accepts a `ScanProfile` and orchestrates discovery
- Implement connection timeout (30 seconds) and authentication error handling with descriptive messages - Implement connection timeout (30 seconds) and authentication error handling with descriptive messages
- Implement progress callback invocation per resource type completion - Implement progress callback invocation per resource type completion
@@ -48,44 +48,44 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Implement warning logging for unsupported resource types while continuing scan - Implement warning logging for unsupported resource types while continuing scan
- _Requirements: 1.1, 1.3, 1.4, 1.5, 1.6, 1.7_ - _Requirements: 1.1, 1.3, 1.4, 1.5, 1.6, 1.7_
- [ ]* 2.2 Write property tests for Scanner behavior (Properties 2, 3, 4, 5) - [x] 2.2 Write property tests for Scanner behavior (Properties 2, 3, 4, 5)
- **Property 2: Authentication error descriptiveness** - **Property 2: Authentication error descriptiveness**
- **Property 3: Graceful degradation on unsupported resource types** - **Property 3: Graceful degradation on unsupported resource types**
- **Property 4: Progress reporting frequency** - **Property 4: Progress reporting frequency**
- **Property 5: Partial inventory preservation on failure** - **Property 5: Partial inventory preservation on failure**
- **Validates: Requirements 1.3, 1.4, 1.5, 1.7** - **Validates: Requirements 1.3, 1.4, 1.5, 1.7**
- [ ] 2.3 Implement Docker Swarm provider plugin - [x] 2.3 Implement Docker Swarm provider plugin
- Implement `DockerSwarmPlugin` using docker-sdk-python - Implement `DockerSwarmPlugin` using docker-sdk-python
- Discover services, networks, volumes, configs, secrets (metadata only) - Discover services, networks, volumes, configs, secrets (metadata only)
- Detect architecture from node info - Detect architecture from node info
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ] 2.4 Implement Kubernetes provider plugin - [x] 2.4 Implement Kubernetes provider plugin
- Implement `KubernetesPlugin` using kubernetes-client - Implement `KubernetesPlugin` using kubernetes-client
- Discover deployments, services, ingresses, config maps, persistent volumes, namespaces - Discover deployments, services, ingresses, config maps, persistent volumes, namespaces
- Detect architecture from node labels - Detect architecture from node labels
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ] 2.5 Implement Synology provider plugin - [x] 2.5 Implement Synology provider plugin
- Implement `SynologyPlugin` using Synology DSM API - Implement `SynologyPlugin` using Synology DSM API
- Discover shared folders, volumes, storage pools, replication tasks, users - Discover shared folders, volumes, storage pools, replication tasks, users
- Detect architecture from system info (ARM vs AMD64) - Detect architecture from system info (ARM vs AMD64)
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ] 2.6 Implement Harvester provider plugin - [x] 2.6 Implement Harvester provider plugin
- Implement `HarvesterPlugin` using Harvester/K8s-based API - Implement `HarvesterPlugin` using Harvester/K8s-based API
- Discover VMs, volumes, images, networks (HCI combined resources) - Discover VMs, volumes, images, networks (HCI combined resources)
- Detect architecture from node info - Detect architecture from node info
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ] 2.7 Implement Bare Metal provider plugin - [x] 2.7 Implement Bare Metal provider plugin
- Implement `BareMetalPlugin` using IPMI/Redfish API - Implement `BareMetalPlugin` using IPMI/Redfish API
- Discover hardware inventory, BMC configs, network interfaces, RAID configurations - Discover hardware inventory, BMC configs, network interfaces, RAID configurations
- Detect architecture from system hardware info - Detect architecture from system hardware info
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ] 2.8 Implement Windows provider plugin - [x] 2.8 Implement Windows provider plugin
- Implement `WindowsDiscoveryPlugin` using pywinrm library - Implement `WindowsDiscoveryPlugin` using pywinrm library
- Authenticate via WinRM using NTLM or Kerberos (configurable transport, port, SSL) - Authenticate via WinRM using NTLM or Kerberos (configurable transport, port, SSL)
- Discover Windows services, scheduled tasks, IIS sites, IIS app pools, network adapters, firewall rules, installed software, Windows features, Hyper-V VMs, Hyper-V switches, DNS records, local users, local groups - Discover Windows services, scheduled tasks, IIS sites, IIS app pools, network adapters, firewall rules, installed software, Windows features, Hyper-V VMs, Hyper-V switches, DNS records, local users, local groups
@@ -94,21 +94,21 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Handle WinRM-specific errors: WinRM not enabled, WMI query failure, insufficient privileges - Handle WinRM-specific errors: WinRM not enabled, WMI query failure, insufficient privileges
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ] 2.9 Implement Authentik integration (SSO + discovery plugin) - [x] 2.9 Implement Authentik integration (SSO + discovery plugin)
- Implement `AuthentikAuthProvider` for OAuth2/OIDC SSO flow (authenticate, refresh, validate) - Implement `AuthentikAuthProvider` for OAuth2/OIDC SSO flow (authenticate, refresh, validate)
- Implement `AuthentikDiscoveryPlugin` conforming to `ProviderPlugin` - Implement `AuthentikDiscoveryPlugin` conforming to `ProviderPlugin`
- Discover flows, stages, providers, applications, outposts, property mappings, certificates, groups, sources - Discover flows, stages, providers, applications, outposts, property mappings, certificates, groups, sources
- _Requirements: 1.1, 1.2, 5.2_ - _Requirements: 1.1, 1.2, 5.2_
- [ ]* 2.10 Write property test for resource inventory completeness (Property 1) - [x] 2.10 Write property test for resource inventory completeness (Property 1)
- **Property 1: Resource inventory completeness** - **Property 1: Resource inventory completeness**
- **Validates: Requirements 1.2** - **Validates: Requirements 1.2**
- [ ] 3. Checkpoint - Ensure all tests pass - [x] 3. Checkpoint - Ensure all tests pass
- Ensure all tests pass, ask the user if questions arise. - Ensure all tests pass, ask the user if questions arise.
- [ ] 4. Implement Dependency Resolver - [x] 4. Implement Dependency Resolver
- [ ] 4.1 Implement dependency resolution and graph building - [x] 4.1 Implement dependency resolution and graph building
- Create `DependencyResolver` class - Create `DependencyResolver` class
- Analyze resource `raw_references` to identify parent-child, reference, and dependency relationships - Analyze resource `raw_references` to identify parent-child, reference, and dependency relationships
- Build dependency graph using networkx - Build dependency graph using networkx
@@ -116,27 +116,27 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Represent relationships as explicit Terraform references (not hardcoded IDs) - Represent relationships as explicit Terraform references (not hardcoded IDs)
- _Requirements: 2.1, 2.2, 2.4_ - _Requirements: 2.1, 2.2, 2.4_
- [ ] 4.2 Implement cycle detection and resolution suggestions - [x] 4.2 Implement cycle detection and resolution suggestions
- Detect circular dependencies in the graph - Detect circular dependencies in the graph
- Report cycles listing all involved resources - Report cycles listing all involved resources
- Suggest resolution strategies (which relationship to break, data source lookup alternatives) - Suggest resolution strategies (which relationship to break, data source lookup alternatives)
- _Requirements: 2.3_ - _Requirements: 2.3_
- [ ] 4.3 Implement unresolved reference handling - [x] 4.3 Implement unresolved reference handling
- Identify references to IDs not in the current inventory - Identify references to IDs not in the current inventory
- Log warnings for unresolved references - Log warnings for unresolved references
- Represent unresolved references as data source lookups or variables in output - Represent unresolved references as data source lookups or variables in output
- _Requirements: 2.5_ - _Requirements: 2.5_
- [ ]* 4.4 Write property tests for Dependency Resolver (Properties 6, 7, 8, 9) - [x] 4.4 Write property tests for Dependency Resolver (Properties 6, 7, 8, 9)
- **Property 6: Dependency relationship identification** - **Property 6: Dependency relationship identification**
- **Property 7: Cycle detection correctness** - **Property 7: Cycle detection correctness**
- **Property 8: Topological order validity** - **Property 8: Topological order validity**
- **Property 9: Unresolved references become data sources or variables** - **Property 9: Unresolved references become data sources or variables**
- **Validates: Requirements 2.1, 2.3, 2.4, 2.5** - **Validates: Requirements 2.1, 2.3, 2.4, 2.5**
- [ ] 5. Implement Code Generator - [x] 5. Implement Code Generator
- [ ] 5.1 Implement HCL code generation with Jinja2 templates - [x] 5.1 Implement HCL code generation with Jinja2 templates
- Create `CodeGenerator` class - Create `CodeGenerator` class
- Create Jinja2 templates for Terraform resource blocks per provider/resource type - Create Jinja2 templates for Terraform resource blocks per provider/resource type
- Generate syntactically valid HCL files from dependency graph - Generate syntactically valid HCL files from dependency graph
@@ -146,31 +146,31 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Generate architecture-specific tags/labels on resources - Generate architecture-specific tags/labels on resources
- _Requirements: 3.1, 3.2, 3.5, 3.6_ - _Requirements: 3.1, 3.2, 3.5, 3.6_
- [ ] 5.2 Implement identifier sanitization - [x] 5.2 Implement identifier sanitization
- Create `sanitize_identifier()` function - Create `sanitize_identifier()` function
- Convert resource names to valid Terraform identifiers: `^[a-zA-Z_][a-zA-Z0-9_]*$` - Convert resource names to valid Terraform identifiers: `^[a-zA-Z_][a-zA-Z0-9_]*$`
- Handle special characters, unicode, leading digits, spaces by replacing with underscores - Handle special characters, unicode, leading digits, spaces by replacing with underscores
- Ensure non-empty output for any input - Ensure non-empty output for any input
- _Requirements: 3.4_ - _Requirements: 3.4_
- [ ] 5.3 Implement variable extraction logic - [x] 5.3 Implement variable extraction logic
- Identify attribute values appearing in 2+ resources - Identify attribute values appearing in 2+ resources
- Extract shared values into `variables.tf` with defaults set to most common value - Extract shared values into `variables.tf` with defaults set to most common value
- Generate variable declarations with type expressions and descriptions - Generate variable declarations with type expressions and descriptions
- _Requirements: 3.3_ - _Requirements: 3.3_
- [ ] 5.4 Implement provider configuration block generation - [x] 5.4 Implement provider configuration block generation
- Generate separate provider blocks for each distinct provider used - Generate separate provider blocks for each distinct provider used
- Include platform-specific configuration (endpoints, certificate settings) - Include platform-specific configuration (endpoints, certificate settings)
- _Requirements: 5.4_ - _Requirements: 5.4_
- [ ] 5.5 Implement multi-provider resource merging with conflict resolution - [x] 5.5 Implement multi-provider resource merging with conflict resolution
- Merge resources from multiple scan profiles into unified inventory - Merge resources from multiple scan profiles into unified inventory
- Resolve naming conflicts by prefixing with provider identifier - Resolve naming conflicts by prefixing with provider identifier
- Preserve provider-specific attributes - Preserve provider-specific attributes
- _Requirements: 5.3_ - _Requirements: 5.3_
- [ ]* 5.6 Write property tests for Code Generator (Properties 10, 11, 12, 13, 14, 15) - [x] 5.6 Write property tests for Code Generator (Properties 10, 11, 12, 13, 14, 15)
- **Property 10: References in generated output use Terraform syntax** - **Property 10: References in generated output use Terraform syntax**
- **Property 11: Generated HCL syntactic validity** - **Property 11: Generated HCL syntactic validity**
- **Property 12: File organization by resource type** - **Property 12: File organization by resource type**
@@ -179,8 +179,8 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- **Property 15: Traceability comments in generated code** - **Property 15: Traceability comments in generated code**
- **Validates: Requirements 2.2, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6** - **Validates: Requirements 2.2, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6**
- [ ] 6. Implement State Builder - [x] 6. Implement State Builder
- [ ] 6.1 Implement Terraform state file generation (format v4) - [x] 6.1 Implement Terraform state file generation (format v4)
- Create `StateBuilder` class - Create `StateBuilder` class
- Generate state JSON with version=4, unique UUID lineage, serial number - Generate state JSON with version=4, unique UUID lineage, serial number
- Create state entries binding each resource block to its live infrastructure ID - Create state entries binding each resource block to its live infrastructure ID
@@ -190,22 +190,22 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Include dependency references in state entries - Include dependency references in state entries
- _Requirements: 4.1, 4.2, 4.4, 4.5_ - _Requirements: 4.1, 4.2, 4.4, 4.5_
- [ ] 6.2 Implement unmapped resource handling in state builder - [x] 6.2 Implement unmapped resource handling in state builder
- Log warnings for resources that cannot be mapped to state entries - Log warnings for resources that cannot be mapped to state entries
- Handle missing provider-assigned resource identifiers - Handle missing provider-assigned resource identifiers
- Exclude unmapped resources from state file - Exclude unmapped resources from state file
- _Requirements: 4.3, 4.6_ - _Requirements: 4.3, 4.6_
- [ ]* 6.3 Write property tests for State Builder (Properties 16, 17) - [x] 6.3 Write property tests for State Builder (Properties 16, 17)
- **Property 16: State file structural validity** - **Property 16: State file structural validity**
- **Property 17: State entry completeness and schema correctness** - **Property 17: State entry completeness and schema correctness**
- **Validates: Requirements 4.1, 4.2, 4.4, 4.5** - **Validates: Requirements 4.1, 4.2, 4.4, 4.5**
- [ ] 7. Checkpoint - Ensure all tests pass - [x] 7. Checkpoint - Ensure all tests pass
- Ensure all tests pass, ask the user if questions arise. - Ensure all tests pass, ask the user if questions arise.
- [ ] 8. Implement Validator - [x] 8. Implement Validator
- [ ] 8.1 Implement Terraform validation runner - [x] 8.1 Implement Terraform validation runner
- Create `Validator` class - Create `Validator` class
- Run `terraform init` and `terraform validate` against generated output - Run `terraform init` and `terraform validate` against generated output
- Run `terraform plan` and check for zero planned changes - Run `terraform plan` and check for zero planned changes
@@ -214,46 +214,46 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Handle missing Terraform binary with descriptive error - Handle missing Terraform binary with descriptive error
- _Requirements: 7.1, 7.2, 7.3, 7.5_ - _Requirements: 7.1, 7.2, 7.3, 7.5_
- [ ] 8.2 Implement auto-correction loop for validation errors - [x] 8.2 Implement auto-correction loop for validation errors
- Attempt to correct validation errors (up to 3 attempts) - Attempt to correct validation errors (up to 3 attempts)
- Re-validate after each correction - Re-validate after each correction
- Report failure with remaining error details if corrections exhausted - Report failure with remaining error details if corrections exhausted
- _Requirements: 7.4_ - _Requirements: 7.4_
- [ ]* 8.3 Write property test for drift report correctness (Property 22) - [x] 8.3 Write property test for drift report correctness (Property 22)
- **Property 22: Drift report correctness** - **Property 22: Drift report correctness**
- **Validates: Requirements 7.3** - **Validates: Requirements 7.3**
- [ ] 9. Implement Incremental Scan Engine - [x] 9. Implement Incremental Scan Engine
- [ ] 9.1 Implement scan snapshot storage and retrieval - [x] 9.1 Implement scan snapshot storage and retrieval
- Store scan results as timestamped JSON in `.iac-reverse/snapshots/` - Store scan results as timestamped JSON in `.iac-reverse/snapshots/`
- Use profile_hash for matching scans to profiles - Use profile_hash for matching scans to profiles
- Retain at least 2 most recent snapshots per profile - Retain at least 2 most recent snapshots per profile
- Load previous snapshot for comparison - Load previous snapshot for comparison
- _Requirements: 8.4, 8.6_ - _Requirements: 8.4, 8.6_
- [ ] 9.2 Implement change detection and classification - [x] 9.2 Implement change detection and classification
- Compare current scan against previous snapshot - Compare current scan against previous snapshot
- Classify resources as added, removed, or modified - Classify resources as added, removed, or modified
- Produce change summary with counts and resource details - Produce change summary with counts and resource details
- Handle first scan (no previous) as full initial scan - Handle first scan (no previous) as full initial scan
- _Requirements: 8.1, 8.4, 8.5_ - _Requirements: 8.1, 8.4, 8.5_
- [ ] 9.3 Implement incremental code and state updates - [x] 9.3 Implement incremental code and state updates
- Update only IaC files containing changed resources (not full regeneration) - Update only IaC files containing changed resources (not full regeneration)
- Remove resource blocks and state entries for removed resources - Remove resource blocks and state entries for removed resources
- Add/update blocks for added/modified resources - Add/update blocks for added/modified resources
- _Requirements: 8.2, 8.3_ - _Requirements: 8.2, 8.3_
- [ ]* 9.4 Write property tests for Incremental Scan (Properties 23, 24, 25, 26) - [x] 9.4 Write property tests for Incremental Scan (Properties 23, 24, 25, 26)
- **Property 23: Change classification correctness** - **Property 23: Change classification correctness**
- **Property 24: Incremental update scope** - **Property 24: Incremental update scope**
- **Property 25: Removed resource exclusion** - **Property 25: Removed resource exclusion**
- **Property 26: Snapshot retention** - **Property 26: Snapshot retention**
- **Validates: Requirements 8.1, 8.2, 8.3, 8.5, 8.6** - **Validates: Requirements 8.1, 8.2, 8.3, 8.5, 8.6**
- [ ] 10. Implement CLI and wire pipeline together - [x] 10. Implement CLI and wire pipeline together
- [ ] 10.1 Implement CLI entry point with Click - [x] 10.1 Implement CLI entry point with Click
- Create `cli.py` with Click command group - Create `cli.py` with Click command group
- Implement `scan` command accepting scan profile YAML path - Implement `scan` command accepting scan profile YAML path
- Implement `generate` command to run full pipeline (scan → resolve → generate → state → validate) - Implement `generate` command to run full pipeline (scan → resolve → generate → state → validate)
@@ -264,32 +264,32 @@ Build a Python CLI tool that reverse-engineers existing on-premises infrastructu
- Add progress bars and formatted output for scan progress - Add progress bars and formatted output for scan progress
- _Requirements: 1.1, 1.5, 6.1, 6.2, 6.3, 6.4, 6.5_ - _Requirements: 1.1, 1.5, 6.1, 6.2, 6.3, 6.4, 6.5_
- [ ] 10.2 Implement scan profile YAML loading and environment variable expansion - [x] 10.2 Implement scan profile YAML loading and environment variable expansion
- Parse YAML scan profiles - Parse YAML scan profiles
- Expand `${ENV_VAR}` references in credential fields - Expand `${ENV_VAR}` references in credential fields
- Support multi-profile YAML for multi-provider scans - Support multi-profile YAML for multi-provider scans
- _Requirements: 6.1, 5.3_ - _Requirements: 6.1, 5.3_
- [ ]* 10.3 Write property tests for multi-provider and filtering (Properties 18, 19, 20, 21) - [x] 10.3 Write property tests for multi-provider and filtering (Properties 18, 19, 20, 21)
- **Property 18: Multi-provider merge with naming conflict resolution** - **Property 18: Multi-provider merge with naming conflict resolution**
- **Property 19: Provider block generation** - **Property 19: Provider block generation**
- **Property 20: Scan profile validation completeness** (additional coverage) - **Property 20: Scan profile validation completeness** (additional coverage)
- **Property 21: Filtering correctness** - **Property 21: Filtering correctness**
- **Validates: Requirements 5.3, 5.4, 6.1, 6.2, 6.4, 6.6, 6.7** - **Validates: Requirements 5.3, 5.4, 6.1, 6.2, 6.4, 6.6, 6.7**
- [ ] 11. Implement resource type filter and multi-provider failure handling - [x] 11. Implement resource type filter and multi-provider failure handling
- [ ] 11.1 Implement resource type filtering in scanner - [x] 11.1 Implement resource type filtering in scanner
- When filters specified, discover only listed resource types - When filters specified, discover only listed resource types
- When no filters specified, discover all supported types for provider - When no filters specified, discover all supported types for provider
- _Requirements: 6.2, 6.3_ - _Requirements: 6.2, 6.3_
- [ ] 11.2 Implement multi-provider partial failure handling - [x] 11.2 Implement multi-provider partial failure handling
- Complete scanning for all remaining providers when one fails - Complete scanning for all remaining providers when one fails
- Include successfully discovered resources in inventory - Include successfully discovered resources in inventory
- Report which providers failed with error details - Report which providers failed with error details
- _Requirements: 5.5_ - _Requirements: 5.5_
- [ ] 12. Final checkpoint - Ensure all tests pass - [x] 12. Final checkpoint - Ensure all tests pass
- Ensure all tests pass, ask the user if questions arise. - Ensure all tests pass, ask the user if questions arise.
## Notes ## Notes

42
pyproject.toml Normal file
View File

@@ -0,0 +1,42 @@
[build-system]
requires = ["setuptools>=68.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "iac-reverse"
version = "0.1.0"
description = "Reverse engineer existing on-premises infrastructure into Terraform HCL code and state files"
readme = "README.md"
requires-python = ">=3.11"
license = {text = "MIT"}
dependencies = [
"click>=8.1.7",
"jinja2>=3.1.3",
"networkx>=3.2.1",
"pyyaml>=6.0.1",
"kubernetes>=28.1.0",
"docker>=7.0.0",
"pywinrm>=0.4.3",
"python-synology>=1.0.0",
]
[project.optional-dependencies]
dev = [
"hypothesis>=6.92.0",
"pytest>=7.4.4",
"pytest-cov>=4.1.0",
]
[project.scripts]
iac-reverse = "iac_reverse.cli:main"
[tool.setuptools.packages.find]
where = ["src"]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["src"]
[tool.hypothesis]
max_examples = 100

View File

@@ -0,0 +1,23 @@
Metadata-Version: 2.4
Name: iac-reverse
Version: 0.1.0
Summary: Reverse engineer existing on-premises infrastructure into Terraform HCL code and state files
License: MIT
Requires-Python: >=3.11
Description-Content-Type: text/markdown
Requires-Dist: click>=8.1.7
Requires-Dist: jinja2>=3.1.3
Requires-Dist: networkx>=3.2.1
Requires-Dist: pyyaml>=6.0.1
Requires-Dist: kubernetes>=28.1.0
Requires-Dist: docker>=7.0.0
Requires-Dist: pywinrm>=0.4.3
Requires-Dist: python-synology>=1.0.0
Provides-Extra: dev
Requires-Dist: hypothesis>=6.92.0; extra == "dev"
Requires-Dist: pytest>=7.4.4; extra == "dev"
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
# SnarfCode
# I added this line to test the syncing

View File

@@ -0,0 +1,17 @@
README.md
pyproject.toml
src/iac_reverse/__init__.py
src/iac_reverse.egg-info/PKG-INFO
src/iac_reverse.egg-info/SOURCES.txt
src/iac_reverse.egg-info/dependency_links.txt
src/iac_reverse.egg-info/entry_points.txt
src/iac_reverse.egg-info/requires.txt
src/iac_reverse.egg-info/top_level.txt
src/iac_reverse/auth/__init__.py
src/iac_reverse/cli/__init__.py
src/iac_reverse/generator/__init__.py
src/iac_reverse/incremental/__init__.py
src/iac_reverse/resolver/__init__.py
src/iac_reverse/scanner/__init__.py
src/iac_reverse/state_builder/__init__.py
src/iac_reverse/validator/__init__.py

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,2 @@
[console_scripts]
iac-reverse = iac_reverse.cli:main

View File

@@ -0,0 +1,13 @@
click>=8.1.7
jinja2>=3.1.3
networkx>=3.2.1
pyyaml>=6.0.1
kubernetes>=28.1.0
docker>=7.0.0
pywinrm>=0.4.3
python-synology>=1.0.0
[dev]
hypothesis>=6.92.0
pytest>=7.4.4
pytest-cov>=4.1.0

View File

@@ -0,0 +1 @@
iac_reverse

View File

@@ -0,0 +1,58 @@
"""IaC Reverse Engineering Tool.
Reverse engineer existing on-premises infrastructure into Terraform HCL code and state files.
"""
__version__ = "0.1.0"
from iac_reverse.models import (
ChangeType,
ChangeSummary,
CodeGenerationResult,
CpuArchitecture,
DependencyGraph,
DiscoveredResource,
ExtractedVariable,
GeneratedFile,
PlannedChange,
PlatformCategory,
PROVIDER_PLATFORM_MAP,
ProviderType,
ResourceChange,
ResourceRelationship,
ScanProfile,
ScanProgress,
ScanResult,
StateEntry,
StateFile,
UnresolvedReference,
ValidationError,
ValidationResult,
)
from iac_reverse.plugin_base import ProviderPlugin
__all__ = [
"ChangeType",
"ChangeSummary",
"CodeGenerationResult",
"CpuArchitecture",
"DependencyGraph",
"DiscoveredResource",
"ExtractedVariable",
"GeneratedFile",
"PlannedChange",
"PlatformCategory",
"PROVIDER_PLATFORM_MAP",
"ProviderPlugin",
"ProviderType",
"ResourceChange",
"ResourceRelationship",
"ScanProfile",
"ScanProgress",
"ScanResult",
"StateEntry",
"StateFile",
"UnresolvedReference",
"ValidationError",
"ValidationResult",
]

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,21 @@
"""Authentication module for Authentik SSO integration."""
from iac_reverse.auth.authentik_auth import (
AuthenticationError,
AuthentikAuthProvider,
AuthentikConfig,
AuthentikSession,
)
from iac_reverse.auth.authentik_discovery import (
AuthentikDiscoveryError,
AuthentikDiscoveryPlugin,
)
__all__ = [
"AuthenticationError",
"AuthentikAuthProvider",
"AuthentikConfig",
"AuthentikSession",
"AuthentikDiscoveryError",
"AuthentikDiscoveryPlugin",
]

View File

@@ -0,0 +1,204 @@
"""Authentik SSO authentication provider.
Handles OAuth2/OIDC authentication flow with an Authentik instance,
including token refresh and validation.
"""
from dataclasses import dataclass, field
from urllib.parse import urljoin
import requests
@dataclass
class AuthentikConfig:
"""Configuration for connecting to an Authentik instance."""
base_url: str # Authentik instance URL (e.g., "https://auth.internal.lab")
client_id: str # OAuth2 client ID for this tool
client_secret: str # OAuth2 client secret
@dataclass
class AuthentikSession:
"""Active session from Authentik SSO authentication."""
access_token: str
refresh_token: str
user_id: str
groups: list[str] = field(default_factory=list)
class AuthenticationError(Exception):
"""Raised when Authentik authentication fails."""
pass
class AuthentikAuthProvider:
"""Handles SSO authentication for the tool via Authentik OAuth2/OIDC.
Provides methods to authenticate users, refresh expired sessions,
and validate existing tokens against the Authentik instance.
"""
def authenticate_user(self, config: AuthentikConfig) -> AuthentikSession:
"""Initiate OAuth2/OIDC flow with Authentik and return a session.
Uses the client credentials or resource owner password grant to obtain
an access token from Authentik's token endpoint.
Args:
config: Authentik connection configuration.
Returns:
An AuthentikSession with access/refresh tokens and user info.
Raises:
AuthenticationError: If authentication fails for any reason.
"""
token_url = urljoin(config.base_url.rstrip("/") + "/", "application/o/token/")
try:
response = requests.post(
token_url,
data={
"grant_type": "client_credentials",
"client_id": config.client_id,
"client_secret": config.client_secret,
"scope": "openid profile email",
},
timeout=30,
)
except requests.RequestException as e:
raise AuthenticationError(
f"Authentik: failed to connect to {config.base_url} - {e}"
)
if response.status_code != 200:
raise AuthenticationError(
f"Authentik: authentication failed with status {response.status_code} "
f"- {response.text}"
)
token_data = response.json()
access_token = token_data.get("access_token", "")
refresh_token = token_data.get("refresh_token", "")
# Fetch user info to get user_id and groups
user_id, groups = self._fetch_user_info(config.base_url, access_token)
return AuthentikSession(
access_token=access_token,
refresh_token=refresh_token,
user_id=user_id,
groups=groups,
)
def refresh_session(
self, config: AuthentikConfig, session: AuthentikSession
) -> AuthentikSession:
"""Refresh an expired session token.
Args:
config: Authentik connection configuration.
session: The current session with a valid refresh token.
Returns:
A new AuthentikSession with refreshed tokens.
Raises:
AuthenticationError: If the refresh fails.
"""
token_url = urljoin(config.base_url.rstrip("/") + "/", "application/o/token/")
try:
response = requests.post(
token_url,
data={
"grant_type": "refresh_token",
"refresh_token": session.refresh_token,
"client_id": config.client_id,
"client_secret": config.client_secret,
},
timeout=30,
)
except requests.RequestException as e:
raise AuthenticationError(
f"Authentik: failed to refresh session - {e}"
)
if response.status_code != 200:
raise AuthenticationError(
f"Authentik: token refresh failed with status {response.status_code} "
f"- {response.text}"
)
token_data = response.json()
access_token = token_data.get("access_token", "")
refresh_token = token_data.get("refresh_token", session.refresh_token)
user_id, groups = self._fetch_user_info(config.base_url, access_token)
return AuthentikSession(
access_token=access_token,
refresh_token=refresh_token,
user_id=user_id,
groups=groups,
)
def validate_token(self, config: AuthentikConfig, token: str) -> bool:
"""Validate an existing token is still valid.
Checks the token against Authentik's userinfo endpoint.
Args:
config: Authentik connection configuration.
token: The access token to validate.
Returns:
True if the token is valid, False otherwise.
"""
userinfo_url = urljoin(
config.base_url.rstrip("/") + "/", "application/o/userinfo/"
)
try:
response = requests.get(
userinfo_url,
headers={"Authorization": f"Bearer {token}"},
timeout=10,
)
return response.status_code == 200
except requests.RequestException:
return False
def _fetch_user_info(
self, base_url: str, access_token: str
) -> tuple[str, list[str]]:
"""Fetch user info from Authentik's userinfo endpoint.
Args:
base_url: Authentik instance base URL.
access_token: Valid access token.
Returns:
Tuple of (user_id, groups list).
"""
userinfo_url = urljoin(base_url.rstrip("/") + "/", "application/o/userinfo/")
try:
response = requests.get(
userinfo_url,
headers={"Authorization": f"Bearer {access_token}"},
timeout=10,
)
if response.status_code == 200:
data = response.json()
user_id = data.get("sub", "")
groups = data.get("groups", [])
return user_id, groups
except requests.RequestException:
pass
return "", []

View File

@@ -0,0 +1,384 @@
"""Authentik discovery plugin.
Discovers Authentik configurations as infrastructure resources, including
flows, stages, providers, applications, outposts, property mappings,
certificates, groups, and sources.
"""
from typing import Callable
from urllib.parse import urljoin
import requests
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
class AuthentikDiscoveryError(Exception):
"""Raised when Authentik discovery encounters an error."""
pass
# Mapping of resource types to their Authentik API endpoints
_RESOURCE_TYPE_API_MAP: dict[str, str] = {
"authentik_flow": "api/v3/flows/instances/",
"authentik_stage": "api/v3/stages/all/",
"authentik_provider": "api/v3/providers/all/",
"authentik_application": "api/v3/core/applications/",
"authentik_outpost": "api/v3/outposts/instances/",
"authentik_property_mapping": "api/v3/propertymappings/all/",
"authentik_certificate": "api/v3/crypto/certificatekeypairs/",
"authentik_group": "api/v3/core/groups/",
"authentik_source": "api/v3/sources/all/",
}
class AuthentikDiscoveryPlugin(ProviderPlugin):
"""Discovers Authentik configurations as infrastructure resources.
Connects to an Authentik instance via its REST API and enumerates
flows, stages, providers, applications, outposts, property mappings,
certificates, groups, and sources.
Since Authentik is an identity provider (not a traditional infrastructure
platform), it uses PlatformCategory.CONTAINER_ORCHESTRATION as a
categorization convenience — Authentik typically runs as a containerized
service within the orchestration layer.
"""
def __init__(self) -> None:
self._base_url: str = ""
self._api_token: str = ""
self._authenticated: bool = False
def authenticate(self, credentials: dict[str, str]) -> None:
"""Authenticate with the Authentik REST API.
Expected credentials:
- base_url: Authentik instance URL (e.g., "https://auth.internal.lab")
- api_token: Authentik API token for administrative access
Args:
credentials: Dictionary with base_url and api_token.
Raises:
AuthentikDiscoveryError: If authentication fails.
"""
base_url = credentials.get("base_url", "")
api_token = credentials.get("api_token", "")
if not base_url:
raise AuthentikDiscoveryError(
"Authentik: 'base_url' is required in credentials"
)
if not api_token:
raise AuthentikDiscoveryError(
"Authentik: 'api_token' is required in credentials"
)
self._base_url = base_url.rstrip("/")
self._api_token = api_token
# Verify connectivity by hitting the core API
try:
response = requests.get(
self._build_url("api/v3/core/applications/"),
headers=self._auth_headers(),
params={"page_size": 1},
timeout=30,
)
except requests.RequestException as e:
raise AuthentikDiscoveryError(
f"Authentik: failed to connect to {base_url} - {e}"
)
if response.status_code == 401:
raise AuthentikDiscoveryError(
"Authentik: authentication failed - invalid API token"
)
if response.status_code == 403:
raise AuthentikDiscoveryError(
"Authentik: authentication failed - insufficient permissions"
)
if response.status_code not in (200, 201):
raise AuthentikDiscoveryError(
f"Authentik: unexpected status {response.status_code} "
f"during authentication check"
)
self._authenticated = True
def get_platform_category(self) -> PlatformCategory:
"""Return the platform category for Authentik.
Authentik is an identity provider that typically runs as a containerized
service, so it is categorized under CONTAINER_ORCHESTRATION.
"""
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
"""Return the Authentik instance endpoint.
Returns:
List containing the configured Authentik base URL.
"""
if not self._base_url:
return []
return [self._base_url]
def list_supported_resource_types(self) -> list[str]:
"""Return all Authentik resource types this plugin can discover.
Returns:
List of Authentik resource type strings.
"""
return [
"authentik_flow",
"authentik_stage",
"authentik_provider",
"authentik_application",
"authentik_outpost",
"authentik_property_mapping",
"authentik_certificate",
"authentik_group",
"authentik_source",
]
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect the CPU architecture of the Authentik host.
Authentik is a web service; architecture detection is not directly
applicable. Defaults to AMD64 as the most common deployment target.
Args:
endpoint: The Authentik endpoint URL.
Returns:
CpuArchitecture.AMD64 as the default.
"""
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover Authentik resources via the REST API.
Connects to the Authentik API and enumerates all resources of the
requested types. Reports progress via the callback function.
Args:
endpoints: List of Authentik endpoint URLs (typically one).
resource_types: List of resource type strings to discover.
progress_callback: Callable that receives ScanProgress updates.
Returns:
ScanResult containing all discovered Authentik resources.
Raises:
AuthentikDiscoveryError: If not authenticated.
"""
if not self._authenticated:
raise AuthentikDiscoveryError(
"Authentik: must authenticate before discovering resources"
)
import datetime
resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
endpoint = endpoints[0] if endpoints else self._base_url
total_types = len(resource_types)
for idx, resource_type in enumerate(resource_types):
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(resources),
resource_types_completed=idx,
total_resource_types=total_types,
)
)
if resource_type not in _RESOURCE_TYPE_API_MAP:
warnings.append(
f"Unsupported Authentik resource type: {resource_type}"
)
continue
try:
discovered = self._discover_resource_type(
resource_type, endpoint
)
resources.extend(discovered)
except Exception as e:
errors.append(
f"Error discovering {resource_type}: {e}"
)
# Final progress update
progress_callback(
ScanProgress(
current_resource_type="complete",
resources_discovered=len(resources),
resource_types_completed=total_types,
total_resource_types=total_types,
)
)
scan_timestamp = datetime.datetime.now(datetime.timezone.utc).isoformat()
return ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp=scan_timestamp,
profile_hash="",
is_partial=len(errors) > 0,
)
def _discover_resource_type(
self, resource_type: str, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all resources of a specific type from Authentik API.
Handles pagination to retrieve all results.
Args:
resource_type: The Authentik resource type to discover.
endpoint: The Authentik endpoint URL.
Returns:
List of DiscoveredResource objects.
"""
api_path = _RESOURCE_TYPE_API_MAP[resource_type]
results: list[DiscoveredResource] = []
page = 1
while True:
response = requests.get(
self._build_url(api_path),
headers=self._auth_headers(),
params={"page": page, "page_size": 100},
timeout=30,
)
if response.status_code != 200:
raise AuthentikDiscoveryError(
f"API request failed for {resource_type}: "
f"status {response.status_code}"
)
data = response.json()
items = data.get("results", [])
for item in items:
resource = self._map_to_resource(resource_type, item, endpoint)
results.append(resource)
# Check for next page
if data.get("pagination", {}).get("next", 0) > 0:
page += 1
else:
break
return results
def _map_to_resource(
self, resource_type: str, item: dict, endpoint: str
) -> DiscoveredResource:
"""Map an Authentik API response item to a DiscoveredResource.
Args:
resource_type: The resource type string.
item: The API response dictionary for a single resource.
endpoint: The Authentik endpoint URL.
Returns:
A DiscoveredResource instance.
"""
# Extract common fields with sensible defaults
unique_id = str(item.get("pk", item.get("uuid", item.get("id", ""))))
name = item.get("name", item.get("slug", item.get("title", unique_id)))
return DiscoveredResource(
resource_type=resource_type,
unique_id=f"authentik/{resource_type}/{unique_id}",
name=name,
provider=ProviderType.DOCKER_SWARM, # Closest match for containerized identity provider
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AMD64,
endpoint=endpoint,
attributes=item,
raw_references=self._extract_references(item),
)
def _extract_references(self, item: dict) -> list[str]:
"""Extract references to other resources from an API item.
Looks for common reference fields in Authentik API responses.
Args:
item: The API response dictionary.
Returns:
List of reference ID strings.
"""
references: list[str] = []
# Common reference fields in Authentik API
ref_fields = [
"flow",
"provider",
"application",
"outpost",
"group",
"source",
"certificate",
"stages",
"policies",
]
for field_name in ref_fields:
value = item.get(field_name)
if value is None:
continue
if isinstance(value, str) and value:
references.append(value)
elif isinstance(value, list):
for v in value:
if isinstance(v, str) and v:
references.append(v)
return references
def _build_url(self, path: str) -> str:
"""Build a full URL from the base URL and a relative path.
Args:
path: Relative API path.
Returns:
Full URL string.
"""
return urljoin(self._base_url + "/", path)
def _auth_headers(self) -> dict[str, str]:
"""Return authorization headers for API requests.
Returns:
Dictionary with Authorization header.
"""
return {"Authorization": f"Bearer {self._api_token}"}

View File

@@ -0,0 +1,6 @@
"""CLI module for command-line interface."""
from iac_reverse.cli.cli import cli, main
from iac_reverse.cli.profile_loader import ProfileLoader, ProfileLoaderError
__all__ = ["cli", "main", "ProfileLoader", "ProfileLoaderError"]

Binary file not shown.

444
src/iac_reverse/cli/cli.py Normal file
View File

@@ -0,0 +1,444 @@
"""CLI entry point for the IaC Reverse Engineering tool.
Provides commands for scanning infrastructure, generating Terraform code,
running incremental diffs, validating output, and authenticating via Authentik SSO.
"""
import sys
from pathlib import Path
from typing import Optional
import click
import yaml
from iac_reverse.models import (
ProviderType,
ScanProfile,
ScanProgress,
)
def _load_scan_profile(profile_path: str) -> ScanProfile:
"""Load a ScanProfile from a YAML file.
Args:
profile_path: Path to the YAML scan profile file.
Returns:
A ScanProfile instance.
Raises:
click.ClickException: If the file cannot be read or parsed.
"""
path = Path(profile_path)
if not path.exists():
raise click.ClickException(f"Profile not found: {profile_path}")
try:
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise click.ClickException(f"Invalid YAML in profile: {e}")
if not isinstance(data, dict):
raise click.ClickException("Profile must be a YAML mapping")
provider_str = data.get("provider", "")
try:
provider = ProviderType(provider_str)
except ValueError:
raise click.ClickException(
f"Unknown provider '{provider_str}'. "
f"Supported: {[p.value for p in ProviderType]}"
)
return ScanProfile(
provider=provider,
credentials=data.get("credentials", {}),
endpoints=data.get("endpoints"),
resource_type_filters=data.get("resource_type_filters"),
authentik_token=data.get("authentik_token"),
)
def _create_plugin(profile: ScanProfile):
"""Create the appropriate provider plugin for a scan profile.
Args:
profile: The ScanProfile specifying the provider.
Returns:
A ProviderPlugin instance for the profile's provider.
Raises:
click.ClickException: If the provider plugin cannot be created.
"""
from iac_reverse.scanner.docker_swarm_plugin import DockerSwarmPlugin
from iac_reverse.scanner.kubernetes_plugin import KubernetesPlugin
from iac_reverse.scanner.synology_plugin import SynologyPlugin
from iac_reverse.scanner.harvester_plugin import HarvesterPlugin
from iac_reverse.scanner.bare_metal_plugin import BareMetalPlugin
from iac_reverse.scanner.windows_plugin import WindowsPlugin
plugin_map = {
ProviderType.DOCKER_SWARM: DockerSwarmPlugin,
ProviderType.KUBERNETES: KubernetesPlugin,
ProviderType.SYNOLOGY: SynologyPlugin,
ProviderType.HARVESTER: HarvesterPlugin,
ProviderType.BARE_METAL: BareMetalPlugin,
ProviderType.WINDOWS: WindowsPlugin,
}
plugin_class = plugin_map.get(profile.provider)
if plugin_class is None:
raise click.ClickException(
f"No plugin available for provider '{profile.provider.value}'"
)
return plugin_class()
def _progress_callback(progress: ScanProgress) -> None:
"""Display scan progress to the user."""
click.echo(
f" [{progress.resource_types_completed}/{progress.total_resource_types}] "
f"Scanning {progress.current_resource_type}... "
f"({progress.resources_discovered} resources found)"
)
@click.group()
@click.version_option(version="0.1.0", prog_name="iac-reverse")
def cli():
"""IaC Reverse Engineering Tool.
Reverse-engineer on-premises infrastructure into Terraform HCL code and state files.
"""
pass
@cli.command()
@click.option(
"--profile",
required=True,
type=click.Path(exists=True),
help="Path to YAML scan profile.",
)
def scan(profile: str):
"""Scan infrastructure and display discovered resources.
Loads the scan profile, connects to the provider, and discovers
all matching resources.
"""
from iac_reverse.scanner.scanner import Scanner
click.echo(f"Loading scan profile: {profile}")
scan_profile = _load_scan_profile(profile)
click.echo(f"Provider: {scan_profile.provider.value}")
click.echo("Creating plugin...")
plugin = _create_plugin(scan_profile)
click.echo("Starting scan...")
scanner = Scanner(profile=scan_profile, plugin=plugin)
try:
result = scanner.scan(progress_callback=_progress_callback)
except Exception as e:
raise click.ClickException(f"Scan failed: {e}")
click.echo("")
click.echo(f"Scan complete: {len(result.resources)} resources discovered")
if result.warnings:
click.echo(f"Warnings: {len(result.warnings)}")
for w in result.warnings:
click.echo(f"{w}")
if result.errors:
click.echo(f"Errors: {len(result.errors)}")
for e in result.errors:
click.echo(f"{e}")
@cli.command()
@click.option(
"--profile",
required=True,
type=click.Path(exists=True),
help="Path to YAML scan profile.",
)
@click.option(
"--output-dir",
required=True,
type=click.Path(),
help="Output directory for generated Terraform files.",
)
def generate(profile: str, output_dir: str):
"""Run full pipeline: scan → resolve → generate → state → validate.
Scans infrastructure, resolves dependencies, generates Terraform HCL
code, builds state file, and validates the output.
"""
from iac_reverse.scanner.scanner import Scanner
from iac_reverse.resolver.resolver import DependencyResolver
from iac_reverse.generator.code_generator import CodeGenerator
from iac_reverse.state_builder.state_builder import StateBuilder
from iac_reverse.validator.validator import Validator
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Step 1: Scan
click.echo(f"Loading scan profile: {profile}")
scan_profile = _load_scan_profile(profile)
plugin = _create_plugin(scan_profile)
click.echo("Step 1/5: Scanning infrastructure...")
scanner = Scanner(profile=scan_profile, plugin=plugin)
try:
scan_result = scanner.scan(progress_callback=_progress_callback)
except Exception as e:
raise click.ClickException(f"Scan failed: {e}")
click.echo(f" Found {len(scan_result.resources)} resources")
# Step 2: Resolve dependencies
click.echo("Step 2/5: Resolving dependencies...")
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
click.echo(
f" Resolved {len(graph.relationships)} relationships, "
f"{len(graph.cycles)} cycles detected"
)
# Step 3: Generate code
click.echo("Step 3/5: Generating Terraform code...")
generator = CodeGenerator()
code_result = generator.generate(graph, [scan_profile])
click.echo(f" Generated {len(code_result.resource_files)} resource files")
# Write generated files to output directory
for gen_file in code_result.resource_files:
file_path = output_path / gen_file.filename
file_path.write_text(gen_file.content, encoding="utf-8")
if code_result.variables_file.content:
(output_path / code_result.variables_file.filename).write_text(
code_result.variables_file.content, encoding="utf-8"
)
if code_result.provider_file.content:
(output_path / code_result.provider_file.filename).write_text(
code_result.provider_file.content, encoding="utf-8"
)
# Step 4: Build state
click.echo("Step 4/5: Building Terraform state...")
state_builder = StateBuilder()
state_file = state_builder.build(code_result, graph, provider_version="1.0.0")
state_json = state_file.to_json()
(output_path / "terraform.tfstate").write_text(state_json, encoding="utf-8")
click.echo(f" State file: {len(state_file.resources)} entries")
if state_builder.unmapped_resources:
click.echo(f" Unmapped: {len(state_builder.unmapped_resources)} resources")
# Step 5: Validate
click.echo("Step 5/5: Validating output...")
validator = Validator()
validation = validator.validate(str(output_path))
if validation.validate_success:
click.echo(" ✓ Validation passed")
else:
click.echo(" ✗ Validation failed")
for err in validation.errors:
click.echo(f" {err.file}:{err.line} - {err.message}")
# Summary
click.echo("")
click.echo("Generation complete:")
click.echo(f" Output directory: {output_dir}")
click.echo(f" Resource files: {len(code_result.resource_files)}")
click.echo(f" Total resources: {len(scan_result.resources)}")
@cli.command()
@click.option(
"--profile",
required=True,
type=click.Path(exists=True),
help="Path to YAML scan profile.",
)
def diff(profile: str):
"""Run incremental scan and display changes.
Loads the previous snapshot, runs a new scan, compares results,
and displays the change summary.
"""
from iac_reverse.scanner.scanner import Scanner
from iac_reverse.incremental.snapshot_store import SnapshotStore
from iac_reverse.incremental.change_detector import ChangeDetector
click.echo(f"Loading scan profile: {profile}")
scan_profile = _load_scan_profile(profile)
plugin = _create_plugin(scan_profile)
# Load previous snapshot
click.echo("Loading previous snapshot...")
snapshot_store = SnapshotStore()
scanner = Scanner(profile=scan_profile, plugin=plugin)
profile_hash = scanner._compute_profile_hash()
previous = snapshot_store.load_previous(profile_hash)
if previous is None:
click.echo(" No previous snapshot found (first scan)")
# Run current scan
click.echo("Scanning infrastructure...")
try:
current = scanner.scan(progress_callback=_progress_callback)
except Exception as e:
raise click.ClickException(f"Scan failed: {e}")
# Compare
click.echo("Comparing with previous scan...")
detector = ChangeDetector()
summary = detector.compare(current, previous)
# Store new snapshot
snapshot_store.store_snapshot(current, profile_hash)
click.echo(" Snapshot saved")
# Display results
click.echo("")
click.echo("Change Summary:")
click.echo(f" Added: {summary.added_count}")
click.echo(f" Removed: {summary.removed_count}")
click.echo(f" Modified: {summary.modified_count}")
if summary.changes:
click.echo("")
for change in summary.changes:
symbol = {"added": "+", "removed": "-", "modified": "~"}
s = symbol.get(change.change_type.value, "?")
click.echo(
f" {s} {change.resource_type}/{change.resource_name}"
)
@cli.command()
@click.option(
"--dir",
"output_dir",
required=True,
type=click.Path(exists=True),
help="Path to directory containing Terraform output to validate.",
)
def validate(output_dir: str):
"""Validate existing Terraform output.
Runs terraform init, validate, and plan against the specified
directory and reports results.
"""
from iac_reverse.validator.validator import Validator
click.echo(f"Validating: {output_dir}")
validator = Validator()
result = validator.validate(output_dir)
click.echo("")
click.echo("Validation Results:")
click.echo(f" terraform init: {'' if result.init_success else ''}")
click.echo(f" terraform validate: {'' if result.validate_success else ''}")
click.echo(f" terraform plan: {'' if result.plan_success else ''}")
if result.correction_attempts > 0:
click.echo(f" Auto-corrections: {result.correction_attempts}")
if result.errors:
click.echo("")
click.echo("Errors:")
for err in result.errors:
location = f"{err.file}:{err.line}" if err.file else "(general)"
click.echo(f"{location} - {err.message}")
if result.planned_changes:
click.echo("")
click.echo(f"Planned Changes ({len(result.planned_changes)}):")
for change in result.planned_changes:
click.echo(
f" {change.change_type}: {change.resource_address}"
)
if result.validate_success and result.plan_success:
click.echo("")
click.echo("✓ All validations passed - no drift detected")
elif result.validate_success and not result.plan_success:
click.echo("")
click.echo("⚠ Validation passed but drift detected")
@cli.command()
@click.option(
"--url",
required=True,
help="Authentik instance URL (e.g., https://auth.internal.lab).",
)
@click.option(
"--client-id",
required=True,
help="OAuth2 client ID for this tool.",
)
@click.option(
"--client-secret",
prompt=True,
hide_input=True,
help="OAuth2 client secret (prompted if not provided).",
)
def login(url: str, client_id: str, client_secret: str):
"""Authenticate with Authentik SSO.
Performs OAuth2/OIDC authentication and stores the token
for use by subsequent commands.
"""
from iac_reverse.auth.authentik_auth import (
AuthentikAuthProvider,
AuthentikConfig,
AuthenticationError,
)
click.echo(f"Authenticating with Authentik at {url}...")
config = AuthentikConfig(
base_url=url,
client_id=client_id,
client_secret=client_secret,
)
provider = AuthentikAuthProvider()
try:
session = provider.authenticate_user(config)
except AuthenticationError as e:
raise click.ClickException(f"Authentication failed: {e}")
# Store token in local config directory
token_dir = Path(".iac-reverse")
token_dir.mkdir(parents=True, exist_ok=True)
token_file = token_dir / "token"
token_file.write_text(session.access_token, encoding="utf-8")
click.echo(f"✓ Authenticated as user: {session.user_id}")
click.echo(f" Groups: {', '.join(session.groups) if session.groups else 'none'}")
click.echo(f" Token stored in {token_file}")
def main():
"""Main entry point for the CLI."""
cli()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,184 @@
"""Profile loader for YAML scan profiles with environment variable expansion.
Handles loading single and multi-profile YAML files, expanding ${ENV_VAR}
and ${ENV_VAR:-default} patterns in credential field values.
"""
import os
import re
from pathlib import Path
from typing import Any
import yaml
from iac_reverse.models import ProviderType, ScanProfile
# Pattern matches ${VAR_NAME} or ${VAR_NAME:-default_value}
_ENV_VAR_PATTERN = re.compile(r"\$\{([^}:]+)(?::-([^}]*))?\}")
class ProfileLoaderError(Exception):
"""Raised when profile loading or env var expansion fails."""
pass
class ProfileLoader:
"""Loads scan profiles from YAML files with environment variable expansion.
Supports:
- Single profile YAML (a dict with provider, credentials, etc.)
- Multi-profile YAML (a list of profile dicts)
- ${ENV_VAR} expansion in credential values
- ${ENV_VAR:-default} syntax for defaults when env var is unset
"""
def load(self, path: str) -> list[ScanProfile]:
"""Load one or more ScanProfiles from a YAML file.
Args:
path: Path to the YAML scan profile file.
Returns:
A list of ScanProfile instances.
Raises:
ProfileLoaderError: If the file cannot be read, parsed, or contains
invalid profile data.
"""
file_path = Path(path)
if not file_path.exists():
raise ProfileLoaderError(f"Profile not found: {path}")
try:
with open(file_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
raise ProfileLoaderError(f"Invalid YAML in profile: {e}")
if data is None:
raise ProfileLoaderError("Profile file is empty")
if isinstance(data, list):
# Multi-profile YAML
profiles = []
for i, item in enumerate(data):
if not isinstance(item, dict):
raise ProfileLoaderError(
f"Profile at index {i} must be a YAML mapping"
)
profiles.append(self._parse_profile(item, index=i))
return profiles
elif isinstance(data, dict):
# Single profile YAML
return [self._parse_profile(data)]
else:
raise ProfileLoaderError(
"Profile must be a YAML mapping or a list of mappings"
)
def expand_env_vars(self, value: str) -> str:
"""Expand ${ENV_VAR} and ${ENV_VAR:-default} patterns in a string.
Args:
value: String potentially containing env var references.
Returns:
The string with all env var references replaced by their values.
Raises:
ProfileLoaderError: If an env var is not set and no default is provided.
"""
def _replace(match: re.Match) -> str:
var_name = match.group(1)
default_value = match.group(2)
env_value = os.environ.get(var_name)
if env_value is not None:
return env_value
if default_value is not None:
return default_value
raise ProfileLoaderError(
f"Environment variable '{var_name}' is not set and no default provided"
)
return _ENV_VAR_PATTERN.sub(_replace, value)
def _parse_profile(
self, data: dict[str, Any], index: int | None = None
) -> ScanProfile:
"""Parse a single profile dict into a ScanProfile.
Args:
data: Dictionary with profile configuration.
index: Optional index for error messages in multi-profile files.
Returns:
A ScanProfile instance with env vars expanded in credentials.
Raises:
ProfileLoaderError: If required fields are missing or invalid.
"""
context = f" at index {index}" if index is not None else ""
provider_str = data.get("provider")
if not provider_str:
raise ProfileLoaderError(f"Missing 'provider' field{context}")
try:
provider = ProviderType(provider_str)
except ValueError:
raise ProfileLoaderError(
f"Unknown provider '{provider_str}'{context}. "
f"Supported: {[p.value for p in ProviderType]}"
)
credentials = data.get("credentials", {})
if not isinstance(credentials, dict):
raise ProfileLoaderError(
f"'credentials' must be a mapping{context}"
)
# Expand env vars in credential values recursively
expanded_credentials = self._expand_credentials(credentials)
endpoints = data.get("endpoints")
resource_type_filters = data.get("resource_type_filters")
authentik_token = data.get("authentik_token")
# Expand env vars in authentik_token if it's a string
if isinstance(authentik_token, str):
authentik_token = self.expand_env_vars(authentik_token)
return ScanProfile(
provider=provider,
credentials=expanded_credentials,
endpoints=endpoints,
resource_type_filters=resource_type_filters,
authentik_token=authentik_token,
)
def _expand_credentials(self, credentials: dict[str, Any]) -> dict[str, str]:
"""Recursively expand environment variables in credential values.
Args:
credentials: Dictionary of credential key-value pairs.
Returns:
Dictionary with all string values having env vars expanded.
"""
expanded: dict[str, str] = {}
for key, value in credentials.items():
if isinstance(value, str):
expanded[key] = self.expand_env_vars(value)
elif isinstance(value, dict):
# Recursively expand nested dicts
expanded[key] = self._expand_credentials(value)
else:
# Keep non-string values as-is (numbers, booleans, etc.)
expanded[key] = value
return expanded

View File

@@ -0,0 +1,15 @@
"""Code generator module for Terraform HCL output."""
from iac_reverse.generator.code_generator import CodeGenerator
from iac_reverse.generator.provider_block import ProviderBlockGenerator
from iac_reverse.generator.resource_merger import ResourceMerger
from iac_reverse.generator.sanitize import sanitize_identifier
from iac_reverse.generator.variable_extractor import VariableExtractor
__all__ = [
"CodeGenerator",
"ProviderBlockGenerator",
"ResourceMerger",
"VariableExtractor",
"sanitize_identifier",
]

View File

@@ -0,0 +1,304 @@
"""HCL code generator using Jinja2 templates.
Produces Terraform HCL files from a DependencyGraph and list of ScanProfiles.
Organizes output by resource type (one .tf file per type), includes traceability
comments, architecture-specific tags/labels, and uses Terraform resource
references for inter-resource dependencies.
"""
import logging
from collections import defaultdict
from jinja2 import Environment, BaseLoader
from iac_reverse.generator.sanitize import sanitize_identifier
from iac_reverse.models import (
CodeGenerationResult,
CpuArchitecture,
DependencyGraph,
DiscoveredResource,
GeneratedFile,
ResourceRelationship,
ScanProfile,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Jinja2 HCL Templates
# ---------------------------------------------------------------------------
_RESOURCE_BLOCK_TEMPLATE = """\
{% for resource in resources %}
# Source: {{ resource.unique_id }}
resource "{{ resource.resource_type }}" "{{ resource.tf_name }}" {
{% for key, value in resource.attributes.items() %}
{{ key }} = {{ value }}
{% endfor %}
{% if resource.tags %}
tags = {
{% for tag_key, tag_value in resource.tags.items() %}
"{{ tag_key }}" = "{{ tag_value }}"
{% endfor %}
}
{% endif %}
{% if resource.dependencies %}
{% for dep in resource.dependencies %}
depends_on = [{{ dep }}]
{% endfor %}
{% endif %}
}
{% endfor %}
"""
_RESOURCE_BLOCK_TEMPLATE_V2 = """\
{% for resource in resources %}
# Source: {{ resource.unique_id }}
resource "{{ resource.resource_type }}" "{{ resource.tf_name }}" {
{% for key, value in resource.rendered_attributes %}
{{ key }} = {{ value }}
{% endfor %}
{% if resource.tags %}
tags = {
{% for tag_key, tag_value in resource.tags.items() %}
"{{ tag_key }}" = "{{ tag_value }}"
{% endfor %}
}
{% endif %}
}
{% endfor %}
"""
# ---------------------------------------------------------------------------
# Helper: format HCL attribute values
# ---------------------------------------------------------------------------
def _format_hcl_value(value: object) -> str:
"""Format a Python value as an HCL literal string."""
if isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, int):
return str(value)
elif isinstance(value, float):
return str(value)
elif isinstance(value, str):
# Escape quotes in strings
escaped = value.replace("\\", "\\\\").replace('"', '\\"')
return f'"{escaped}"'
elif isinstance(value, list):
items = [_format_hcl_value(item) for item in value]
return "[" + ", ".join(items) + "]"
elif isinstance(value, dict):
lines = []
lines.append("{")
for k, v in value.items():
lines.append(f' "{k}" = {_format_hcl_value(v)}')
lines.append(" }")
return "\n".join(lines)
else:
return f'"{value}"'
# ---------------------------------------------------------------------------
# Internal data structure for template rendering
# ---------------------------------------------------------------------------
class _RenderableResource:
"""Internal representation of a resource ready for template rendering."""
def __init__(
self,
resource_type: str,
tf_name: str,
unique_id: str,
rendered_attributes: list[tuple[str, str]],
tags: dict[str, str],
):
self.resource_type = resource_type
self.tf_name = tf_name
self.unique_id = unique_id
self.rendered_attributes = rendered_attributes
self.tags = tags
# ---------------------------------------------------------------------------
# CodeGenerator
# ---------------------------------------------------------------------------
class CodeGenerator:
"""Generates Terraform HCL files from a dependency graph.
Accepts a DependencyGraph and list of ScanProfiles, produces one .tf file
per resource type with traceability comments, architecture tags, and
Terraform resource references for dependencies.
"""
def __init__(self) -> None:
"""Initialize the code generator with Jinja2 environment."""
self._env = Environment(
loader=BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
)
self._template = self._env.from_string(_RESOURCE_BLOCK_TEMPLATE_V2)
def generate(
self, graph: DependencyGraph, profiles: list[ScanProfile]
) -> CodeGenerationResult:
"""Generate Terraform HCL from a dependency graph.
Args:
graph: The DependencyGraph containing resources and relationships.
profiles: List of ScanProfiles used during scanning.
Returns:
CodeGenerationResult with resource_files, variables_file, and
provider_file. Variables and provider files are empty placeholders
(implemented in tasks 5.3 and 5.4).
"""
# Build lookup maps
resource_map: dict[str, DiscoveredResource] = {
r.unique_id: r for r in graph.resources
}
# Map from source_id -> list of (target_id, relationship)
relationships_by_source: dict[str, list[ResourceRelationship]] = defaultdict(
list
)
for rel in graph.relationships:
relationships_by_source[rel.source_id].append(rel)
# Group resources by type
resources_by_type: dict[str, list[DiscoveredResource]] = defaultdict(list)
for resource in graph.resources:
resources_by_type[resource.resource_type].append(resource)
# Generate one file per resource type
resource_files: list[GeneratedFile] = []
for resource_type, resources in sorted(resources_by_type.items()):
renderable_resources = []
for resource in resources:
renderable = self._build_renderable(
resource, relationships_by_source, resource_map
)
renderable_resources.append(renderable)
content = self._template.render(resources=renderable_resources)
filename = f"{resource_type}.tf"
resource_files.append(
GeneratedFile(
filename=filename,
content=content,
resource_count=len(resources),
)
)
# Placeholder files for tasks 5.3 and 5.4
variables_file = GeneratedFile(
filename="variables.tf",
content="",
resource_count=0,
)
provider_file = GeneratedFile(
filename="providers.tf",
content="",
resource_count=0,
)
return CodeGenerationResult(
resource_files=resource_files,
variables_file=variables_file,
provider_file=provider_file,
)
def _build_renderable(
self,
resource: DiscoveredResource,
relationships_by_source: dict[str, list[ResourceRelationship]],
resource_map: dict[str, DiscoveredResource],
) -> _RenderableResource:
"""Build a renderable resource with formatted attributes and references.
For attributes that reference other resources in the graph, replaces
the hardcoded ID with a Terraform resource reference expression.
"""
tf_name = sanitize_identifier(resource.name)
# Build a set of target IDs this resource references
target_ids_for_resource: dict[str, ResourceRelationship] = {}
for rel in relationships_by_source.get(resource.unique_id, []):
target_ids_for_resource[rel.target_id] = rel
# Render attributes, replacing references with Terraform expressions
rendered_attributes: list[tuple[str, str]] = []
for attr_key, attr_value in resource.attributes.items():
resolved_value = self._resolve_attribute_value(
attr_value, target_ids_for_resource, resource_map
)
rendered_attributes.append((attr_key, resolved_value))
# Generate architecture-specific tags
tags = self._generate_architecture_tags(resource)
return _RenderableResource(
resource_type=resource.resource_type,
tf_name=tf_name,
unique_id=resource.unique_id,
rendered_attributes=rendered_attributes,
tags=tags,
)
def _resolve_attribute_value(
self,
value: object,
target_ids: dict[str, ResourceRelationship],
resource_map: dict[str, DiscoveredResource],
) -> str:
"""Resolve an attribute value, replacing resource IDs with Terraform references.
If the value matches a target resource's unique_id or name, returns a
Terraform resource reference expression. Otherwise formats as HCL literal.
"""
if isinstance(value, str):
# Check if this string matches a target resource's unique_id
if value in target_ids:
target_resource = resource_map[value]
return self._make_terraform_reference(target_resource)
# Check if this string matches a target resource's name
for target_id, rel in target_ids.items():
target_resource = resource_map[target_id]
if value == target_resource.name:
return self._make_terraform_reference(target_resource)
# Default: format as HCL literal
return _format_hcl_value(value)
def _make_terraform_reference(self, target_resource: DiscoveredResource) -> str:
"""Create a Terraform resource reference expression.
Example: kubernetes_namespace.default.id
"""
target_tf_name = sanitize_identifier(target_resource.name)
return f"{target_resource.resource_type}.{target_tf_name}.id"
def _generate_architecture_tags(
self, resource: DiscoveredResource
) -> dict[str, str]:
"""Generate architecture-specific tags/labels for a resource.
Returns a dict of tag key-value pairs including the CPU architecture.
"""
tags: dict[str, str] = {
"arch": resource.architecture.value,
"managed_by": "iac-reverse",
}
return tags

View File

@@ -0,0 +1,197 @@
"""Provider block generator for Terraform HCL output.
Generates a providers.tf file containing:
- A terraform { required_providers { ... } } block listing all providers used
- Individual provider configuration blocks with platform-specific settings
(endpoints, certificates, credentials) for each distinct provider type.
"""
from __future__ import annotations
from iac_reverse.models import ProviderType, ScanProfile, GeneratedFile
# ---------------------------------------------------------------------------
# Provider metadata: maps ProviderType to Terraform provider details
# ---------------------------------------------------------------------------
# Each entry: (terraform_provider_name, source, version_constraint)
_PROVIDER_METADATA: dict[ProviderType, tuple[str, str, str]] = {
ProviderType.KUBERNETES: (
"kubernetes",
"hashicorp/kubernetes",
"~> 2.0",
),
ProviderType.DOCKER_SWARM: (
"docker",
"kreuzwerker/docker",
"~> 3.0",
),
ProviderType.SYNOLOGY: (
"synology",
"synology-community/synology",
"~> 0.2",
),
ProviderType.HARVESTER: (
"harvester",
"harvester/harvester",
"~> 0.6",
),
ProviderType.BARE_METAL: (
"redfish",
"dell/redfish",
"~> 1.0",
),
ProviderType.WINDOWS: (
"windows",
"hashicorp/windows",
"~> 0.1",
),
}
def _generate_provider_config(
provider_type: ProviderType, profile: ScanProfile
) -> str:
"""Generate the provider configuration block for a given provider type.
Uses credentials and endpoints from the ScanProfile to populate
platform-specific configuration attributes.
"""
tf_name = _PROVIDER_METADATA[provider_type][0]
lines: list[str] = []
lines.append(f'provider "{tf_name}" {{')
if provider_type == ProviderType.KUBERNETES:
host = profile.credentials.get("host", "")
cluster_ca = profile.credentials.get("cluster_ca_certificate", "")
token = profile.credentials.get("token", "")
lines.append(f' host = "{host}"')
lines.append(f' cluster_ca_certificate = "{cluster_ca}"')
lines.append(f' token = "{token}"')
elif provider_type == ProviderType.DOCKER_SWARM:
host = profile.credentials.get("host", "")
cert_path = profile.credentials.get("cert_path", "")
lines.append(f' host = "{host}"')
lines.append(f' cert_path = "{cert_path}"')
elif provider_type == ProviderType.SYNOLOGY:
url = profile.credentials.get("url", "")
username = profile.credentials.get("username", "")
password = profile.credentials.get("password", "")
lines.append(f' url = "{url}"')
lines.append(f' username = "{username}"')
lines.append(f' password = "{password}"')
elif provider_type == ProviderType.HARVESTER:
kubeconfig = profile.credentials.get("kubeconfig", "")
lines.append(f' kubeconfig = "{kubeconfig}"')
elif provider_type == ProviderType.BARE_METAL:
endpoint = profile.credentials.get("endpoint", "")
username = profile.credentials.get("username", "")
password = profile.credentials.get("password", "")
lines.append(f' endpoint = "{endpoint}"')
lines.append(f' username = "{username}"')
lines.append(f' password = "{password}"')
elif provider_type == ProviderType.WINDOWS:
host = profile.credentials.get("host", "")
username = profile.credentials.get("username", "")
password = profile.credentials.get("password", "")
lines.append(f' host = "{host}"')
lines.append(f' username = "{username}"')
lines.append(f' password = "{password}"')
lines.append("")
lines.append(" winrm {")
winrm_port = profile.credentials.get("winrm_port", "5985")
winrm_use_ssl = profile.credentials.get("winrm_use_ssl", "false")
lines.append(f" port = {winrm_port}")
lines.append(f" use_ssl = {winrm_use_ssl}")
lines.append(" }")
lines.append("}")
return "\n".join(lines)
def _generate_required_providers_block(
provider_types: set[ProviderType],
) -> str:
"""Generate the terraform { required_providers { ... } } block."""
lines: list[str] = []
lines.append("terraform {")
lines.append(" required_providers {")
for provider_type in sorted(provider_types, key=lambda p: p.value):
tf_name, source, version = _PROVIDER_METADATA[provider_type]
lines.append(f" {tf_name} = {{")
lines.append(f' source = "{source}"')
lines.append(f' version = "{version}"')
lines.append(" }")
lines.append(" }")
lines.append("}")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# ProviderBlockGenerator
# ---------------------------------------------------------------------------
class ProviderBlockGenerator:
"""Generates Terraform provider configuration blocks.
Accepts a list of ScanProfiles and a set of ProviderTypes used in the
generated code, and produces a providers.tf file containing:
- A terraform { required_providers { ... } } block
- Individual provider blocks with platform-specific configuration
"""
def generate(
self,
profiles: list[ScanProfile],
provider_types: set[ProviderType],
) -> GeneratedFile:
"""Generate the providers.tf file content.
Args:
profiles: List of ScanProfiles providing credentials/endpoints.
provider_types: Set of distinct ProviderTypes used in the code.
Returns:
A GeneratedFile with filename "providers.tf" and the HCL content.
"""
# Build a map from ProviderType -> first matching profile
profile_map: dict[ProviderType, ScanProfile] = {}
for profile in profiles:
if profile.provider not in profile_map:
profile_map[profile.provider] = profile
sections: list[str] = []
# 1. required_providers block
sections.append(_generate_required_providers_block(provider_types))
# 2. Individual provider configuration blocks
for provider_type in sorted(provider_types, key=lambda p: p.value):
profile = profile_map.get(provider_type)
if profile is not None:
sections.append(
_generate_provider_config(provider_type, profile)
)
else:
# Generate a placeholder block if no profile matches
tf_name = _PROVIDER_METADATA[provider_type][0]
sections.append(
f'provider "{tf_name}" {{\n # No profile provided\n}}'
)
content = "\n\n".join(sections) + "\n"
return GeneratedFile(
filename="providers.tf",
content=content,
resource_count=0,
)

View File

@@ -0,0 +1,59 @@
"""Multi-provider resource merging with conflict resolution.
Merges resources from multiple scan profiles into a unified inventory,
resolving naming conflicts by prefixing with the provider identifier.
"""
from dataclasses import replace
from collections import defaultdict
from iac_reverse.models import DiscoveredResource, ScanResult
class ResourceMerger:
"""Merges resources from multiple ScanResult objects into a unified list.
When resources from different providers share the same name, the merger
resolves the conflict by prefixing each conflicting resource's name with
its provider identifier (e.g., "kubernetes_nginx", "docker_swarm_nginx").
Provider-specific attributes are preserved unchanged.
"""
def merge(self, scan_results: list[ScanResult]) -> list[DiscoveredResource]:
"""Merge resources from multiple scan results into a unified list.
Args:
scan_results: List of ScanResult objects, one per provider/scan profile.
Returns:
A unified list of DiscoveredResource with naming conflicts resolved
by prefixing conflicting names with the provider identifier.
"""
# Collect all resources from all scan results
all_resources: list[DiscoveredResource] = []
for result in scan_results:
all_resources.extend(result.resources)
# Group resources by name to detect conflicts
resources_by_name: dict[str, list[DiscoveredResource]] = defaultdict(list)
for resource in all_resources:
resources_by_name[resource.name].append(resource)
# Identify conflicting names: same name from different providers
conflicting_names: set[str] = set()
for name, resources in resources_by_name.items():
providers = {r.provider for r in resources}
if len(providers) > 1:
conflicting_names.add(name)
# Build the merged list, resolving conflicts
merged: list[DiscoveredResource] = []
for resource in all_resources:
if resource.name in conflicting_names:
prefixed_name = f"{resource.provider.value}_{resource.name}"
merged.append(replace(resource, name=prefixed_name))
else:
merged.append(resource)
return merged

View File

@@ -0,0 +1,41 @@
"""Identifier sanitization for Terraform resource names.
Converts arbitrary resource names into valid Terraform identifiers
matching the pattern: ^[a-zA-Z_][a-zA-Z0-9_]*$
"""
import re
def sanitize_identifier(name: str) -> str:
"""Convert a resource name to a valid Terraform identifier.
Terraform identifiers must match: ^[a-zA-Z_][a-zA-Z0-9_]*$
Rules applied:
- Replace any character not in [a-zA-Z0-9_] with underscore
- Collapse multiple consecutive underscores into one
- If result starts with a digit, prepend an underscore
- If result is empty or only underscores, return "_resource"
Args:
name: Any string resource name.
Returns:
A valid Terraform identifier derived from the input.
"""
# Replace any non-alphanumeric/underscore character with underscore
result = re.sub(r"[^a-zA-Z0-9_]", "_", name)
# Collapse multiple consecutive underscores into one
result = re.sub(r"_+", "_", result)
# If result is only underscores or empty, return fallback
if not result or result.strip("_") == "":
return "_resource"
# If starts with a digit, prepend underscore
if result[0].isdigit():
result = "_" + result
return result

View File

@@ -0,0 +1,203 @@
"""Variable extraction logic for Terraform code generation.
Identifies attribute values that appear in 2+ resources and extracts them
into Terraform variables with appropriate type expressions and defaults.
"""
import logging
from collections import defaultdict
from iac_reverse.models import DiscoveredResource, ExtractedVariable
logger = logging.getLogger(__name__)
def _infer_type_expr(value: object) -> str:
"""Infer a Terraform type expression from a Python value.
Args:
value: The Python value to infer a type for.
Returns:
A Terraform type expression string (e.g., "string", "number", "bool").
"""
if isinstance(value, bool):
return "bool"
elif isinstance(value, int) or isinstance(value, float):
return "number"
elif isinstance(value, str):
return "string"
elif isinstance(value, list):
return "list(string)"
elif isinstance(value, dict):
return "map(string)"
else:
return "string"
def _format_default_value(value: object) -> str:
"""Format a Python value as a Terraform default value literal.
Args:
value: The Python value to format.
Returns:
A string representation suitable for a Terraform variable default.
"""
if isinstance(value, bool):
return "true" if value else "false"
elif isinstance(value, int) or isinstance(value, float):
return str(value)
elif isinstance(value, str):
return f'"{value}"'
elif isinstance(value, list):
items = ", ".join(f'"{item}"' if isinstance(item, str) else str(item) for item in value)
return f"[{items}]"
elif isinstance(value, dict):
entries = ", ".join(f'"{k}" = "{v}"' for k, v in value.items())
return "{" + entries + "}"
else:
return f'"{value}"'
def _make_hashable(value: object) -> object:
"""Convert a value to a hashable representation for counting.
Args:
value: Any Python value from resource attributes.
Returns:
A hashable version of the value.
"""
if isinstance(value, dict):
return tuple(sorted(value.items()))
elif isinstance(value, list):
return tuple(value)
else:
return value
class VariableExtractor:
"""Extracts shared attribute values into Terraform variables.
Scans a list of DiscoveredResource objects, identifies attribute values
that appear in 2 or more resources, and creates ExtractedVariable instances
for each shared value.
"""
def extract_variables(
self, resources: list[DiscoveredResource]
) -> list[ExtractedVariable]:
"""Identify shared attribute values and extract them as variables.
For each attribute key, collects all values across all resources.
If a value appears in 2+ resources for the same attribute key,
it becomes a variable with the most common value as the default.
Args:
resources: List of discovered resources to analyze.
Returns:
List of ExtractedVariable instances for shared values.
"""
if len(resources) < 2:
return []
# Collect attribute values grouped by attribute key
# key -> {hashable_value -> [list of (resource_unique_id, original_value)]}
attr_values: dict[str, dict[object, list[tuple[str, object]]]] = defaultdict(
lambda: defaultdict(list)
)
for resource in resources:
for attr_key, attr_value in resource.attributes.items():
# Skip complex nested structures (dicts/lists) for variable extraction
# as they are less likely to be meaningfully shared
if isinstance(attr_value, (dict, list)):
continue
hashable = _make_hashable(attr_value)
attr_values[attr_key][hashable].append(
(resource.unique_id, attr_value)
)
# Build extracted variables for values appearing in 2+ resources
variables: list[ExtractedVariable] = []
for attr_key, value_groups in sorted(attr_values.items()):
# Find all values that appear in 2+ resources for this key
shared_values = [
(hv, entries)
for hv, entries in value_groups.items()
if len(entries) >= 2
]
if not shared_values:
continue
# If only one shared value exists for this key, use the key as the var name
# If multiple shared values exist, disambiguate with a suffix
for idx, (hashable_value, resource_entries) in enumerate(shared_values):
original_value = resource_entries[0][1]
used_by = [entry[0] for entry in resource_entries]
# Determine the most common value among the shared values for this key
# The default is set to the most common value overall
most_common_entries = max(shared_values, key=lambda x: len(x[1]))
most_common_value = most_common_entries[1][0][1]
# Use the most common value as default for the primary variable,
# but each variable's default is its own value
default_value = _format_default_value(original_value)
if len(shared_values) == 1:
var_name = f"var_{attr_key}"
else:
# Disambiguate when multiple shared values exist for same key
var_name = f"var_{attr_key}_{idx}"
type_expr = _infer_type_expr(original_value)
description = (
f"Shared {attr_key} value extracted from "
f"{len(resource_entries)} resources"
)
variables.append(
ExtractedVariable(
name=var_name,
type_expr=type_expr,
default_value=default_value,
description=description,
used_by=used_by,
)
)
return variables
def generate_variables_tf(
self, variables: list[ExtractedVariable]
) -> str:
"""Generate Terraform variables.tf file content.
Produces variable blocks with type, description, and default values.
Args:
variables: List of extracted variables to render.
Returns:
String content for a variables.tf file.
"""
if not variables:
return ""
blocks: list[str] = []
for var in variables:
block = (
f'variable "{var.name}" {{\n'
f' type = {var.type_expr}\n'
f' description = "{var.description}"\n'
f' default = {var.default_value}\n'
f'}}'
)
blocks.append(block)
return "\n\n".join(blocks) + "\n"

View File

@@ -0,0 +1,7 @@
"""Incremental scan engine for change detection."""
from iac_reverse.incremental.change_detector import ChangeDetector
from iac_reverse.incremental.incremental_updater import IncrementalUpdater
from iac_reverse.incremental.snapshot_store import SnapshotStore
__all__ = ["ChangeDetector", "IncrementalUpdater", "SnapshotStore"]

View File

@@ -0,0 +1,144 @@
"""Change detection and classification for incremental scans.
Compares current scan results against previous snapshots to identify
added, removed, and modified resources.
"""
from typing import Optional
from iac_reverse.models import (
ChangeSummary,
ChangeType,
DiscoveredResource,
ResourceChange,
ScanResult,
)
class ChangeDetector:
"""Detects and classifies changes between scan results.
Compares resources by unique_id to determine which resources
have been added, removed, or modified between scans.
"""
def compare(
self, current: ScanResult, previous: Optional[ScanResult]
) -> ChangeSummary:
"""Compare current scan against a previous scan result.
Args:
current: The current scan result.
previous: The previous scan result, or None for first scan.
Returns:
A ChangeSummary with counts and list of ResourceChange objects.
If previous is None, all current resources are classified as ADDED.
"""
if previous is None:
return self._handle_first_scan(current)
current_map = {r.unique_id: r for r in current.resources}
previous_map = {r.unique_id: r for r in previous.resources}
changes: list[ResourceChange] = []
# Detect ADDED resources: in current but not in previous
for uid, resource in current_map.items():
if uid not in previous_map:
changes.append(
ResourceChange(
resource_id=resource.unique_id,
resource_type=resource.resource_type,
resource_name=resource.name,
change_type=ChangeType.ADDED,
changed_attributes=None,
)
)
# Detect REMOVED resources: in previous but not in current
for uid, resource in previous_map.items():
if uid not in current_map:
changes.append(
ResourceChange(
resource_id=resource.unique_id,
resource_type=resource.resource_type,
resource_name=resource.name,
change_type=ChangeType.REMOVED,
changed_attributes=None,
)
)
# Detect MODIFIED resources: same unique_id but attributes differ
for uid in current_map:
if uid in previous_map:
changed_attrs = self._diff_attributes(
current_map[uid], previous_map[uid]
)
if changed_attrs:
resource = current_map[uid]
changes.append(
ResourceChange(
resource_id=resource.unique_id,
resource_type=resource.resource_type,
resource_name=resource.name,
change_type=ChangeType.MODIFIED,
changed_attributes=changed_attrs,
)
)
added_count = sum(1 for c in changes if c.change_type == ChangeType.ADDED)
removed_count = sum(1 for c in changes if c.change_type == ChangeType.REMOVED)
modified_count = sum(1 for c in changes if c.change_type == ChangeType.MODIFIED)
return ChangeSummary(
added_count=added_count,
removed_count=removed_count,
modified_count=modified_count,
changes=changes,
)
def _handle_first_scan(self, current: ScanResult) -> ChangeSummary:
"""Handle first scan with no previous snapshot.
All resources in the current scan are classified as ADDED.
"""
changes = [
ResourceChange(
resource_id=resource.unique_id,
resource_type=resource.resource_type,
resource_name=resource.name,
change_type=ChangeType.ADDED,
changed_attributes=None,
)
for resource in current.resources
]
return ChangeSummary(
added_count=len(changes),
removed_count=0,
modified_count=0,
changes=changes,
)
def _diff_attributes(
self, current: DiscoveredResource, previous: DiscoveredResource
) -> Optional[dict]:
"""Compare attributes between two versions of the same resource.
Returns a dict of changed attributes with 'old' and 'new' values,
or None if no attributes differ.
"""
if current.attributes == previous.attributes:
return None
changed: dict = {}
all_keys = set(current.attributes.keys()) | set(previous.attributes.keys())
for key in all_keys:
old_val = previous.attributes.get(key)
new_val = current.attributes.get(key)
if old_val != new_val:
changed[key] = {"old": old_val, "new": new_val}
return changed if changed else None

View File

@@ -0,0 +1,339 @@
"""Incremental updater for Terraform IaC files and state.
Applies a ChangeSummary to an existing output directory, modifying only
the .tf files and state file that contain changed resources. Supports
adding new resource blocks, removing existing blocks, and updating
modified resource attributes without full regeneration.
"""
import json
import logging
import re
from pathlib import Path
from typing import Optional
from iac_reverse.generator.code_generator import _format_hcl_value
from iac_reverse.generator.sanitize import sanitize_identifier
from iac_reverse.models import ChangeSummary, ChangeType, ResourceChange
logger = logging.getLogger(__name__)
class IncrementalUpdater:
"""Applies incremental changes to Terraform IaC files and state.
Accepts a ChangeSummary and an output directory path. Modifies only
the .tf files containing changed resources (one .tf file per resource
type) and updates the terraform.tfstate file accordingly.
For REMOVED resources: removes the resource block from the .tf file
and the corresponding entry from the state file.
For ADDED resources: appends a new resource block to the appropriate
.tf file (creating the file if it doesn't exist).
For MODIFIED resources: updates the existing resource block with new
attribute values.
"""
def __init__(
self,
change_summary: ChangeSummary,
output_dir: str,
resource_attributes: Optional[dict[str, dict]] = None,
) -> None:
"""Initialize the IncrementalUpdater.
Args:
change_summary: The ChangeSummary describing what changed.
output_dir: Path to the output directory containing .tf and
state files.
resource_attributes: Optional mapping of resource_id to full
attribute dict for ADDED/MODIFIED resources. Required for
generating resource blocks for added resources.
"""
self._change_summary = change_summary
self._output_dir = Path(output_dir)
self._resource_attributes = resource_attributes or {}
self._modified_files: set[str] = set()
@property
def modified_files(self) -> set[str]:
"""Return the set of file paths that were modified during apply."""
return set(self._modified_files)
def apply(self) -> None:
"""Apply all changes from the ChangeSummary to the output directory.
Processes removed, added, and modified resources, updating only
the affected .tf files and the state file.
"""
for change in self._change_summary.changes:
if change.change_type == ChangeType.REMOVED:
self._handle_removed(change)
elif change.change_type == ChangeType.ADDED:
self._handle_added(change)
elif change.change_type == ChangeType.MODIFIED:
self._handle_modified(change)
def _handle_removed(self, change: ResourceChange) -> None:
"""Remove a resource block from its .tf file and state entry.
Args:
change: The ResourceChange describing the removed resource.
"""
tf_file = self._get_tf_file_path(change.resource_type)
if tf_file.exists():
self._remove_resource_block(tf_file, change)
self._modified_files.add(str(tf_file))
self._remove_state_entry(change)
def _handle_added(self, change: ResourceChange) -> None:
"""Add a new resource block to the appropriate .tf file.
Args:
change: The ResourceChange describing the added resource.
"""
tf_file = self._get_tf_file_path(change.resource_type)
attributes = self._resource_attributes.get(change.resource_id, {})
self._add_resource_block(tf_file, change, attributes)
self._modified_files.add(str(tf_file))
def _handle_modified(self, change: ResourceChange) -> None:
"""Update an existing resource block with new attribute values.
Args:
change: The ResourceChange describing the modified resource.
"""
tf_file = self._get_tf_file_path(change.resource_type)
if tf_file.exists():
self._update_resource_block(tf_file, change)
self._modified_files.add(str(tf_file))
def _get_tf_file_path(self, resource_type: str) -> Path:
"""Get the .tf file path for a given resource type.
Each resource type maps to a file named <resource_type>.tf.
Args:
resource_type: The Terraform resource type string.
Returns:
Path to the .tf file for this resource type.
"""
return self._output_dir / f"{resource_type}.tf"
def _remove_resource_block(
self, tf_file: Path, change: ResourceChange
) -> None:
"""Remove a resource block from a .tf file.
Identifies the block by matching the resource type and sanitized
resource name in the resource declaration line.
Args:
tf_file: Path to the .tf file.
change: The ResourceChange identifying the resource to remove.
"""
content = tf_file.read_text(encoding="utf-8")
tf_name = sanitize_identifier(change.resource_name)
# Pattern to match the full resource block including optional comment
# Matches: optional comment line + resource "type" "name" { ... }
pattern = self._build_block_pattern(change.resource_type, tf_name)
new_content = re.sub(pattern, "", content)
# Clean up excessive blank lines
new_content = re.sub(r"\n{3,}", "\n\n", new_content)
new_content = new_content.strip()
if new_content:
new_content += "\n"
tf_file.write_text(new_content, encoding="utf-8")
def _add_resource_block(
self, tf_file: Path, change: ResourceChange, attributes: dict
) -> None:
"""Add a new resource block to a .tf file.
Creates the file if it doesn't exist. Appends the block at the end.
Args:
tf_file: Path to the .tf file.
change: The ResourceChange describing the added resource.
attributes: The full attribute dict for the resource.
"""
tf_name = sanitize_identifier(change.resource_name)
block = self._render_resource_block(
change.resource_type, tf_name, change.resource_id, attributes
)
if tf_file.exists():
content = tf_file.read_text(encoding="utf-8")
if content and not content.endswith("\n"):
content += "\n"
content += "\n" + block
else:
content = block
tf_file.write_text(content, encoding="utf-8")
def _update_resource_block(
self, tf_file: Path, change: ResourceChange
) -> None:
"""Update an existing resource block with changed attributes.
Replaces only the changed attribute lines within the block.
Args:
tf_file: Path to the .tf file.
change: The ResourceChange with changed_attributes dict.
"""
if not change.changed_attributes:
return
content = tf_file.read_text(encoding="utf-8")
tf_name = sanitize_identifier(change.resource_name)
# Find the resource block
pattern = self._build_block_pattern(change.resource_type, tf_name)
match = re.search(pattern, content)
if not match:
logger.warning(
"Could not find resource block for %s.%s in %s",
change.resource_type,
tf_name,
tf_file,
)
return
block = match.group(0)
updated_block = block
for attr_name, attr_change in change.changed_attributes.items():
new_value = attr_change.get("new")
if new_value is None:
# Attribute was removed - remove the line
attr_pattern = re.compile(
rf"^[ \t]*{re.escape(attr_name)}\s*=\s*.*$\n?",
re.MULTILINE,
)
updated_block = attr_pattern.sub("", updated_block)
else:
# Attribute was added or changed - update/add the line
hcl_value = _format_hcl_value(new_value)
attr_pattern = re.compile(
rf"^([ \t]*){re.escape(attr_name)}\s*=\s*.*$",
re.MULTILINE,
)
attr_match = attr_pattern.search(updated_block)
if attr_match:
# Replace existing attribute line
indent = attr_match.group(1)
replacement = f"{indent}{attr_name} = {hcl_value}"
updated_block = attr_pattern.sub(replacement, updated_block)
else:
# Add new attribute before the closing brace
updated_block = re.sub(
r"(\n})",
f"\n {attr_name} = {hcl_value}\\1",
updated_block,
count=1,
)
content = content.replace(block, updated_block)
tf_file.write_text(content, encoding="utf-8")
def _remove_state_entry(self, change: ResourceChange) -> None:
"""Remove a resource entry from the terraform.tfstate file.
Args:
change: The ResourceChange identifying the resource to remove.
"""
state_file = self._output_dir / "terraform.tfstate"
if not state_file.exists():
return
content = state_file.read_text(encoding="utf-8")
try:
state = json.loads(content)
except json.JSONDecodeError:
logger.warning("Could not parse state file: %s", state_file)
return
tf_name = sanitize_identifier(change.resource_name)
resources = state.get("resources", [])
state["resources"] = [
r
for r in resources
if not (
r.get("type") == change.resource_type
and r.get("name") == tf_name
)
]
# Increment serial to indicate state change
state["serial"] = state.get("serial", 0) + 1
state_file.write_text(
json.dumps(state, indent=2), encoding="utf-8"
)
self._modified_files.add(str(state_file))
def _build_block_pattern(
self, resource_type: str, tf_name: str
) -> re.Pattern:
"""Build a regex pattern to match a full resource block.
Matches an optional comment line (# Source: ...) followed by the
resource declaration and its body enclosed in braces.
Args:
resource_type: The Terraform resource type.
tf_name: The sanitized Terraform resource name.
Returns:
A compiled regex pattern matching the full block.
"""
# Match optional comment + resource block with balanced braces
# The block body can contain nested braces (e.g., tags = { ... })
escaped_type = re.escape(resource_type)
escaped_name = re.escape(tf_name)
pattern = (
rf"(?:# Source:.*\n)?"
rf'resource\s+"{escaped_type}"\s+"{escaped_name}"\s*\{{'
rf"[^{{}}]*(?:\{{[^{{}}]*\}}[^{{}}]*)*"
rf"\}}\n?"
)
return re.compile(pattern, re.DOTALL)
def _render_resource_block(
self,
resource_type: str,
tf_name: str,
source_id: str,
attributes: dict,
) -> str:
"""Render a Terraform resource block as HCL text.
Args:
resource_type: The Terraform resource type.
tf_name: The sanitized Terraform resource name.
source_id: The source resource identifier for traceability.
attributes: The attribute dict to render.
Returns:
A string containing the HCL resource block.
"""
lines = [f"# Source: {source_id}"]
lines.append(f'resource "{resource_type}" "{tf_name}" {{')
for key, value in attributes.items():
hcl_value = _format_hcl_value(value)
lines.append(f" {key} = {hcl_value}")
lines.append("}")
lines.append("") # trailing newline
return "\n".join(lines)

View File

@@ -0,0 +1,177 @@
"""Snapshot storage and retrieval for incremental scan comparison.
Stores scan results as timestamped JSON files in `.iac-reverse/snapshots/`
and provides retrieval of previous snapshots for change detection.
"""
import json
import os
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanResult,
)
# Default directory for snapshot storage
SNAPSHOT_DIR = os.path.join(".iac-reverse", "snapshots")
# Minimum number of snapshots to retain per profile
MIN_RETAINED_SNAPSHOTS = 2
def _serialize_scan_result(result: ScanResult) -> dict:
"""Serialize a ScanResult to a JSON-compatible dictionary."""
return {
"scan_timestamp": result.scan_timestamp,
"profile_hash": result.profile_hash,
"is_partial": result.is_partial,
"warnings": result.warnings,
"errors": result.errors,
"resources": [_serialize_resource(r) for r in result.resources],
}
def _serialize_resource(resource: DiscoveredResource) -> dict:
"""Serialize a DiscoveredResource to a JSON-compatible dictionary."""
return {
"resource_type": resource.resource_type,
"unique_id": resource.unique_id,
"name": resource.name,
"provider": resource.provider.value,
"platform_category": resource.platform_category.value,
"architecture": resource.architecture.value,
"endpoint": resource.endpoint,
"attributes": resource.attributes,
"raw_references": resource.raw_references,
}
def _deserialize_scan_result(data: dict) -> ScanResult:
"""Deserialize a dictionary into a ScanResult."""
resources = [_deserialize_resource(r) for r in data["resources"]]
return ScanResult(
resources=resources,
warnings=data["warnings"],
errors=data["errors"],
scan_timestamp=data["scan_timestamp"],
profile_hash=data["profile_hash"],
is_partial=data.get("is_partial", False),
)
def _deserialize_resource(data: dict) -> DiscoveredResource:
"""Deserialize a dictionary into a DiscoveredResource."""
return DiscoveredResource(
resource_type=data["resource_type"],
unique_id=data["unique_id"],
name=data["name"],
provider=ProviderType(data["provider"]),
platform_category=PlatformCategory(data["platform_category"]),
architecture=CpuArchitecture(data["architecture"]),
endpoint=data["endpoint"],
attributes=data["attributes"],
raw_references=data.get("raw_references", []),
)
class SnapshotStore:
"""Manages storage and retrieval of scan result snapshots.
Stores scan results as timestamped JSON files in a configurable
directory (defaults to `.iac-reverse/snapshots/`). Supports
retrieval of the most recent snapshot for a given profile hash
and automatic pruning of old snapshots.
"""
def __init__(self, base_dir: Optional[str] = None) -> None:
"""Initialize the snapshot store.
Args:
base_dir: Base directory for snapshot storage.
Defaults to `.iac-reverse/snapshots/`.
"""
self._snapshot_dir = Path(base_dir) if base_dir else Path(SNAPSHOT_DIR)
@property
def snapshot_dir(self) -> Path:
"""Return the snapshot directory path."""
return self._snapshot_dir
def store_snapshot(self, result: ScanResult, profile_hash: str) -> None:
"""Store a scan result as a timestamped JSON snapshot.
Args:
result: The scan result to store.
profile_hash: Hash identifying the scan profile.
The snapshot is saved with filename format:
{profile_hash}_{timestamp}.json
where timestamp is ISO format with colons replaced by dashes.
After storing, old snapshots are pruned to retain at least
MIN_RETAINED_SNAPSHOTS most recent files per profile_hash.
"""
self._snapshot_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")
filename = f"{profile_hash}_{timestamp}.json"
filepath = self._snapshot_dir / filename
data = _serialize_scan_result(result)
with open(filepath, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
self._prune_snapshots(profile_hash)
def load_previous(self, profile_hash: str) -> Optional[ScanResult]:
"""Load the most recent snapshot for a given profile hash.
Args:
profile_hash: Hash identifying the scan profile.
Returns:
The most recent ScanResult for the profile, or None if
no snapshot exists.
"""
snapshots = self._list_snapshots(profile_hash)
if not snapshots:
return None
# Sort by filename (which includes timestamp) to get most recent
snapshots.sort()
most_recent = snapshots[-1]
with open(most_recent, "r", encoding="utf-8") as f:
data = json.load(f)
return _deserialize_scan_result(data)
def _list_snapshots(self, profile_hash: str) -> list[Path]:
"""List all snapshot files for a given profile hash."""
if not self._snapshot_dir.exists():
return []
prefix = f"{profile_hash}_"
return [
p
for p in self._snapshot_dir.iterdir()
if p.is_file() and p.name.startswith(prefix) and p.name.endswith(".json")
]
def _prune_snapshots(self, profile_hash: str) -> None:
"""Remove old snapshots, keeping at least MIN_RETAINED_SNAPSHOTS most recent."""
snapshots = self._list_snapshots(profile_hash)
if len(snapshots) <= MIN_RETAINED_SNAPSHOTS:
return
# Sort by filename (timestamp is embedded) and remove oldest
snapshots.sort()
to_remove = snapshots[: len(snapshots) - MIN_RETAINED_SNAPSHOTS]
for snapshot_path in to_remove:
snapshot_path.unlink()

425
src/iac_reverse/models.py Normal file
View File

@@ -0,0 +1,425 @@
"""Core data models for the IaC Reverse Engineering tool.
Contains enums, dataclasses, and type definitions used across all components
of the pipeline: Scanner, Dependency Resolver, Code Generator, State Builder,
Validator, and Incremental Scan Engine.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class ProviderType(Enum):
"""Supported on-premises infrastructure provider types."""
DOCKER_SWARM = "docker_swarm"
KUBERNETES = "kubernetes"
SYNOLOGY = "synology"
HARVESTER = "harvester"
BARE_METAL = "bare_metal"
WINDOWS = "windows"
class PlatformCategory(Enum):
"""Categorizes providers by their infrastructure model."""
CONTAINER_ORCHESTRATION = "container" # Docker Swarm, Kubernetes
STORAGE_APPLIANCE = "storage" # Synology Disk Station
HCI = "hci" # SUSE Harvester (Hyper-Converged Infrastructure)
BARE_METAL = "bare_metal" # Physical servers (Linux)
WINDOWS = "windows" # Standalone Windows machines
PROVIDER_PLATFORM_MAP: dict[ProviderType, PlatformCategory] = {
ProviderType.DOCKER_SWARM: PlatformCategory.CONTAINER_ORCHESTRATION,
ProviderType.KUBERNETES: PlatformCategory.CONTAINER_ORCHESTRATION,
ProviderType.SYNOLOGY: PlatformCategory.STORAGE_APPLIANCE,
ProviderType.HARVESTER: PlatformCategory.HCI,
ProviderType.BARE_METAL: PlatformCategory.BARE_METAL,
ProviderType.WINDOWS: PlatformCategory.WINDOWS,
}
class CpuArchitecture(Enum):
"""CPU architecture of the host or resource."""
AMD64 = "amd64"
ARM = "arm"
AARCH64 = "aarch64"
class ChangeType(Enum):
"""Classification of resource changes between scan runs."""
ADDED = "added"
REMOVED = "removed"
MODIFIED = "modified"
# ---------------------------------------------------------------------------
# Provider supported resource types
# ---------------------------------------------------------------------------
PROVIDER_SUPPORTED_RESOURCE_TYPES: dict[ProviderType, list[str]] = {
ProviderType.DOCKER_SWARM: [
"docker_service",
"docker_network",
"docker_volume",
"docker_config",
"docker_secret",
],
ProviderType.KUBERNETES: [
"kubernetes_deployment",
"kubernetes_service",
"kubernetes_ingress",
"kubernetes_config_map",
"kubernetes_persistent_volume",
"kubernetes_namespace",
],
ProviderType.SYNOLOGY: [
"synology_shared_folder",
"synology_volume",
"synology_storage_pool",
"synology_replication_task",
"synology_user",
],
ProviderType.HARVESTER: [
"harvester_virtualmachine",
"harvester_volume",
"harvester_image",
"harvester_network",
],
ProviderType.BARE_METAL: [
"bare_metal_hardware",
"bare_metal_bmc_config",
"bare_metal_network_interface",
"bare_metal_raid_config",
],
ProviderType.WINDOWS: [
"windows_service",
"windows_scheduled_task",
"windows_iis_site",
"windows_iis_app_pool",
"windows_network_adapter",
"windows_firewall_rule",
"windows_installed_software",
"windows_feature",
"windows_hyperv_vm",
"windows_hyperv_switch",
"windows_dns_record",
"windows_local_user",
"windows_local_group",
],
}
MAX_RESOURCE_TYPE_FILTERS = 200
# ---------------------------------------------------------------------------
# Scanner dataclasses
# ---------------------------------------------------------------------------
@dataclass
class ScanProfile:
"""Configuration for a single infrastructure scan."""
provider: ProviderType
credentials: dict[str, str]
endpoints: Optional[list[str]] = None
resource_type_filters: Optional[list[str]] = None
authentik_token: Optional[str] = None
def validate(self) -> list[str]:
"""Returns list of validation errors, empty if valid.
Validates:
- credentials must not be empty
- resource_type_filters must have at most MAX_RESOURCE_TYPE_FILTERS entries
- resource_type_filters entries must be supported by the provider
"""
errors: list[str] = []
if not self.credentials:
errors.append("credentials must not be empty")
if self.resource_type_filters is not None:
if len(self.resource_type_filters) > MAX_RESOURCE_TYPE_FILTERS:
errors.append(
f"resource_type_filters must have at most "
f"{MAX_RESOURCE_TYPE_FILTERS} entries, "
f"got {len(self.resource_type_filters)}"
)
supported = set(PROVIDER_SUPPORTED_RESOURCE_TYPES[self.provider])
unsupported = [
rt for rt in self.resource_type_filters if rt not in supported
]
if unsupported:
errors.append(
f"unsupported resource types for provider "
f"'{self.provider.value}': {unsupported}"
)
return errors
@property
def platform_category(self) -> PlatformCategory:
"""Return the platform category for this profile's provider."""
return PROVIDER_PLATFORM_MAP[self.provider]
@dataclass
class DiscoveredResource:
"""A single resource discovered from an infrastructure provider."""
resource_type: str
unique_id: str
name: str
provider: ProviderType
platform_category: PlatformCategory
architecture: CpuArchitecture
endpoint: str
attributes: dict
raw_references: list[str] = field(default_factory=list)
@dataclass
class ScanResult:
"""Complete result of a scan operation."""
resources: list[DiscoveredResource]
warnings: list[str]
errors: list[str]
scan_timestamp: str
profile_hash: str
is_partial: bool = False
@dataclass
class ScanProgress:
"""Progress update during a scan operation."""
current_resource_type: str
resources_discovered: int
resource_types_completed: int
total_resource_types: int
# ---------------------------------------------------------------------------
# Dependency Resolver dataclasses
# ---------------------------------------------------------------------------
@dataclass
class ResourceRelationship:
"""A relationship between two discovered resources."""
source_id: str
target_id: str
relationship_type: str # "parent-child", "reference", "dependency"
source_attribute: str
@dataclass
class UnresolvedReference:
"""A reference that could not be resolved to a known resource."""
source_resource_id: str
source_attribute: str
referenced_id: str
suggested_resolution: str # "data_source" or "variable"
@dataclass
class CycleReport:
"""Report of a detected circular dependency with resolution suggestions."""
cycle: list[str] # Resource IDs forming the cycle
suggested_break: tuple[str, str] # (source_id, target_id) edge to break
break_relationship_type: str # Type of the relationship to break
resolution_strategy: str # Human-readable suggestion for resolution
@dataclass
class DependencyGraph:
"""Complete dependency graph of discovered resources."""
resources: list[DiscoveredResource]
relationships: list[ResourceRelationship]
topological_order: list[str]
cycles: list[list[str]]
unresolved_references: list[UnresolvedReference]
cycle_reports: list[CycleReport] = field(default_factory=list)
# ---------------------------------------------------------------------------
# Code Generator dataclasses
# ---------------------------------------------------------------------------
@dataclass
class GeneratedFile:
"""A single generated Terraform HCL file."""
filename: str
content: str
resource_count: int
@dataclass
class ExtractedVariable:
"""A Terraform variable extracted from common resource values."""
name: str
type_expr: str
default_value: str
description: str
used_by: list[str] = field(default_factory=list)
@dataclass
class CodeGenerationResult:
"""Complete result of code generation."""
resource_files: list[GeneratedFile]
variables_file: GeneratedFile
provider_file: GeneratedFile
outputs_file: Optional[GeneratedFile] = None
skipped_resources: list[tuple[str, str]] = field(default_factory=list)
# ---------------------------------------------------------------------------
# State Builder dataclasses
# ---------------------------------------------------------------------------
@dataclass
class StateEntry:
"""A single resource entry in the Terraform state file."""
resource_type: str
resource_name: str
provider_id: str
attributes: dict
sensitive_attributes: list[str] = field(default_factory=list)
schema_version: int = 0
dependencies: list[str] = field(default_factory=list)
@dataclass
class StateFile:
"""Terraform state file representation (format version 4)."""
version: int = 4
terraform_version: str = ""
serial: int = 1
lineage: str = ""
resources: list[StateEntry] = field(default_factory=list)
def to_json(self) -> str:
"""Serialize to Terraform state JSON format."""
import json
import uuid
lineage = self.lineage or str(uuid.uuid4())
state_resources = []
for entry in self.resources:
state_resources.append(
{
"mode": "managed",
"type": entry.resource_type,
"name": entry.resource_name,
"provider": f'provider["registry.terraform.io/hashicorp/{entry.resource_type.split("_")[0]}"]',
"instances": [
{
"schema_version": entry.schema_version,
"attributes": {
"id": entry.provider_id,
**entry.attributes,
},
"sensitive_attributes": entry.sensitive_attributes,
"dependencies": entry.dependencies,
}
],
}
)
state = {
"version": self.version,
"terraform_version": self.terraform_version,
"serial": self.serial,
"lineage": lineage,
"outputs": {},
"resources": state_resources,
}
return json.dumps(state, indent=2)
# ---------------------------------------------------------------------------
# Validator dataclasses
# ---------------------------------------------------------------------------
@dataclass
class PlannedChange:
"""A single planned change reported by terraform plan."""
resource_address: str
change_type: str # "add", "modify", "destroy"
details: str
@dataclass
class ValidationError:
"""A validation error from terraform validate or plan."""
file: str
message: str
line: Optional[int] = None
@dataclass
class ValidationResult:
"""Complete result of terraform validation."""
init_success: bool
validate_success: bool
plan_success: bool
planned_changes: list[PlannedChange] = field(default_factory=list)
errors: list[ValidationError] = field(default_factory=list)
correction_attempts: int = 0
# ---------------------------------------------------------------------------
# Incremental Scan dataclasses
# ---------------------------------------------------------------------------
@dataclass
class ResourceChange:
"""A single resource change between scan runs."""
resource_id: str
resource_type: str
resource_name: str
change_type: ChangeType
changed_attributes: Optional[dict] = None
@dataclass
class ChangeSummary:
"""Summary of changes between two scan runs."""
added_count: int
removed_count: int
modified_count: int
changes: list[ResourceChange] = field(default_factory=list)

View File

@@ -0,0 +1,103 @@
"""Provider plugin abstract base class.
Defines the interface that all infrastructure provider plugins must implement
to participate in the scanning pipeline.
"""
from abc import ABC, abstractmethod
from typing import Callable
from iac_reverse.models import (
CpuArchitecture,
PlatformCategory,
ScanProgress,
ScanResult,
)
class ProviderPlugin(ABC):
"""Interface that all provider plugins must implement.
Each on-premises platform (Docker Swarm, Kubernetes, Synology, Harvester,
Bare Metal, Windows) provides a concrete implementation of this class to
handle platform-specific authentication, discovery, and architecture detection.
"""
@abstractmethod
def authenticate(self, credentials: dict[str, str]) -> None:
"""Authenticate with the platform API.
Args:
credentials: Provider-specific authentication parameters
(API tokens, usernames, passwords, kubeconfig paths, etc.)
Raises:
AuthenticationError: If authentication fails, with a descriptive
message including the provider name and failure reason.
"""
...
@abstractmethod
def get_platform_category(self) -> PlatformCategory:
"""Return the platform category for this provider.
Returns:
The PlatformCategory enum value representing this provider's
infrastructure model (container orchestration, storage, HCI, etc.)
"""
...
@abstractmethod
def list_endpoints(self) -> list[str]:
"""Return all reachable endpoints/hosts for this provider.
Returns:
List of endpoint URLs or host addresses that can be scanned.
"""
...
@abstractmethod
def list_supported_resource_types(self) -> list[str]:
"""Return all resource types this plugin can discover.
Returns:
List of resource type strings (e.g., "kubernetes_deployment",
"windows_iis_site", "synology_shared_folder").
"""
...
@abstractmethod
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect the CPU architecture of the target host/node.
Args:
endpoint: The endpoint URL or host address to query.
Returns:
The CpuArchitecture enum value for the target.
"""
...
@abstractmethod
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover resources from the infrastructure provider.
Connects to the specified endpoints and enumerates resources of the
requested types. Reports progress via the callback function.
Args:
endpoints: List of endpoint URLs or host addresses to scan.
resource_types: List of resource type strings to discover.
Should be a subset of list_supported_resource_types().
progress_callback: Callable that receives ScanProgress updates
during the discovery process.
Returns:
ScanResult containing all discovered resources, warnings, and errors.
"""
...

View File

@@ -0,0 +1,5 @@
"""Dependency resolver module for resource relationship mapping."""
from iac_reverse.resolver.resolver import DependencyResolver
__all__ = ["DependencyResolver"]

View File

@@ -0,0 +1,443 @@
"""Dependency resolver for resource relationship mapping.
Analyzes discovered resources and their raw_references to build a dependency
graph with topological ordering. Identifies parent-child, reference, and
dependency relationships between resources. Detects circular dependencies
and suggests resolution strategies.
"""
import logging
import networkx as nx
from iac_reverse.models import (
CycleReport,
DependencyGraph,
DiscoveredResource,
ResourceRelationship,
ScanResult,
UnresolvedReference,
)
logger = logging.getLogger(__name__)
# Resource types that represent namespace/container resources (parent-child targets)
_NAMESPACE_RESOURCE_TYPES = frozenset(
[
"kubernetes_namespace",
"docker_network",
"harvester_network",
]
)
# Resource types that represent infrastructure dependencies (must exist before dependents)
# Maps: dependent resource type -> set of resource types it depends on
_DEPENDENCY_RESOURCE_TYPES: dict[str, frozenset[str]] = {
"windows_iis_site": frozenset(["windows_iis_app_pool"]),
"windows_hyperv_vm": frozenset(["windows_hyperv_switch"]),
"kubernetes_deployment": frozenset(["kubernetes_namespace", "kubernetes_config_map"]),
"kubernetes_service": frozenset(["kubernetes_namespace"]),
"kubernetes_ingress": frozenset(["kubernetes_namespace", "kubernetes_service"]),
"harvester_virtualmachine": frozenset(["harvester_network", "harvester_image"]),
}
# Priority for breaking relationships (lower = prefer to break first)
_RELATIONSHIP_BREAK_PRIORITY: dict[str, int] = {
"reference": 0,
"dependency": 1,
"parent-child": 2,
}
class DependencyResolver:
"""Resolves dependencies between discovered infrastructure resources.
Analyzes raw_references on each DiscoveredResource to identify relationships
and builds a networkx DiGraph for topological ordering. Detects cycles and
suggests resolution strategies.
"""
def __init__(self, scan_result: ScanResult) -> None:
"""Initialize the resolver with a scan result.
Args:
scan_result: The ScanResult containing discovered resources.
"""
self._scan_result = scan_result
self._resource_map: dict[str, DiscoveredResource] = {
r.unique_id: r for r in scan_result.resources
}
def resolve(self) -> DependencyGraph:
"""Analyze relationships and produce a dependency graph.
Builds the graph, detects cycles, suggests resolutions, and produces
a topological ordering (breaking cycle edges if necessary).
Returns:
DependencyGraph with resources, relationships, topological ordering,
cycles, cycle_reports, and unresolved_references.
"""
graph = nx.DiGraph()
relationships: list[ResourceRelationship] = []
unresolved_references: list[UnresolvedReference] = []
# Add all resources as nodes
for resource in self._scan_result.resources:
graph.add_node(resource.unique_id)
# Analyze raw_references to build edges and relationships
for resource in self._scan_result.resources:
for ref_id in resource.raw_references:
if ref_id not in self._resource_map:
# Unresolved reference - track it
source_attribute = self._identify_source_attribute_for_ref(
resource, ref_id
)
suggested_resolution = self._suggest_resolution(ref_id)
unresolved_references.append(
UnresolvedReference(
source_resource_id=resource.unique_id,
source_attribute=source_attribute,
referenced_id=ref_id,
suggested_resolution=suggested_resolution,
)
)
logger.warning(
"Unresolved reference from resource '%s' (attribute: '%s') "
"to '%s' - suggested resolution: %s",
resource.unique_id,
source_attribute,
ref_id,
suggested_resolution,
)
continue
target_resource = self._resource_map[ref_id]
relationship_type = self._classify_relationship(
resource, target_resource
)
# Edge direction: source depends on target
# So target must come before source in topological order
graph.add_edge(ref_id, resource.unique_id)
source_attribute = self._identify_source_attribute(
resource, target_resource
)
relationships.append(
ResourceRelationship(
source_id=resource.unique_id,
target_id=ref_id,
relationship_type=relationship_type,
source_attribute=source_attribute,
)
)
# Detect cycles
cycle_reports = self.detect_cycles(graph, relationships)
cycles = [report.cycle for report in cycle_reports]
# Produce topological ordering by breaking cycle edges if needed
topological_order = self._topological_order_with_cycle_breaking(
graph, cycle_reports
)
return DependencyGraph(
resources=self._scan_result.resources,
relationships=relationships,
topological_order=topological_order,
cycles=cycles,
unresolved_references=unresolved_references,
cycle_reports=cycle_reports,
)
def detect_cycles(
self, graph: nx.DiGraph, relationships: list[ResourceRelationship]
) -> list[CycleReport]:
"""Detect circular dependencies and suggest resolution strategies.
Finds all simple cycles in the graph and for each cycle suggests which
edge to break. Prefers breaking "reference" over "dependency" over
"parent-child" relationships.
Args:
graph: The networkx DiGraph with resource dependencies.
relationships: The list of ResourceRelationship objects.
Returns:
List of CycleReport objects with cycle info and suggestions.
"""
# Build a lookup for relationship types by edge (target_id, source_id)
# Note: graph edges are (target_id, source_id) because edge direction
# means "target must come before source"
edge_relationship_map: dict[tuple[str, str], ResourceRelationship] = {}
for rel in relationships:
# In the graph, edge is (rel.target_id -> rel.source_id)
edge_relationship_map[(rel.target_id, rel.source_id)] = rel
# Find all simple cycles
raw_cycles = list(nx.simple_cycles(graph))
cycle_reports: list[CycleReport] = []
for cycle_nodes in raw_cycles:
if len(cycle_nodes) < 2:
continue
# Find the best edge to break in this cycle
suggested_break, break_type = self._suggest_cycle_break(
cycle_nodes, edge_relationship_map
)
# Build resolution strategy message
source_id, target_id = suggested_break
# The relationship source_id is the resource that holds the reference
# In graph edge (A, B), A is target_id in relationship, B is source_id
resolution_strategy = (
f"Break the '{break_type}' relationship by replacing the direct "
f"reference from '{target_id}' to '{source_id}' with a "
f"data source lookup (e.g., terraform data source) to decouple "
f"the circular dependency."
)
cycle_reports.append(
CycleReport(
cycle=cycle_nodes,
suggested_break=suggested_break,
break_relationship_type=break_type,
resolution_strategy=resolution_strategy,
)
)
return cycle_reports
def _suggest_cycle_break(
self,
cycle_nodes: list[str],
edge_relationship_map: dict[tuple[str, str], ResourceRelationship],
) -> tuple[tuple[str, str], str]:
"""Suggest which edge to break in a cycle.
Prefers breaking "reference" over "dependency" over "parent-child".
Args:
cycle_nodes: List of node IDs forming the cycle.
edge_relationship_map: Map from graph edge to ResourceRelationship.
Returns:
Tuple of ((source_node, target_node) edge to break, relationship_type).
"""
# Build edges in the cycle: each consecutive pair + wrap-around
cycle_edges: list[tuple[str, str]] = []
for i in range(len(cycle_nodes)):
from_node = cycle_nodes[i]
to_node = cycle_nodes[(i + 1) % len(cycle_nodes)]
cycle_edges.append((from_node, to_node))
# Find the edge with lowest break priority (prefer to break "reference" first)
best_edge = cycle_edges[0]
best_type = "reference"
best_priority = _RELATIONSHIP_BREAK_PRIORITY.get("reference", 0)
for edge in cycle_edges:
rel = edge_relationship_map.get(edge)
if rel:
rel_type = rel.relationship_type
else:
# If no relationship found, treat as reference (easiest to break)
rel_type = "reference"
priority = _RELATIONSHIP_BREAK_PRIORITY.get(rel_type, 0)
if priority < best_priority or (
priority == best_priority and edge < best_edge
):
best_priority = priority
best_edge = edge
best_type = rel_type
return best_edge, best_type
def _topological_order_with_cycle_breaking(
self, graph: nx.DiGraph, cycle_reports: list[CycleReport]
) -> list[str]:
"""Produce topological order by temporarily removing cycle-breaking edges.
If the graph has cycles, removes the suggested edges from each cycle
report and attempts topological sort on the resulting DAG.
Args:
graph: The original DiGraph (may contain cycles).
cycle_reports: Cycle reports with suggested edges to break.
Returns:
List of resource IDs in topological order.
"""
if not cycle_reports:
# No cycles - straightforward topological sort
try:
return list(nx.topological_sort(graph))
except nx.NetworkXUnfeasible:
# Shouldn't happen if cycle detection is correct, but be safe
return list(graph.nodes)
# Create a copy and remove suggested break edges
working_graph = graph.copy()
for report in cycle_reports:
edge = report.suggested_break
if working_graph.has_edge(*edge):
working_graph.remove_edge(*edge)
# Try topological sort on the modified graph
try:
return list(nx.topological_sort(working_graph))
except nx.NetworkXUnfeasible:
# Still has cycles (overlapping cycles may need more breaks)
# Fall back to removing all cycle edges iteratively
while True:
try:
return list(nx.topological_sort(working_graph))
except nx.NetworkXUnfeasible:
# Find remaining cycle and break an edge
try:
cycle = nx.find_cycle(working_graph)
# Remove the first edge in the found cycle
working_graph.remove_edge(*cycle[0][:2])
except nx.NetworkXNoCycle:
return list(nx.topological_sort(working_graph))
def _classify_relationship(
self, source: DiscoveredResource, target: DiscoveredResource
) -> str:
"""Classify the relationship type between source and target.
Args:
source: The resource that holds the reference.
target: The resource being referenced.
Returns:
One of "parent-child", "dependency", or "reference".
"""
# Parent-child: target is a namespace/container resource
if target.resource_type in _NAMESPACE_RESOURCE_TYPES:
return "parent-child"
# Dependency: source resource type has a known dependency on target's type
dependent_types = _DEPENDENCY_RESOURCE_TYPES.get(source.resource_type)
if dependent_types and target.resource_type in dependent_types:
return "dependency"
# Default: reference relationship
return "reference"
def _identify_source_attribute(
self, source: DiscoveredResource, target: DiscoveredResource
) -> str:
"""Identify which attribute in the source holds the reference to target.
Searches the source's attributes for values matching the target's unique_id
or name. Falls back to "raw_references" if no specific attribute is found.
Args:
source: The resource holding the reference.
target: The resource being referenced.
Returns:
The attribute name that holds the reference.
"""
# Search attributes for the target's unique_id or name
for attr_name, attr_value in source.attributes.items():
if isinstance(attr_value, str):
if attr_value == target.unique_id or attr_value == target.name:
return attr_name
elif isinstance(attr_value, list):
for item in attr_value:
if isinstance(item, str) and (
item == target.unique_id or item == target.name
):
return attr_name
return "raw_references"
def _identify_source_attribute_for_ref(
self, source: DiscoveredResource, ref_id: str
) -> str:
"""Identify which attribute in the source holds an unresolved reference.
Searches the source's attributes for values matching the given ref_id.
Falls back to "raw_references" if no specific attribute is found.
Args:
source: The resource holding the reference.
ref_id: The unresolved reference ID string.
Returns:
The attribute name that holds the reference.
"""
for attr_name, attr_value in source.attributes.items():
if isinstance(attr_value, str):
if attr_value == ref_id:
return attr_name
elif isinstance(attr_value, list):
for item in attr_value:
if isinstance(item, str) and item == ref_id:
return attr_name
return "raw_references"
def _suggest_resolution(self, ref_id: str) -> str:
"""Suggest a resolution strategy for an unresolved reference.
Args:
ref_id: The unresolved reference ID.
Returns:
Either "data_source" or "variable" as the suggested resolution.
"""
# If the reference looks like a structured resource ID (contains /),
# suggest a data source lookup. Otherwise suggest a variable.
if "/" in ref_id:
return "data_source"
return "variable"
def _identify_source_attribute_for_ref(
self, source: DiscoveredResource, ref_id: str
) -> str:
"""Identify which attribute in the source holds an unresolved reference.
Searches the source's attributes for values matching the referenced ID.
Falls back to "raw_references" if no specific attribute is found.
Args:
source: The resource holding the reference.
ref_id: The unresolved reference ID.
Returns:
The attribute name that holds the reference.
"""
for attr_name, attr_value in source.attributes.items():
if isinstance(attr_value, str):
if attr_value == ref_id:
return attr_name
elif isinstance(attr_value, list):
for item in attr_value:
if isinstance(item, str) and item == ref_id:
return attr_name
return "raw_references"
@staticmethod
def _suggest_resolution(ref_id: str) -> str:
"""Determine the suggested resolution for an unresolved reference.
Args:
ref_id: The unresolved reference ID.
Returns:
"data_source" if the reference looks like a resource ID (contains
"/" or ":"), otherwise "variable" for simple value/name references.
"""
if "/" in ref_id or ":" in ref_id:
return "data_source"
return "variable"

View File

@@ -0,0 +1,45 @@
"""Scanner module for infrastructure discovery."""
from iac_reverse.scanner.bare_metal_plugin import BareMetalPlugin
from iac_reverse.scanner.docker_swarm_plugin import DockerSwarmPlugin
from iac_reverse.scanner.harvester_plugin import HarvesterPlugin
from iac_reverse.scanner.kubernetes_plugin import KubernetesPlugin
from iac_reverse.scanner.multi_provider_scanner import (
MultiProviderScanner,
MultiProviderScanResult,
ProviderFailure,
ProviderScanEntry,
)
from iac_reverse.scanner.scanner import (
AuthenticationError,
ConnectionLostError,
Scanner,
ScanTimeoutError,
)
from iac_reverse.scanner.synology_plugin import SynologyPlugin
from iac_reverse.scanner.windows_plugin import (
InsufficientPrivilegesError,
WindowsDiscoveryPlugin,
WinRMNotEnabledError,
WMIQueryError,
)
__all__ = [
"AuthenticationError",
"BareMetalPlugin",
"ConnectionLostError",
"DockerSwarmPlugin",
"HarvesterPlugin",
"InsufficientPrivilegesError",
"KubernetesPlugin",
"MultiProviderScanner",
"MultiProviderScanResult",
"ProviderFailure",
"ProviderScanEntry",
"Scanner",
"ScanTimeoutError",
"SynologyPlugin",
"WindowsDiscoveryPlugin",
"WinRMNotEnabledError",
"WMIQueryError",
]

View File

@@ -0,0 +1,497 @@
"""Bare Metal provider plugin using Redfish/IPMI API.
Discovers hardware inventory, BMC configurations, network interfaces,
and RAID configurations from physical servers via the Redfish REST API
(standard BMC management interface).
"""
import logging
from typing import Callable
from urllib.parse import urljoin
import requests
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import AuthenticationError
logger = logging.getLogger(__name__)
class BareMetalPlugin(ProviderPlugin):
"""Provider plugin for bare metal servers using Redfish/IPMI API.
Connects to a server's BMC (Baseboard Management Controller) via the
Redfish REST API to discover hardware inventory, BMC configuration,
network interfaces, and RAID configurations.
Expected credentials dict keys:
host: BMC hostname or IP address (required)
username: BMC username (required)
password: BMC password (required)
port: BMC port (optional, default 443)
use_ssl: Whether to use HTTPS (optional, default "true")
"""
SUPPORTED_RESOURCE_TYPES = [
"bare_metal_hardware",
"bare_metal_bmc_config",
"bare_metal_network_interface",
"bare_metal_raid_config",
]
def __init__(self) -> None:
self._session: requests.Session | None = None
self._base_url: str = ""
self._host: str = ""
def authenticate(self, credentials: dict[str, str]) -> None:
"""Authenticate with the BMC via Redfish session creation.
Args:
credentials: Dict with keys: host, username, password,
and optionally port (default 443) and use_ssl (default "true").
Raises:
AuthenticationError: If connection or login fails.
"""
host = credentials.get("host", "")
username = credentials.get("username", "")
password = credentials.get("password", "")
port = credentials.get("port", "443")
use_ssl = credentials.get("use_ssl", "true").lower() == "true"
if not host or not username or not password:
raise AuthenticationError(
provider_name="bare_metal",
reason="Missing required credentials: host, username, and password are required",
)
scheme = "https" if use_ssl else "http"
self._base_url = f"{scheme}://{host}:{port}"
self._host = host
session = requests.Session()
session.verify = False # BMC certs are typically self-signed
session.headers.update({
"Content-Type": "application/json",
"Accept": "application/json",
})
# Attempt Redfish session-based authentication
session_url = f"{self._base_url}/redfish/v1/SessionService/Sessions"
payload = {"UserName": username, "Password": password}
try:
response = session.post(session_url, json=payload, timeout=30)
if response.status_code in (200, 201):
# Extract session token from response headers
token = response.headers.get("X-Auth-Token", "")
if token:
session.headers["X-Auth-Token"] = token
elif response.status_code == 401:
raise AuthenticationError(
provider_name="bare_metal",
reason="Invalid credentials (HTTP 401)",
)
else:
raise AuthenticationError(
provider_name="bare_metal",
reason=f"Unexpected response status {response.status_code}",
)
except requests.exceptions.ConnectionError as exc:
raise AuthenticationError(
provider_name="bare_metal",
reason=f"Cannot connect to BMC at {self._base_url}: {exc}",
) from exc
except requests.exceptions.Timeout as exc:
raise AuthenticationError(
provider_name="bare_metal",
reason=f"Connection to BMC timed out: {exc}",
) from exc
except AuthenticationError:
raise
except Exception as exc:
raise AuthenticationError(
provider_name="bare_metal",
reason=f"Unexpected error during authentication: {exc}",
) from exc
self._session = session
def get_platform_category(self) -> PlatformCategory:
"""Return PlatformCategory.BARE_METAL."""
return PlatformCategory.BARE_METAL
def list_endpoints(self) -> list[str]:
"""Return the BMC host as the single endpoint."""
return [self._host] if self._host else []
def list_supported_resource_types(self) -> list[str]:
"""Return supported bare metal resource types."""
return list(self.SUPPORTED_RESOURCE_TYPES)
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect CPU architecture from Redfish system hardware info.
Queries /redfish/v1/Systems/1/Processors to determine the
processor architecture.
Args:
endpoint: The BMC host address.
Returns:
CpuArchitecture enum value based on processor info.
"""
if self._session is None:
return CpuArchitecture.AMD64
processors_url = f"{self._base_url}/redfish/v1/Systems/1/Processors"
try:
response = self._session.get(processors_url, timeout=30)
if response.status_code == 200:
data = response.json()
members = data.get("Members", [])
if members:
# Query first processor for architecture details
proc_uri = members[0].get("@odata.id", "")
if proc_uri:
proc_url = f"{self._base_url}{proc_uri}"
proc_response = self._session.get(proc_url, timeout=30)
if proc_response.status_code == 200:
proc_data = proc_response.json()
return self._parse_architecture(proc_data)
except Exception as exc:
logger.warning("Failed to detect architecture: %s", exc)
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover bare metal resources via Redfish API.
Args:
endpoints: List of BMC host addresses to scan.
resource_types: Resource types to discover.
progress_callback: Progress reporting callback.
Returns:
ScanResult with discovered resources.
"""
resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
total_types = len(resource_types)
types_completed = 0
for endpoint in endpoints:
architecture = self.detect_architecture(endpoint)
for resource_type in resource_types:
try:
discovered = self._discover_resource_type(
endpoint, resource_type, architecture
)
resources.extend(discovered)
except Exception as exc:
error_msg = (
f"Error discovering {resource_type} on {endpoint}: {exc}"
)
errors.append(error_msg)
logger.error(error_msg)
types_completed += 1
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(resources),
resource_types_completed=types_completed,
total_resource_types=total_types,
)
)
return ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp="",
profile_hash="",
)
# -----------------------------------------------------------------------
# Private helpers
# -----------------------------------------------------------------------
def _discover_resource_type(
self,
endpoint: str,
resource_type: str,
architecture: CpuArchitecture,
) -> list[DiscoveredResource]:
"""Dispatch discovery to the appropriate handler."""
handlers = {
"bare_metal_hardware": self._discover_hardware,
"bare_metal_bmc_config": self._discover_bmc_config,
"bare_metal_network_interface": self._discover_network_interfaces,
"bare_metal_raid_config": self._discover_raid_config,
}
handler = handlers.get(resource_type)
if handler is None:
return []
return handler(endpoint, architecture)
def _discover_hardware(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover hardware inventory via /redfish/v1/Systems/1."""
if self._session is None:
return []
url = f"{self._base_url}/redfish/v1/Systems/1"
try:
response = self._session.get(url, timeout=30)
if response.status_code != 200:
return []
data = response.json()
except Exception as exc:
logger.warning("Failed to discover hardware: %s", exc)
return []
system_id = data.get("Id", "System.1")
return [
DiscoveredResource(
resource_type="bare_metal_hardware",
unique_id=f"{endpoint}:{system_id}",
name=data.get("Name", f"System {system_id}"),
provider=ProviderType.BARE_METAL,
platform_category=PlatformCategory.BARE_METAL,
architecture=architecture,
endpoint=endpoint,
attributes={
"manufacturer": data.get("Manufacturer", ""),
"model": data.get("Model", ""),
"serial_number": data.get("SerialNumber", ""),
"sku": data.get("SKU", ""),
"bios_version": data.get("BiosVersion", ""),
"total_memory_gib": data.get("MemorySummary", {}).get(
"TotalSystemMemoryGiB", 0
),
"processor_count": data.get("ProcessorSummary", {}).get(
"Count", 0
),
"processor_model": data.get("ProcessorSummary", {}).get(
"Model", ""
),
"power_state": data.get("PowerState", ""),
"status": data.get("Status", {}),
},
)
]
def _discover_bmc_config(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover BMC configuration via /redfish/v1/Managers/1."""
if self._session is None:
return []
url = f"{self._base_url}/redfish/v1/Managers/1"
try:
response = self._session.get(url, timeout=30)
if response.status_code != 200:
return []
data = response.json()
except Exception as exc:
logger.warning("Failed to discover BMC config: %s", exc)
return []
manager_id = data.get("Id", "BMC.1")
return [
DiscoveredResource(
resource_type="bare_metal_bmc_config",
unique_id=f"{endpoint}:{manager_id}",
name=data.get("Name", f"BMC {manager_id}"),
provider=ProviderType.BARE_METAL,
platform_category=PlatformCategory.BARE_METAL,
architecture=architecture,
endpoint=endpoint,
attributes={
"manager_type": data.get("ManagerType", ""),
"firmware_version": data.get("FirmwareVersion", ""),
"model": data.get("Model", ""),
"status": data.get("Status", {}),
"uuid": data.get("UUID", ""),
},
)
]
def _discover_network_interfaces(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover network interfaces via /redfish/v1/Systems/1/EthernetInterfaces."""
if self._session is None:
return []
url = f"{self._base_url}/redfish/v1/Systems/1/EthernetInterfaces"
try:
response = self._session.get(url, timeout=30)
if response.status_code != 200:
return []
data = response.json()
except Exception as exc:
logger.warning("Failed to discover network interfaces: %s", exc)
return []
resources: list[DiscoveredResource] = []
for member in data.get("Members", []):
nic_uri = member.get("@odata.id", "")
if not nic_uri:
continue
try:
nic_url = f"{self._base_url}{nic_uri}"
nic_response = self._session.get(nic_url, timeout=30)
if nic_response.status_code != 200:
continue
nic_data = nic_response.json()
except Exception as exc:
logger.warning("Failed to get NIC details at %s: %s", nic_uri, exc)
continue
nic_id = nic_data.get("Id", "")
resources.append(
DiscoveredResource(
resource_type="bare_metal_network_interface",
unique_id=f"{endpoint}:{nic_id}",
name=nic_data.get("Name", f"NIC {nic_id}"),
provider=ProviderType.BARE_METAL,
platform_category=PlatformCategory.BARE_METAL,
architecture=architecture,
endpoint=endpoint,
attributes={
"mac_address": nic_data.get("MACAddress", ""),
"speed_mbps": nic_data.get("SpeedMbps", 0),
"status": nic_data.get("Status", {}),
"ipv4_addresses": nic_data.get("IPv4Addresses", []),
"ipv6_addresses": nic_data.get("IPv6Addresses", []),
"vlan": nic_data.get("VLAN", {}),
"link_status": nic_data.get("LinkStatus", ""),
"auto_neg": nic_data.get("AutoNeg", False),
},
)
)
return resources
def _discover_raid_config(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover RAID configuration via /redfish/v1/Systems/1/Storage."""
if self._session is None:
return []
url = f"{self._base_url}/redfish/v1/Systems/1/Storage"
try:
response = self._session.get(url, timeout=30)
if response.status_code != 200:
return []
data = response.json()
except Exception as exc:
logger.warning("Failed to discover RAID config: %s", exc)
return []
resources: list[DiscoveredResource] = []
for member in data.get("Members", []):
storage_uri = member.get("@odata.id", "")
if not storage_uri:
continue
try:
storage_url = f"{self._base_url}{storage_uri}"
storage_response = self._session.get(storage_url, timeout=30)
if storage_response.status_code != 200:
continue
storage_data = storage_response.json()
except Exception as exc:
logger.warning(
"Failed to get storage details at %s: %s", storage_uri, exc
)
continue
storage_id = storage_data.get("Id", "")
drives = []
for drive in storage_data.get("Drives", []):
drive_uri = drive.get("@odata.id", "")
if drive_uri:
drives.append(drive_uri)
volumes = []
volumes_link = storage_data.get("Volumes", {}).get("@odata.id", "")
if volumes_link:
try:
vol_url = f"{self._base_url}{volumes_link}"
vol_response = self._session.get(vol_url, timeout=30)
if vol_response.status_code == 200:
vol_data = vol_response.json()
for vol_member in vol_data.get("Members", []):
vol_uri = vol_member.get("@odata.id", "")
if vol_uri:
volumes.append(vol_uri)
except Exception as exc:
logger.warning("Failed to get volumes: %s", exc)
resources.append(
DiscoveredResource(
resource_type="bare_metal_raid_config",
unique_id=f"{endpoint}:{storage_id}",
name=storage_data.get("Name", f"Storage {storage_id}"),
provider=ProviderType.BARE_METAL,
platform_category=PlatformCategory.BARE_METAL,
architecture=architecture,
endpoint=endpoint,
attributes={
"storage_controllers": [
ctrl.get("Name", "")
for ctrl in storage_data.get(
"StorageControllers", []
)
],
"drive_count": len(drives),
"drives": drives,
"volumes": volumes,
"status": storage_data.get("Status", {}),
},
)
)
return resources
@staticmethod
def _parse_architecture(proc_data: dict) -> CpuArchitecture:
"""Parse CPU architecture from Redfish processor data.
Examines InstructionSet and Model fields to determine architecture.
"""
instruction_set = proc_data.get("InstructionSet", "").lower()
model = proc_data.get("Model", "").lower()
if "aarch64" in instruction_set or "arm" in instruction_set:
return CpuArchitecture.AARCH64
if "arm" in model:
if "64" in model or "aarch64" in model or "v8" in model:
return CpuArchitecture.AARCH64
return CpuArchitecture.ARM
# Default to AMD64 for x86/x86_64/IA-32e
return CpuArchitecture.AMD64

View File

@@ -0,0 +1,433 @@
"""Docker Swarm provider plugin.
Discovers services, networks, volumes, configs, and secrets from a Docker Swarm
cluster using the docker-sdk-python library.
"""
import logging
from typing import Callable, Optional
import docker
from docker.tls import TLSConfig
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import AuthenticationError
logger = logging.getLogger(__name__)
# Resource types supported by this plugin
SUPPORTED_RESOURCE_TYPES = [
"docker_service",
"docker_network",
"docker_volume",
"docker_config",
"docker_secret",
]
# Mapping from Docker platform architecture strings to CpuArchitecture enum
_ARCH_MAP: dict[str, CpuArchitecture] = {
"x86_64": CpuArchitecture.AMD64,
"amd64": CpuArchitecture.AMD64,
"aarch64": CpuArchitecture.AARCH64,
"arm64": CpuArchitecture.AARCH64,
"armv7l": CpuArchitecture.ARM,
"armhf": CpuArchitecture.ARM,
"arm": CpuArchitecture.ARM,
}
class DockerSwarmPlugin(ProviderPlugin):
"""Provider plugin for Docker Swarm infrastructure discovery.
Connects to a Docker daemon (in Swarm mode) and enumerates services,
networks, volumes, configs, and secrets.
Expected credentials dict keys:
- host: Docker daemon URL (e.g., "tcp://192.168.1.10:2376")
- tls_verify: (optional) "true" or "false" to enable TLS verification
- cert_path: (optional) path to TLS certificates directory
"""
def __init__(self) -> None:
self._client: Optional[docker.DockerClient] = None
self._host: str = ""
def authenticate(self, credentials: dict[str, str]) -> None:
"""Connect to the Docker daemon using the provided credentials.
Args:
credentials: Dict with keys 'host' (required), 'tls_verify' (optional),
and 'cert_path' (optional).
Raises:
AuthenticationError: If connection to the Docker daemon fails.
"""
host = credentials.get("host", "")
if not host:
raise AuthenticationError(
provider_name="docker_swarm",
reason="'host' is required in credentials",
)
tls_verify = credentials.get("tls_verify", "").lower() == "true"
cert_path = credentials.get("cert_path")
tls_config: Optional[TLSConfig] = None
if tls_verify or cert_path:
tls_config = TLSConfig(
verify=tls_verify,
client_cert=(
(f"{cert_path}/cert.pem", f"{cert_path}/key.pem")
if cert_path
else None
),
ca_cert=f"{cert_path}/ca.pem" if cert_path else None,
)
try:
self._client = docker.DockerClient(
base_url=host,
tls=tls_config if tls_config else False,
)
# Verify connection by pinging the daemon
self._client.ping()
except Exception as exc:
raise AuthenticationError(
provider_name="docker_swarm",
reason=str(exc),
) from exc
self._host = host
def get_platform_category(self) -> PlatformCategory:
"""Return CONTAINER_ORCHESTRATION platform category."""
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
"""Return the Docker daemon host as the single endpoint."""
if self._host:
return [self._host]
return []
def list_supported_resource_types(self) -> list[str]:
"""Return supported Docker Swarm resource types."""
return list(SUPPORTED_RESOURCE_TYPES)
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect CPU architecture from Docker node info.
Queries the Docker daemon's system info to determine the architecture
of the Swarm node.
Args:
endpoint: The Docker daemon endpoint (used for context).
Returns:
CpuArchitecture enum value detected from node info.
"""
if self._client is None:
return CpuArchitecture.AMD64
try:
info = self._client.info()
arch_str = info.get("Architecture", "x86_64").lower()
return _ARCH_MAP.get(arch_str, CpuArchitecture.AMD64)
except Exception:
logger.warning(
"Failed to detect architecture for endpoint %s, defaulting to AMD64",
endpoint,
)
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover Docker Swarm resources.
Enumerates services, networks, volumes, configs, and secrets
based on the requested resource_types.
Args:
endpoints: List of Docker daemon endpoints.
resource_types: Resource types to discover.
progress_callback: Callback for progress updates.
Returns:
ScanResult with discovered resources.
"""
resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
if self._client is None:
return ScanResult(
resources=[],
warnings=[],
errors=["Not authenticated. Call authenticate() first."],
scan_timestamp="",
profile_hash="",
)
endpoint = endpoints[0] if endpoints else self._host
architecture = self.detect_architecture(endpoint)
total_types = len(resource_types)
discovery_methods = {
"docker_service": self._discover_services,
"docker_network": self._discover_networks,
"docker_volume": self._discover_volumes,
"docker_config": self._discover_configs,
"docker_secret": self._discover_secrets,
}
for idx, resource_type in enumerate(resource_types):
method = discovery_methods.get(resource_type)
if method is None:
warnings.append(f"Unknown resource type: {resource_type}")
continue
try:
discovered = method(endpoint, architecture)
resources.extend(discovered)
except Exception as exc:
error_msg = f"Error discovering {resource_type}: {exc}"
errors.append(error_msg)
logger.error(error_msg)
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(resources),
resource_types_completed=idx + 1,
total_resource_types=total_types,
)
)
return ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp="",
profile_hash="",
)
# ------------------------------------------------------------------
# Private discovery methods
# ------------------------------------------------------------------
def _discover_services(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Docker Swarm services."""
resources: list[DiscoveredResource] = []
services = self._client.services.list()
for svc in services:
attrs = svc.attrs
spec = attrs.get("Spec", {})
task_template = spec.get("TaskTemplate", {})
container_spec = task_template.get("ContainerSpec", {})
resources.append(
DiscoveredResource(
resource_type="docker_service",
unique_id=attrs.get("ID", ""),
name=spec.get("Name", ""),
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"image": container_spec.get("Image", ""),
"replicas": spec.get("Mode", {})
.get("Replicated", {})
.get("Replicas", 1),
"labels": spec.get("Labels", {}),
},
raw_references=self._extract_service_references(spec),
)
)
return resources
def _discover_networks(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Docker networks."""
resources: list[DiscoveredResource] = []
networks = self._client.networks.list()
for net in networks:
attrs = net.attrs
resources.append(
DiscoveredResource(
resource_type="docker_network",
unique_id=attrs.get("Id", ""),
name=attrs.get("Name", ""),
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"driver": attrs.get("Driver", ""),
"scope": attrs.get("Scope", ""),
"attachable": attrs.get("Attachable", False),
"ingress": attrs.get("Ingress", False),
"labels": attrs.get("Labels", {}),
},
raw_references=[],
)
)
return resources
def _discover_volumes(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Docker volumes."""
resources: list[DiscoveredResource] = []
volumes = self._client.volumes.list()
for vol in volumes:
attrs = vol.attrs
resources.append(
DiscoveredResource(
resource_type="docker_volume",
unique_id=attrs.get("Name", ""),
name=attrs.get("Name", ""),
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"driver": attrs.get("Driver", ""),
"mountpoint": attrs.get("Mountpoint", ""),
"labels": attrs.get("Labels", {}),
},
raw_references=[],
)
)
return resources
def _discover_configs(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Docker configs (metadata only, no data content)."""
resources: list[DiscoveredResource] = []
configs = self._client.configs.list()
for cfg in configs:
attrs = cfg.attrs
spec = attrs.get("Spec", {})
resources.append(
DiscoveredResource(
resource_type="docker_config",
unique_id=attrs.get("ID", ""),
name=spec.get("Name", ""),
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"labels": spec.get("Labels", {}),
"created_at": attrs.get("CreatedAt", ""),
"updated_at": attrs.get("UpdatedAt", ""),
},
raw_references=[],
)
)
return resources
def _discover_secrets(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Docker secrets (metadata only, no secret data)."""
resources: list[DiscoveredResource] = []
secrets = self._client.secrets.list()
for secret in secrets:
attrs = secret.attrs
spec = attrs.get("Spec", {})
resources.append(
DiscoveredResource(
resource_type="docker_secret",
unique_id=attrs.get("ID", ""),
name=spec.get("Name", ""),
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"labels": spec.get("Labels", {}),
"created_at": attrs.get("CreatedAt", ""),
"updated_at": attrs.get("UpdatedAt", ""),
},
raw_references=[],
)
)
return resources
@staticmethod
def _extract_service_references(spec: dict) -> list[str]:
"""Extract resource references from a service spec.
Looks for network attachments, volume mounts, config references,
and secret references.
"""
refs: list[str] = []
# Network references
networks = spec.get("TaskTemplate", {}).get("Networks", [])
for net in networks:
target = net.get("Target", "")
if target:
refs.append(f"network:{target}")
# Volume mount references
mounts = (
spec.get("TaskTemplate", {})
.get("ContainerSpec", {})
.get("Mounts", [])
)
for mount in mounts:
source = mount.get("Source", "")
if source:
refs.append(f"volume:{source}")
# Config references
configs = (
spec.get("TaskTemplate", {})
.get("ContainerSpec", {})
.get("Configs", [])
)
for cfg in configs:
config_id = cfg.get("ConfigID", "")
if config_id:
refs.append(f"config:{config_id}")
# Secret references
secrets = (
spec.get("TaskTemplate", {})
.get("ContainerSpec", {})
.get("Secrets", [])
)
for secret in secrets:
secret_id = secret.get("SecretID", "")
if secret_id:
refs.append(f"secret:{secret_id}")
return refs

View File

@@ -0,0 +1,458 @@
"""Harvester provider plugin for HCI infrastructure discovery.
Uses the Kubernetes Python client to interact with Harvester's K8s-based API,
discovering virtual machines, volumes, images, and networks via custom resources.
"""
import logging
from typing import Callable
from kubernetes import client, config
from kubernetes.client.rest import ApiException
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import AuthenticationError
logger = logging.getLogger(__name__)
# Harvester CRD API groups and versions
HARVESTER_API_GROUP = "kubevirt.io"
HARVESTER_VM_VERSION = "v1"
HARVESTER_VM_PLURAL = "virtualmachines"
HARVESTER_CDI_GROUP = "cdi.kubevirt.io"
HARVESTER_CDI_VERSION = "v1beta1"
HARVESTER_VOLUME_PLURAL = "datavolumes"
HARVESTER_IMAGE_GROUP = "harvesterhci.io"
HARVESTER_IMAGE_VERSION = "v1beta1"
HARVESTER_IMAGE_PLURAL = "virtualmachineimages"
HARVESTER_NETWORK_GROUP = "k8s.cni.cncf.io"
HARVESTER_NETWORK_VERSION = "v1"
HARVESTER_NETWORK_PLURAL = "network-attachment-definitions"
# Default namespace for Harvester resources
DEFAULT_NAMESPACE = "default"
class HarvesterPlugin(ProviderPlugin):
"""Provider plugin for SUSE Harvester HCI platform.
Harvester runs on top of Kubernetes and exposes its resources as CRDs.
This plugin uses the kubernetes Python client to authenticate via kubeconfig
and discover VMs, volumes, images, and networks.
Expected credentials:
kubeconfig_path: Path to the kubeconfig file for the Harvester cluster.
context: (optional) Kubernetes context name to use.
"""
def __init__(self) -> None:
self._api_client: client.ApiClient | None = None
self._custom_api: client.CustomObjectsApi | None = None
self._core_api: client.CoreV1Api | None = None
self._kubeconfig_path: str | None = None
self._context: str | None = None
def authenticate(self, credentials: dict[str, str]) -> None:
"""Authenticate with the Harvester cluster via kubeconfig.
Args:
credentials: Must contain 'kubeconfig_path'. May contain 'context'.
Raises:
AuthenticationError: If kubeconfig cannot be loaded or is invalid.
"""
kubeconfig_path = credentials.get("kubeconfig_path")
if not kubeconfig_path:
raise AuthenticationError(
provider_name="harvester",
reason="'kubeconfig_path' is required in credentials",
)
context = credentials.get("context") or None
self._kubeconfig_path = kubeconfig_path
self._context = context
try:
self._api_client = config.new_client_from_config(
config_file=kubeconfig_path,
context=context,
)
self._custom_api = client.CustomObjectsApi(self._api_client)
self._core_api = client.CoreV1Api(self._api_client)
except Exception as exc:
raise AuthenticationError(
provider_name="harvester",
reason=f"Failed to load kubeconfig: {exc}",
) from exc
def get_platform_category(self) -> PlatformCategory:
"""Return HCI platform category for Harvester."""
return PlatformCategory.HCI
def list_endpoints(self) -> list[str]:
"""Return the Harvester cluster API endpoint.
Extracts the server URL from the loaded kubeconfig.
"""
if self._api_client is None:
return []
host = self._api_client.configuration.host or ""
return [host] if host else []
def list_supported_resource_types(self) -> list[str]:
"""Return resource types supported by the Harvester plugin."""
return [
"harvester_virtualmachine",
"harvester_volume",
"harvester_image",
"harvester_network",
]
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect CPU architecture from Harvester cluster node info.
Queries the Kubernetes node list and inspects the architecture label.
Harvester typically runs on AMD64 (Dell PowerEdge servers).
Args:
endpoint: The cluster API endpoint (used for logging context).
Returns:
CpuArchitecture detected from node info.
"""
if self._core_api is None:
return CpuArchitecture.AMD64
try:
nodes = self._core_api.list_node()
if nodes.items:
node = nodes.items[0]
arch = node.status.node_info.architecture
arch_lower = arch.lower() if arch else ""
if arch_lower in ("arm64", "aarch64"):
return CpuArchitecture.AARCH64
elif arch_lower == "arm":
return CpuArchitecture.ARM
else:
return CpuArchitecture.AMD64
except ApiException as exc:
logger.warning(
"Failed to detect architecture from node info for %s: %s",
endpoint,
exc,
)
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover Harvester resources via Kubernetes CRDs.
Enumerates VMs, volumes, images, and networks from the Harvester cluster.
Args:
endpoints: List of cluster API endpoints.
resource_types: Resource types to discover.
progress_callback: Callback for progress updates.
Returns:
ScanResult with discovered resources.
"""
resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
endpoint = endpoints[0] if endpoints else ""
architecture = self.detect_architecture(endpoint)
total_types = len(resource_types)
completed = 0
discovery_map = {
"harvester_virtualmachine": self._discover_vms,
"harvester_volume": self._discover_volumes,
"harvester_image": self._discover_images,
"harvester_network": self._discover_networks,
}
for resource_type in resource_types:
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(resources),
resource_types_completed=completed,
total_resource_types=total_types,
)
)
discover_fn = discovery_map.get(resource_type)
if discover_fn is None:
warnings.append(f"Unknown resource type: {resource_type}")
completed += 1
continue
try:
discovered = discover_fn(endpoint, architecture)
resources.extend(discovered)
except ApiException as exc:
error_msg = (
f"Failed to discover {resource_type}: "
f"HTTP {exc.status} - {exc.reason}"
)
errors.append(error_msg)
logger.error(error_msg)
except Exception as exc:
error_msg = f"Failed to discover {resource_type}: {exc}"
errors.append(error_msg)
logger.error(error_msg)
completed += 1
# Final progress update
progress_callback(
ScanProgress(
current_resource_type="",
resources_discovered=len(resources),
resource_types_completed=total_types,
total_resource_types=total_types,
)
)
return ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp="",
profile_hash="",
)
# ------------------------------------------------------------------
# Private discovery methods
# ------------------------------------------------------------------
def _discover_vms(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Harvester virtual machines via kubevirt.io CRD."""
items = self._list_cluster_custom_objects(
group=HARVESTER_API_GROUP,
version=HARVESTER_VM_VERSION,
plural=HARVESTER_VM_PLURAL,
)
resources = []
for item in items:
metadata = item.get("metadata", {})
spec = item.get("spec", {})
name = metadata.get("name", "unknown")
namespace = metadata.get("namespace", DEFAULT_NAMESPACE)
uid = metadata.get("uid", f"{namespace}/{name}")
resources.append(
DiscoveredResource(
resource_type="harvester_virtualmachine",
unique_id=uid,
name=name,
provider=ProviderType.HARVESTER,
platform_category=PlatformCategory.HCI,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"running": spec.get("running", False),
"spec": spec,
"labels": metadata.get("labels", {}),
"annotations": metadata.get("annotations", {}),
},
raw_references=self._extract_vm_references(spec),
)
)
return resources
def _discover_volumes(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Harvester data volumes via cdi.kubevirt.io CRD."""
items = self._list_cluster_custom_objects(
group=HARVESTER_CDI_GROUP,
version=HARVESTER_CDI_VERSION,
plural=HARVESTER_VOLUME_PLURAL,
)
resources = []
for item in items:
metadata = item.get("metadata", {})
spec = item.get("spec", {})
name = metadata.get("name", "unknown")
namespace = metadata.get("namespace", DEFAULT_NAMESPACE)
uid = metadata.get("uid", f"{namespace}/{name}")
resources.append(
DiscoveredResource(
resource_type="harvester_volume",
unique_id=uid,
name=name,
provider=ProviderType.HARVESTER,
platform_category=PlatformCategory.HCI,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"spec": spec,
"labels": metadata.get("labels", {}),
},
raw_references=[],
)
)
return resources
def _discover_images(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Harvester VM images via harvesterhci.io CRD."""
items = self._list_cluster_custom_objects(
group=HARVESTER_IMAGE_GROUP,
version=HARVESTER_IMAGE_VERSION,
plural=HARVESTER_IMAGE_PLURAL,
)
resources = []
for item in items:
metadata = item.get("metadata", {})
spec = item.get("spec", {})
name = metadata.get("name", "unknown")
namespace = metadata.get("namespace", DEFAULT_NAMESPACE)
uid = metadata.get("uid", f"{namespace}/{name}")
resources.append(
DiscoveredResource(
resource_type="harvester_image",
unique_id=uid,
name=name,
provider=ProviderType.HARVESTER,
platform_category=PlatformCategory.HCI,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"display_name": spec.get("displayName", name),
"url": spec.get("url", ""),
"spec": spec,
"labels": metadata.get("labels", {}),
},
raw_references=[],
)
)
return resources
def _discover_networks(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Harvester networks via k8s.cni.cncf.io CRD."""
items = self._list_cluster_custom_objects(
group=HARVESTER_NETWORK_GROUP,
version=HARVESTER_NETWORK_VERSION,
plural=HARVESTER_NETWORK_PLURAL,
)
resources = []
for item in items:
metadata = item.get("metadata", {})
spec = item.get("spec", {})
name = metadata.get("name", "unknown")
namespace = metadata.get("namespace", DEFAULT_NAMESPACE)
uid = metadata.get("uid", f"{namespace}/{name}")
resources.append(
DiscoveredResource(
resource_type="harvester_network",
unique_id=uid,
name=name,
provider=ProviderType.HARVESTER,
platform_category=PlatformCategory.HCI,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"config": spec.get("config", ""),
"labels": metadata.get("labels", {}),
},
raw_references=[],
)
)
return resources
def _list_cluster_custom_objects(
self, group: str, version: str, plural: str
) -> list[dict]:
"""List all custom objects across all namespaces.
Args:
group: API group (e.g., 'kubevirt.io').
version: API version (e.g., 'v1').
plural: Resource plural name (e.g., 'virtualmachines').
Returns:
List of resource items as dicts.
"""
if self._custom_api is None:
return []
result = self._custom_api.list_cluster_custom_object(
group=group,
version=version,
plural=plural,
)
return result.get("items", [])
@staticmethod
def _extract_vm_references(spec: dict) -> list[str]:
"""Extract resource references from a VM spec.
Looks for volume and network references in the VM template spec.
"""
references: list[str] = []
template = spec.get("template", {})
template_spec = template.get("spec", {})
# Extract volume references
volumes = template_spec.get("volumes", [])
for volume in volumes:
if "dataVolume" in volume:
dv_name = volume["dataVolume"].get("name", "")
if dv_name:
references.append(f"volume:{dv_name}")
if "persistentVolumeClaim" in volume:
pvc_name = volume["persistentVolumeClaim"].get("claimName", "")
if pvc_name:
references.append(f"volume:{pvc_name}")
# Extract network references
networks = template_spec.get("networks", [])
for network in networks:
if "multus" in network:
net_name = network["multus"].get("networkName", "")
if net_name:
references.append(f"network:{net_name}")
return references

View File

@@ -0,0 +1,454 @@
"""Kubernetes provider plugin for infrastructure discovery.
Uses the official kubernetes-client library to discover deployments, services,
ingresses, config maps, persistent volumes, and namespaces from a Kubernetes
cluster. Detects CPU architecture from node labels.
"""
import logging
from typing import Callable
from kubernetes import client, config
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import AuthenticationError
logger = logging.getLogger(__name__)
# Mapping from kubernetes.io/arch label values to CpuArchitecture enum
_ARCH_LABEL_MAP: dict[str, CpuArchitecture] = {
"amd64": CpuArchitecture.AMD64,
"arm": CpuArchitecture.ARM,
"arm64": CpuArchitecture.AARCH64,
"aarch64": CpuArchitecture.AARCH64,
}
_SUPPORTED_RESOURCE_TYPES = [
"kubernetes_deployment",
"kubernetes_service",
"kubernetes_ingress",
"kubernetes_config_map",
"kubernetes_persistent_volume",
"kubernetes_namespace",
]
class KubernetesPlugin(ProviderPlugin):
"""Kubernetes provider plugin using the official kubernetes-client.
Authenticates via kubeconfig file and discovers cluster resources
including deployments, services, ingresses, config maps, persistent
volumes, and namespaces.
"""
def __init__(self) -> None:
self._api_client: client.ApiClient | None = None
self._core_v1: client.CoreV1Api | None = None
self._apps_v1: client.AppsV1Api | None = None
self._networking_v1: client.NetworkingV1Api | None = None
def authenticate(self, credentials: dict[str, str]) -> None:
"""Load kubeconfig and initialize Kubernetes API clients.
Args:
credentials: Dict with keys:
- kubeconfig_path: Path to the kubeconfig file (required)
- context: Kubernetes context name (optional)
Raises:
AuthenticationError: If kubeconfig cannot be loaded.
"""
kubeconfig_path = credentials.get("kubeconfig_path")
if not kubeconfig_path:
raise AuthenticationError(
provider_name="kubernetes",
reason="kubeconfig_path is required in credentials",
)
context = credentials.get("context") or None
try:
config.load_kube_config(
config_file=kubeconfig_path,
context=context,
)
except Exception as exc:
raise AuthenticationError(
provider_name="kubernetes",
reason=f"Failed to load kubeconfig from '{kubeconfig_path}': {exc}",
) from exc
self._api_client = client.ApiClient()
self._core_v1 = client.CoreV1Api(self._api_client)
self._apps_v1 = client.AppsV1Api(self._api_client)
self._networking_v1 = client.NetworkingV1Api(self._api_client)
def get_platform_category(self) -> PlatformCategory:
"""Return CONTAINER_ORCHESTRATION platform category."""
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
"""Return node addresses as endpoints.
Returns:
List of node internal IP addresses or hostnames.
"""
if self._core_v1 is None:
return []
try:
nodes = self._core_v1.list_node()
endpoints: list[str] = []
for node in nodes.items:
if node.status and node.status.addresses:
for addr in node.status.addresses:
if addr.type == "InternalIP":
endpoints.append(addr.address)
break
else:
# Fallback to first address
endpoints.append(node.status.addresses[0].address)
return endpoints
except Exception as exc:
logger.warning("Failed to list node endpoints: %s", exc)
return []
def list_supported_resource_types(self) -> list[str]:
"""Return all Kubernetes resource types this plugin can discover."""
return list(_SUPPORTED_RESOURCE_TYPES)
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect CPU architecture from node labels.
Queries node labels for 'kubernetes.io/arch' to determine the
CPU architecture. Falls back to AMD64 if the label is not found.
Args:
endpoint: Node IP address or hostname to query.
Returns:
CpuArchitecture enum value for the node.
"""
if self._core_v1 is None:
return CpuArchitecture.AMD64
try:
nodes = self._core_v1.list_node()
for node in nodes.items:
# Match node by address
if node.status and node.status.addresses:
node_addresses = [
addr.address for addr in node.status.addresses
]
if endpoint in node_addresses:
labels = node.metadata.labels or {}
arch_label = labels.get(
"kubernetes.io/arch",
labels.get("beta.kubernetes.io/arch", "amd64"),
)
return _ARCH_LABEL_MAP.get(
arch_label, CpuArchitecture.AMD64
)
except Exception as exc:
logger.warning(
"Failed to detect architecture for endpoint '%s': %s",
endpoint,
exc,
)
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover Kubernetes resources across all namespaces.
Args:
endpoints: List of node addresses (used for architecture detection).
resource_types: List of resource type strings to discover.
progress_callback: Callable for progress updates.
Returns:
ScanResult with all discovered resources.
"""
resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
# Determine architecture from first endpoint
architecture = CpuArchitecture.AMD64
if endpoints:
architecture = self.detect_architecture(endpoints[0])
endpoint_str = endpoints[0] if endpoints else "cluster"
total_types = len(resource_types)
for idx, resource_type in enumerate(resource_types):
try:
discovered = self._discover_resource_type(
resource_type, architecture, endpoint_str
)
resources.extend(discovered)
except Exception as exc:
error_msg = (
f"Error discovering {resource_type}: {exc}"
)
errors.append(error_msg)
logger.error(error_msg)
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(resources),
resource_types_completed=idx + 1,
total_resource_types=total_types,
)
)
return ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp="",
profile_hash="",
)
def _discover_resource_type(
self,
resource_type: str,
architecture: CpuArchitecture,
endpoint: str,
) -> list[DiscoveredResource]:
"""Discover resources of a specific type.
Args:
resource_type: The resource type string to discover.
architecture: Detected CPU architecture.
endpoint: Endpoint string for the resource.
Returns:
List of DiscoveredResource objects.
"""
dispatch = {
"kubernetes_deployment": self._discover_deployments,
"kubernetes_service": self._discover_services,
"kubernetes_ingress": self._discover_ingresses,
"kubernetes_config_map": self._discover_config_maps,
"kubernetes_persistent_volume": self._discover_persistent_volumes,
"kubernetes_namespace": self._discover_namespaces,
}
handler = dispatch.get(resource_type)
if handler is None:
return []
return handler(architecture, endpoint)
def _discover_deployments(
self, architecture: CpuArchitecture, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all deployments across namespaces."""
results: list[DiscoveredResource] = []
deployments = self._apps_v1.list_deployment_for_all_namespaces()
for dep in deployments.items:
name = dep.metadata.name
namespace = dep.metadata.namespace
results.append(
DiscoveredResource(
resource_type="kubernetes_deployment",
unique_id=f"{namespace}/{name}",
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"replicas": dep.spec.replicas if dep.spec else None,
"labels": dict(dep.metadata.labels or {}),
},
raw_references=[
f"kubernetes_namespace:{namespace}",
],
)
)
return results
def _discover_services(
self, architecture: CpuArchitecture, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all services across namespaces."""
results: list[DiscoveredResource] = []
services = self._core_v1.list_service_for_all_namespaces()
for svc in services.items:
name = svc.metadata.name
namespace = svc.metadata.namespace
results.append(
DiscoveredResource(
resource_type="kubernetes_service",
unique_id=f"{namespace}/{name}",
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"type": svc.spec.type if svc.spec else None,
"cluster_ip": svc.spec.cluster_ip if svc.spec else None,
"labels": dict(svc.metadata.labels or {}),
},
raw_references=[
f"kubernetes_namespace:{namespace}",
],
)
)
return results
def _discover_ingresses(
self, architecture: CpuArchitecture, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all ingresses across namespaces."""
results: list[DiscoveredResource] = []
ingresses = self._networking_v1.list_ingress_for_all_namespaces()
for ing in ingresses.items:
name = ing.metadata.name
namespace = ing.metadata.namespace
results.append(
DiscoveredResource(
resource_type="kubernetes_ingress",
unique_id=f"{namespace}/{name}",
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"labels": dict(ing.metadata.labels or {}),
},
raw_references=[
f"kubernetes_namespace:{namespace}",
],
)
)
return results
def _discover_config_maps(
self, architecture: CpuArchitecture, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all config maps across namespaces."""
results: list[DiscoveredResource] = []
config_maps = self._core_v1.list_config_map_for_all_namespaces()
for cm in config_maps.items:
name = cm.metadata.name
namespace = cm.metadata.namespace
results.append(
DiscoveredResource(
resource_type="kubernetes_config_map",
unique_id=f"{namespace}/{name}",
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"namespace": namespace,
"data_keys": list((cm.data or {}).keys()),
"labels": dict(cm.metadata.labels or {}),
},
raw_references=[
f"kubernetes_namespace:{namespace}",
],
)
)
return results
def _discover_persistent_volumes(
self, architecture: CpuArchitecture, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all persistent volumes (cluster-scoped)."""
results: list[DiscoveredResource] = []
pvs = self._core_v1.list_persistent_volume()
for pv in pvs.items:
name = pv.metadata.name
results.append(
DiscoveredResource(
resource_type="kubernetes_persistent_volume",
unique_id=name,
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"capacity": (
dict(pv.spec.capacity)
if pv.spec and pv.spec.capacity
else {}
),
"access_modes": (
list(pv.spec.access_modes)
if pv.spec and pv.spec.access_modes
else []
),
"storage_class": (
pv.spec.storage_class_name if pv.spec else None
),
"labels": dict(pv.metadata.labels or {}),
},
raw_references=[],
)
)
return results
def _discover_namespaces(
self, architecture: CpuArchitecture, endpoint: str
) -> list[DiscoveredResource]:
"""Discover all namespaces."""
results: list[DiscoveredResource] = []
namespaces = self._core_v1.list_namespace()
for ns in namespaces.items:
name = ns.metadata.name
results.append(
DiscoveredResource(
resource_type="kubernetes_namespace",
unique_id=name,
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=architecture,
endpoint=endpoint,
attributes={
"status": (
ns.status.phase if ns.status else None
),
"labels": dict(ns.metadata.labels or {}),
},
raw_references=[],
)
)
return results

View File

@@ -0,0 +1,140 @@
"""Multi-provider scanner for infrastructure discovery.
Coordinates scanning across multiple providers independently, handling
partial failures gracefully. If one provider fails, scanning continues
for all remaining providers. Successfully discovered resources are
collected into a unified inventory, and failed providers are reported
with error details.
Implements Requirement 5.5: IF one or more Provider scans fail during a
multi-provider scan, THEN THE Scanner SHALL complete scanning for all
remaining Providers, include successfully discovered Resources in the
inventory, and report which Providers failed along with the corresponding
error details.
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Optional
from iac_reverse.models import (
DiscoveredResource,
ScanProfile,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import Scanner
logger = logging.getLogger(__name__)
@dataclass
class ProviderFailure:
"""Details about a provider that failed during multi-provider scanning."""
provider_name: str
error_type: str
error_message: str
@dataclass
class MultiProviderScanResult:
"""Result of scanning across multiple providers.
Contains all successfully discovered resources from providers that
completed scanning, plus details about any providers that failed.
"""
resources: list[DiscoveredResource] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
errors: list[str] = field(default_factory=list)
failed_providers: list[ProviderFailure] = field(default_factory=list)
successful_providers: list[str] = field(default_factory=list)
scan_timestamp: str = ""
@dataclass
class ProviderScanEntry:
"""A pairing of a ScanProfile with its corresponding ProviderPlugin."""
profile: ScanProfile
plugin: ProviderPlugin
class MultiProviderScanner:
"""Orchestrates infrastructure discovery across multiple providers.
Scans each provider independently. If one provider fails (auth error,
connection error, etc.), continues with remaining providers. Collects
all successfully discovered resources into a unified inventory and
reports which providers failed and why.
"""
def __init__(self, entries: list[ProviderScanEntry]):
"""Initialize with a list of provider scan entries.
Args:
entries: List of ProviderScanEntry, each pairing a ScanProfile
with its corresponding ProviderPlugin.
"""
self.entries = entries
def scan(
self,
progress_callback: Optional[Callable[[ScanProgress], None]] = None,
) -> MultiProviderScanResult:
"""Execute scans across all configured providers.
Each provider is scanned independently. If a provider fails for
any reason (authentication, connection, timeout, validation, etc.),
the error is recorded and scanning continues with remaining providers.
Args:
progress_callback: Optional callable invoked with ScanProgress
updates from each provider scan.
Returns:
MultiProviderScanResult containing all successfully discovered
resources and details about any failed providers.
"""
result = MultiProviderScanResult(
scan_timestamp=datetime.now(timezone.utc).isoformat(),
)
for entry in self.entries:
provider_name = entry.profile.provider.value
try:
scanner = Scanner(entry.profile, entry.plugin)
scan_result = scanner.scan(progress_callback=progress_callback)
# Collect successful resources
result.resources.extend(scan_result.resources)
result.warnings.extend(scan_result.warnings)
result.errors.extend(scan_result.errors)
result.successful_providers.append(provider_name)
logger.info(
"Provider '%s' scan completed: %d resources discovered",
provider_name,
len(scan_result.resources),
)
except Exception as exc:
# Record the failure and continue with remaining providers
failure = ProviderFailure(
provider_name=provider_name,
error_type=type(exc).__name__,
error_message=str(exc),
)
result.failed_providers.append(failure)
logger.warning(
"Provider '%s' scan failed (%s): %s",
provider_name,
type(exc).__name__,
exc,
)
return result

View File

@@ -0,0 +1,287 @@
"""Scanner orchestrator for infrastructure discovery.
Coordinates provider plugins to discover infrastructure resources,
handling authentication, retries, progress reporting, and error recovery.
"""
import hashlib
import logging
import time
from datetime import datetime, timezone
from typing import Callable, Optional
from iac_reverse.models import (
ScanProfile,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Custom Exceptions
# ---------------------------------------------------------------------------
class AuthenticationError(Exception):
"""Raised when authentication with a provider fails."""
def __init__(self, provider_name: str, reason: str):
self.provider_name = provider_name
self.reason = reason
super().__init__(
f"Authentication failed for provider '{provider_name}': {reason}"
)
class ConnectionLostError(Exception):
"""Raised when the provider connection is lost during a scan."""
def __init__(self, partial_result: ScanResult):
self.partial_result = partial_result
super().__init__("Connection lost during scan; partial results available")
class ScanTimeoutError(Exception):
"""Raised when a scan operation exceeds the allowed timeout."""
def __init__(self, message: str = "Scan operation timed out"):
super().__init__(message)
# ---------------------------------------------------------------------------
# Scanner Orchestrator
# ---------------------------------------------------------------------------
# Default constants
CONNECTION_TIMEOUT_SECONDS = 30
MAX_RETRIES = 3
INITIAL_BACKOFF_SECONDS = 1.0
class Scanner:
"""Orchestrates infrastructure discovery using a provider plugin.
Accepts a ScanProfile and an optional ProviderPlugin instance.
Handles authentication, progress reporting, retry logic with
exponential backoff, and graceful degradation on errors.
"""
def __init__(
self,
profile: ScanProfile,
plugin: Optional[ProviderPlugin] = None,
):
self.profile = profile
self.plugin = plugin
def scan(
self,
progress_callback: Optional[Callable[[ScanProgress], None]] = None,
) -> ScanResult:
"""Execute a full infrastructure scan.
Args:
progress_callback: Optional callable invoked per resource type
completion with a ScanProgress update.
Returns:
ScanResult containing discovered resources, warnings, and errors.
Raises:
AuthenticationError: If authentication with the provider fails.
ScanTimeoutError: If the connection attempt exceeds 30 seconds.
ValueError: If the scan profile is invalid.
"""
# 1. Validate the scan profile (critical fields only)
validation_errors = self._validate_profile()
if validation_errors:
raise ValueError(
f"Invalid scan profile: {'; '.join(validation_errors)}"
)
if self.plugin is None:
raise ValueError("No provider plugin configured for scanning")
# 2. Authenticate with the provider (30 second timeout)
self._authenticate()
# 3. Determine resource types to scan
supported_types = self.plugin.list_supported_resource_types()
resource_types, warnings = self._resolve_resource_types(supported_types)
# 4. Determine endpoints
endpoints = self.profile.endpoints or self.plugin.list_endpoints()
# 5. Discover resources with retry logic
scan_result = self._discover_with_retries(
endpoints=endpoints,
resource_types=resource_types,
progress_callback=progress_callback,
)
# Merge any warnings from unsupported resource type filtering
scan_result.warnings = warnings + scan_result.warnings
# Set metadata
scan_result.scan_timestamp = datetime.now(timezone.utc).isoformat()
scan_result.profile_hash = self._compute_profile_hash()
return scan_result
def _authenticate(self) -> None:
"""Authenticate with the provider plugin, enforcing a 30s timeout."""
provider_name = self.profile.provider.value
start_time = time.monotonic()
try:
self.plugin.authenticate(self.profile.credentials)
except Exception as exc:
elapsed = time.monotonic() - start_time
if elapsed >= CONNECTION_TIMEOUT_SECONDS:
raise ScanTimeoutError(
f"Authentication with provider '{provider_name}' "
f"timed out after {CONNECTION_TIMEOUT_SECONDS} seconds"
)
# Wrap any auth exception in our AuthenticationError
if isinstance(exc, AuthenticationError):
raise
raise AuthenticationError(
provider_name=provider_name,
reason=str(exc),
) from exc
elapsed = time.monotonic() - start_time
if elapsed >= CONNECTION_TIMEOUT_SECONDS:
raise ScanTimeoutError(
f"Authentication with provider '{provider_name}' "
f"timed out after {CONNECTION_TIMEOUT_SECONDS} seconds"
)
def _resolve_resource_types(
self, supported_types: list[str]
) -> tuple[list[str], list[str]]:
"""Determine which resource types to scan and log warnings for unsupported ones.
Returns:
Tuple of (resource_types_to_scan, warnings_list)
"""
warnings: list[str] = []
if self.profile.resource_type_filters is None:
# No filters: scan all supported types
return supported_types, warnings
# Filter requested types against supported types
valid_types: list[str] = []
for rt in self.profile.resource_type_filters:
if rt in supported_types:
valid_types.append(rt)
else:
warning_msg = (
f"Unsupported resource type '{rt}' for provider "
f"'{self.profile.provider.value}'; skipping"
)
warnings.append(warning_msg)
logger.warning(warning_msg)
return valid_types, warnings
def _discover_with_retries(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Optional[Callable[[ScanProgress], None]],
) -> ScanResult:
"""Call the plugin's discover_resources with retry logic.
Retries up to MAX_RETRIES times with exponential backoff for
transient errors. On connection loss, returns partial inventory.
"""
last_exception: Optional[Exception] = None
for attempt in range(MAX_RETRIES + 1):
try:
result = self.plugin.discover_resources(
endpoints=endpoints,
resource_types=resource_types,
progress_callback=progress_callback or self._noop_callback,
)
return result
except ConnectionLostError:
# Connection lost: return partial results immediately
raise
except ConnectionError as exc:
# Connection lost during scan: build partial result
logger.warning(
"Connection lost during scan (attempt %d/%d): %s",
attempt + 1,
MAX_RETRIES + 1,
exc,
)
partial = ScanResult(
resources=[],
warnings=[f"Connection lost: {exc}"],
errors=[str(exc)],
scan_timestamp=datetime.now(timezone.utc).isoformat(),
profile_hash=self._compute_profile_hash(),
is_partial=True,
)
raise ConnectionLostError(partial_result=partial) from exc
except Exception as exc:
last_exception = exc
if attempt < MAX_RETRIES:
backoff = INITIAL_BACKOFF_SECONDS * (2**attempt)
logger.warning(
"Transient error during scan (attempt %d/%d), "
"retrying in %.1fs: %s",
attempt + 1,
MAX_RETRIES + 1,
backoff,
exc,
)
time.sleep(backoff)
else:
logger.error(
"Scan failed after %d attempts: %s",
MAX_RETRIES + 1,
exc,
)
# All retries exhausted — return error result
return ScanResult(
resources=[],
warnings=[],
errors=[f"Scan failed after {MAX_RETRIES + 1} attempts: {last_exception}"],
scan_timestamp=datetime.now(timezone.utc).isoformat(),
profile_hash=self._compute_profile_hash(),
is_partial=True,
)
def _validate_profile(self) -> list[str]:
"""Validate critical scan profile fields.
Only checks fields that prevent scanning entirely (e.g., missing
credentials). Unsupported resource types are handled as warnings
during the scan per Requirement 1.4.
"""
errors: list[str] = []
if not self.profile.credentials:
errors.append("credentials must not be empty")
return errors
def _compute_profile_hash(self) -> str:
"""Compute a stable hash of the scan profile for snapshot matching."""
content = (
f"{self.profile.provider.value}:"
f"{sorted(self.profile.credentials.items())}:"
f"{self.profile.endpoints}:"
f"{self.profile.resource_type_filters}"
)
return hashlib.sha256(content.encode()).hexdigest()[:16]
@staticmethod
def _noop_callback(progress: ScanProgress) -> None:
"""No-op progress callback used when none is provided."""
pass

View File

@@ -0,0 +1,482 @@
"""Synology DSM provider plugin.
Discovers shared folders, volumes, storage pools, replication tasks, and users
from a Synology DiskStation Manager (DSM) appliance via its HTTP API.
"""
import logging
from datetime import datetime, timezone
from typing import Callable, Optional
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import AuthenticationError
try:
from synology_dsm import SynologyDSM
except ImportError: # pragma: no cover
SynologyDSM = None # type: ignore[assignment,misc]
logger = logging.getLogger(__name__)
# Resource type constants
SYNOLOGY_SHARED_FOLDER = "synology_shared_folder"
SYNOLOGY_VOLUME = "synology_volume"
SYNOLOGY_STORAGE_POOL = "synology_storage_pool"
SYNOLOGY_REPLICATION_TASK = "synology_replication_task"
SYNOLOGY_USER = "synology_user"
SUPPORTED_RESOURCE_TYPES = [
SYNOLOGY_SHARED_FOLDER,
SYNOLOGY_VOLUME,
SYNOLOGY_STORAGE_POOL,
SYNOLOGY_REPLICATION_TASK,
SYNOLOGY_USER,
]
class SynologyPlugin(ProviderPlugin):
"""Provider plugin for Synology DiskStation Manager (DSM).
Connects to the Synology DSM API to discover storage infrastructure
including shared folders, volumes, storage pools, replication tasks,
and local users.
Expected credentials:
- host: DSM hostname or IP address
- port: DSM port (default "5001")
- username: DSM admin username
- password: DSM admin password
- use_ssl: "true" or "false" (default "true")
"""
def __init__(self) -> None:
self._api: Optional[object] = None
self._host: str = ""
self._port: str = "5001"
self._use_ssl: bool = True
self._authenticated: bool = False
def authenticate(self, credentials: dict[str, str]) -> None:
"""Authenticate with the Synology DSM API.
Args:
credentials: Dict with keys: host, port, username, password,
and optionally use_ssl.
Raises:
AuthenticationError: If connection or login fails.
"""
host = credentials.get("host", "")
port = credentials.get("port", "5001")
username = credentials.get("username", "")
password = credentials.get("password", "")
use_ssl = credentials.get("use_ssl", "true").lower() == "true"
if not host:
raise AuthenticationError("synology", "host is required")
if not username:
raise AuthenticationError("synology", "username is required")
if not password:
raise AuthenticationError("synology", "password is required")
self._host = host
self._port = port
self._use_ssl = use_ssl
try:
if SynologyDSM is None:
raise AuthenticationError(
"synology",
"python-synology library is not installed",
)
api = SynologyDSM(
host,
int(port),
username,
password,
use_https=use_ssl,
verify_ssl=False,
)
# Attempt login
if not api.login():
raise AuthenticationError(
"synology",
f"Login failed for user '{username}' on {host}:{port}",
)
self._api = api
self._authenticated = True
logger.info("Authenticated with Synology DSM at %s:%s", host, port)
except AuthenticationError:
raise
except Exception as exc:
raise AuthenticationError(
"synology",
f"Failed to connect to DSM at {host}:{port}: {exc}",
) from exc
def get_platform_category(self) -> PlatformCategory:
"""Return STORAGE_APPLIANCE platform category."""
return PlatformCategory.STORAGE_APPLIANCE
def list_endpoints(self) -> list[str]:
"""Return the DSM endpoint address."""
protocol = "https" if self._use_ssl else "http"
return [f"{protocol}://{self._host}:{self._port}"]
def list_supported_resource_types(self) -> list[str]:
"""Return all Synology resource types this plugin can discover."""
return list(SUPPORTED_RESOURCE_TYPES)
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect CPU architecture from Synology system info.
Queries the DSM information API to determine if the NAS
runs on ARM or AMD64 hardware.
Args:
endpoint: The DSM endpoint (used for context, not connection).
Returns:
CpuArchitecture.ARM for ARM-based models,
CpuArchitecture.AARCH64 for 64-bit ARM models,
CpuArchitecture.AMD64 for x86-64 models.
"""
if self._api is None:
return CpuArchitecture.AMD64
try:
info = self._api.information
if info is None:
return CpuArchitecture.AMD64
# The model name or CPU info can indicate architecture
model = getattr(info, "model", "") or ""
cpu_name = getattr(info, "cpu_hardware_name", "") or ""
# Combine for matching
hw_info = f"{model} {cpu_name}".lower()
if "aarch64" in hw_info or "arm64" in hw_info:
return CpuArchitecture.AARCH64
elif "arm" in hw_info or "rtd" in hw_info or "alpine" in hw_info:
return CpuArchitecture.ARM
else:
return CpuArchitecture.AMD64
except Exception as exc:
logger.warning("Failed to detect architecture: %s", exc)
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover Synology resources from the DSM API.
Enumerates shared folders, volumes, storage pools, replication tasks,
and users based on the requested resource_types.
Args:
endpoints: List of DSM endpoints (typically one).
resource_types: Resource types to discover.
progress_callback: Callback for progress updates.
Returns:
ScanResult with discovered resources.
"""
resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
endpoint = endpoints[0] if endpoints else self.list_endpoints()[0]
architecture = self.detect_architecture(endpoint)
total_types = len(resource_types)
completed = 0
# Discovery dispatch table
discovery_methods = {
SYNOLOGY_SHARED_FOLDER: self._discover_shared_folders,
SYNOLOGY_VOLUME: self._discover_volumes,
SYNOLOGY_STORAGE_POOL: self._discover_storage_pools,
SYNOLOGY_REPLICATION_TASK: self._discover_replication_tasks,
SYNOLOGY_USER: self._discover_users,
}
for rt in resource_types:
progress_callback(
ScanProgress(
current_resource_type=rt,
resources_discovered=len(resources),
resource_types_completed=completed,
total_resource_types=total_types,
)
)
method = discovery_methods.get(rt)
if method is None:
warnings.append(f"Unsupported resource type: {rt}")
completed += 1
continue
try:
discovered = method(endpoint, architecture)
resources.extend(discovered)
except Exception as exc:
error_msg = f"Error discovering {rt}: {exc}"
errors.append(error_msg)
logger.error(error_msg)
completed += 1
# Final progress update
progress_callback(
ScanProgress(
current_resource_type="",
resources_discovered=len(resources),
resource_types_completed=completed,
total_resource_types=total_types,
)
)
return ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp=datetime.now(timezone.utc).isoformat(),
profile_hash="",
)
# ------------------------------------------------------------------
# Private discovery methods
# ------------------------------------------------------------------
def _discover_shared_folders(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover shared folders from DSM."""
resources: list[DiscoveredResource] = []
storage = self._api.storage
if storage is None:
return resources
# Access shared folders via the storage API
shares = getattr(storage, "shares", None)
if shares is None:
return resources
for share in shares:
name = share.get("name", "unknown")
resources.append(
DiscoveredResource(
resource_type=SYNOLOGY_SHARED_FOLDER,
unique_id=f"synology/shared_folder/{name}",
name=name,
provider=ProviderType.SYNOLOGY,
platform_category=PlatformCategory.STORAGE_APPLIANCE,
architecture=architecture,
endpoint=endpoint,
attributes={
"name": name,
"path": share.get("path", ""),
"desc": share.get("desc", ""),
"encryption": share.get("is_encrypted", False),
"recycle_bin": share.get("enable_recycle_bin", False),
"vol_path": share.get("vol_path", ""),
},
raw_references=[],
)
)
return resources
def _discover_volumes(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover volumes from DSM."""
resources: list[DiscoveredResource] = []
storage = self._api.storage
if storage is None:
return resources
volumes = getattr(storage, "volumes", None)
if volumes is None:
return resources
for volume in volumes:
vol_id = volume.get("id", "unknown")
name = volume.get("display_name", vol_id)
resources.append(
DiscoveredResource(
resource_type=SYNOLOGY_VOLUME,
unique_id=f"synology/volume/{vol_id}",
name=name,
provider=ProviderType.SYNOLOGY,
platform_category=PlatformCategory.STORAGE_APPLIANCE,
architecture=architecture,
endpoint=endpoint,
attributes={
"id": vol_id,
"display_name": name,
"status": volume.get("status", ""),
"fs_type": volume.get("fs_type", ""),
"size_total": volume.get("size", {}).get("total", ""),
"size_used": volume.get("size", {}).get("used", ""),
"pool_path": volume.get("pool_path", ""),
},
raw_references=[
f"synology/storage_pool/{volume.get('pool_path', '')}"
]
if volume.get("pool_path")
else [],
)
)
return resources
def _discover_storage_pools(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover storage pools from DSM."""
resources: list[DiscoveredResource] = []
storage = self._api.storage
if storage is None:
return resources
pools = getattr(storage, "storage_pools", None)
if pools is None:
return resources
for pool in pools:
pool_id = pool.get("id", "unknown")
name = pool.get("display_name", pool_id)
resources.append(
DiscoveredResource(
resource_type=SYNOLOGY_STORAGE_POOL,
unique_id=f"synology/storage_pool/{pool_id}",
name=name,
provider=ProviderType.SYNOLOGY,
platform_category=PlatformCategory.STORAGE_APPLIANCE,
architecture=architecture,
endpoint=endpoint,
attributes={
"id": pool_id,
"display_name": name,
"status": pool.get("status", ""),
"raid_type": pool.get("raid_type", ""),
"size_total": pool.get("size", {}).get("total", ""),
"size_used": pool.get("size", {}).get("used", ""),
"disk_count": len(pool.get("disks", [])),
},
raw_references=[],
)
)
return resources
def _discover_replication_tasks(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover replication tasks from DSM."""
resources: list[DiscoveredResource] = []
# Replication tasks are accessed via a separate API module
api = self._api
if api is None:
return resources
# Try to access replication info if available
replication = getattr(api, "replication", None)
if replication is None:
return resources
tasks = getattr(replication, "tasks", None)
if tasks is None:
return resources
for task in tasks:
task_id = task.get("id", "unknown")
name = task.get("name", task_id)
resources.append(
DiscoveredResource(
resource_type=SYNOLOGY_REPLICATION_TASK,
unique_id=f"synology/replication_task/{task_id}",
name=name,
provider=ProviderType.SYNOLOGY,
platform_category=PlatformCategory.STORAGE_APPLIANCE,
architecture=architecture,
endpoint=endpoint,
attributes={
"id": task_id,
"name": name,
"status": task.get("status", ""),
"type": task.get("type", ""),
"destination": task.get("destination", ""),
"schedule": task.get("schedule", {}),
"shared_folder": task.get("shared_folder", ""),
},
raw_references=[
f"synology/shared_folder/{task.get('shared_folder', '')}"
]
if task.get("shared_folder")
else [],
)
)
return resources
def _discover_users(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover local users from DSM."""
resources: list[DiscoveredResource] = []
api = self._api
if api is None:
return resources
# Users are typically accessed via SYNO.Core.User API
users_api = getattr(api, "users", None)
if users_api is None:
return resources
users = getattr(users_api, "users", None)
if users is None:
return resources
for user in users:
username = user.get("name", "unknown")
resources.append(
DiscoveredResource(
resource_type=SYNOLOGY_USER,
unique_id=f"synology/user/{username}",
name=username,
provider=ProviderType.SYNOLOGY,
platform_category=PlatformCategory.STORAGE_APPLIANCE,
architecture=architecture,
endpoint=endpoint,
attributes={
"name": username,
"description": user.get("description", ""),
"email": user.get("email", ""),
"expired": user.get("expired", False),
"groups": user.get("groups", []),
},
raw_references=[],
)
)
return resources

View File

@@ -0,0 +1,825 @@
"""Windows provider plugin for infrastructure discovery via WinRM.
Uses pywinrm to connect to Windows machines and discover services,
scheduled tasks, IIS sites, app pools, network adapters, firewall rules,
installed software, Windows features, Hyper-V VMs, Hyper-V switches,
DNS records, local users, and local groups.
"""
import json
import logging
from typing import Callable
import winrm
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import AuthenticationError
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Custom Exceptions
# ---------------------------------------------------------------------------
class WinRMNotEnabledError(Exception):
"""Raised when WinRM is not enabled on the target host."""
def __init__(self, host: str, reason: str = ""):
self.host = host
self.reason = reason
super().__init__(
f"WinRM is not enabled or unreachable on host '{host}'"
+ (f": {reason}" if reason else "")
)
class WMIQueryError(Exception):
"""Raised when a WMI query fails on the target host."""
def __init__(self, query: str, reason: str = ""):
self.query = query
self.reason = reason
super().__init__(
f"WMI query failed: '{query}'"
+ (f": {reason}" if reason else "")
)
class InsufficientPrivilegesError(Exception):
"""Raised when the authenticated user lacks required privileges."""
def __init__(self, operation: str, reason: str = ""):
self.operation = operation
self.reason = reason
super().__init__(
f"Insufficient privileges for operation '{operation}'"
+ (f": {reason}" if reason else "")
)
# ---------------------------------------------------------------------------
# Windows Discovery Plugin
# ---------------------------------------------------------------------------
WINDOWS_RESOURCE_TYPES = [
"windows_service",
"windows_scheduled_task",
"windows_iis_site",
"windows_iis_app_pool",
"windows_network_adapter",
"windows_firewall_rule",
"windows_installed_software",
"windows_feature",
"windows_hyperv_vm",
"windows_hyperv_switch",
"windows_dns_record",
"windows_local_user",
"windows_local_group",
]
class WindowsDiscoveryPlugin(ProviderPlugin):
"""Provider plugin for discovering Windows infrastructure via WinRM.
Connects to Windows machines using pywinrm and discovers resources
through PowerShell commands and WMI queries executed over WinRM.
Expected credentials dict keys:
host: Target hostname or IP address
username: Windows username (domain\\user or user@domain)
password: Password for authentication
transport: Authentication transport - "ntlm" (default) or "kerberos"
port: WinRM port - "5985" (HTTP) or "5986" (HTTPS, default)
use_ssl: Whether to use SSL - "true" (default) or "false"
"""
def __init__(self) -> None:
self._session: winrm.Session | None = None
self._host: str = ""
self._credentials: dict[str, str] = {}
def authenticate(self, credentials: dict[str, str]) -> None:
"""Authenticate with the Windows host via WinRM.
Args:
credentials: Dict with keys: host, username, password,
transport (default "ntlm"), port (default "5986"),
use_ssl (default "true").
Raises:
AuthenticationError: If authentication fails.
WinRMNotEnabledError: If WinRM is not reachable.
"""
host = credentials.get("host", "")
username = credentials.get("username", "")
password = credentials.get("password", "")
transport = credentials.get("transport", "ntlm")
port = credentials.get("port", "5986")
use_ssl = credentials.get("use_ssl", "true").lower() == "true"
if not host:
raise AuthenticationError("windows", "host is required")
if not username:
raise AuthenticationError("windows", "username is required")
if not password:
raise AuthenticationError("windows", "password is required")
self._host = host
self._credentials = credentials
scheme = "https" if use_ssl else "http"
endpoint = f"{scheme}://{host}:{port}/wsman"
try:
self._session = winrm.Session(
endpoint,
auth=(username, password),
transport=transport,
server_cert_validation="ignore" if use_ssl else "validate",
)
# Test connectivity with a simple command
result = self._session.run_ps("$env:COMPUTERNAME")
if result.status_code != 0:
stderr = result.std_err.decode("utf-8", errors="replace").strip()
if "access" in stderr.lower() or "denied" in stderr.lower():
raise InsufficientPrivilegesError(
"authenticate", stderr
)
raise AuthenticationError("windows", stderr or "Authentication test failed")
except AuthenticationError:
raise
except InsufficientPrivilegesError as exc:
raise AuthenticationError("windows", str(exc)) from exc
except WinRMNotEnabledError:
raise
except Exception as exc:
error_msg = str(exc).lower()
if "connection" in error_msg or "refused" in error_msg or "unreachable" in error_msg:
raise WinRMNotEnabledError(host, str(exc)) from exc
raise AuthenticationError("windows", str(exc)) from exc
def get_platform_category(self) -> PlatformCategory:
"""Return PlatformCategory.WINDOWS."""
return PlatformCategory.WINDOWS
def list_endpoints(self) -> list[str]:
"""Return the single Windows host as the endpoint."""
return [self._host] if self._host else []
def list_supported_resource_types(self) -> list[str]:
"""Return all 13 Windows resource types."""
return list(WINDOWS_RESOURCE_TYPES)
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
"""Detect CPU architecture via WMI Win32_Processor query.
Args:
endpoint: The Windows host to query.
Returns:
CpuArchitecture enum value.
Raises:
WMIQueryError: If the WMI query fails.
"""
query = "Get-WmiObject Win32_Processor | Select-Object -First 1 -ExpandProperty Architecture"
result = self._run_powershell(query)
if result.status_code != 0:
stderr = result.std_err.decode("utf-8", errors="replace").strip()
raise WMIQueryError("Win32_Processor.Architecture", stderr)
arch_code = result.std_out.decode("utf-8", errors="replace").strip()
# WMI Architecture codes:
# 0 = x86, 5 = ARM, 9 = x64, 12 = ARM64
arch_map = {
"0": CpuArchitecture.AMD64, # x86 mapped to amd64 for simplicity
"5": CpuArchitecture.ARM,
"9": CpuArchitecture.AMD64,
"12": CpuArchitecture.AARCH64,
}
return arch_map.get(arch_code, CpuArchitecture.AMD64)
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover Windows resources via WinRM/PowerShell.
Args:
endpoints: List of Windows hosts to scan.
resource_types: List of resource type strings to discover.
progress_callback: Callable for progress updates.
Returns:
ScanResult with discovered resources, warnings, and errors.
"""
all_resources: list[DiscoveredResource] = []
warnings: list[str] = []
errors: list[str] = []
total_types = len(resource_types)
for endpoint in endpoints:
# Detect architecture for this endpoint
try:
architecture = self.detect_architecture(endpoint)
except (WMIQueryError, Exception) as exc:
warnings.append(
f"Could not detect architecture for {endpoint}: {exc}. "
f"Defaulting to AMD64."
)
architecture = CpuArchitecture.AMD64
# Check if Hyper-V is installed (needed for hyperv resource types)
hyperv_installed = self._is_hyperv_installed()
for idx, resource_type in enumerate(resource_types):
try:
# Skip Hyper-V resources if role not installed
if resource_type in ("windows_hyperv_vm", "windows_hyperv_switch"):
if not hyperv_installed:
warnings.append(
f"Skipping {resource_type}: Hyper-V role not installed on {endpoint}"
)
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(all_resources),
resource_types_completed=idx + 1,
total_resource_types=total_types,
)
)
continue
discovered = self._discover_resource_type(
endpoint, resource_type, architecture
)
all_resources.extend(discovered)
except InsufficientPrivilegesError as exc:
errors.append(
f"Insufficient privileges for {resource_type} on {endpoint}: {exc}"
)
except WMIQueryError as exc:
errors.append(
f"WMI query failed for {resource_type} on {endpoint}: {exc}"
)
except Exception as exc:
errors.append(
f"Error discovering {resource_type} on {endpoint}: {exc}"
)
progress_callback(
ScanProgress(
current_resource_type=resource_type,
resources_discovered=len(all_resources),
resource_types_completed=idx + 1,
total_resource_types=total_types,
)
)
return ScanResult(
resources=all_resources,
warnings=warnings,
errors=errors,
scan_timestamp="",
profile_hash="",
)
# -----------------------------------------------------------------------
# Private helpers
# -----------------------------------------------------------------------
def _run_powershell(self, script: str) -> winrm.Response:
"""Execute a PowerShell script via WinRM.
Args:
script: PowerShell script to execute.
Returns:
winrm.Response object.
Raises:
WinRMNotEnabledError: If the session is not established.
"""
if self._session is None:
raise WinRMNotEnabledError(self._host, "No active WinRM session")
return self._session.run_ps(script)
def _run_powershell_json(self, script: str) -> list[dict]:
"""Execute a PowerShell script and parse JSON output.
The script should output ConvertTo-Json formatted data.
Args:
script: PowerShell script that outputs JSON.
Returns:
List of dicts parsed from JSON output.
Raises:
WMIQueryError: If the command fails.
InsufficientPrivilegesError: If access is denied.
"""
result = self._run_powershell(script)
if result.status_code != 0:
stderr = result.std_err.decode("utf-8", errors="replace").strip()
if "access" in stderr.lower() or "denied" in stderr.lower() or "privilege" in stderr.lower():
raise InsufficientPrivilegesError(script, stderr)
raise WMIQueryError(script, stderr)
stdout = result.std_out.decode("utf-8", errors="replace").strip()
if not stdout:
return []
try:
data = json.loads(stdout)
if isinstance(data, dict):
return [data]
return data if isinstance(data, list) else []
except json.JSONDecodeError:
return []
def _is_hyperv_installed(self) -> bool:
"""Check if the Hyper-V role is installed on the target.
Returns:
True if Hyper-V is installed, False otherwise.
"""
script = (
"Get-WindowsFeature -Name Hyper-V | "
"Select-Object -ExpandProperty Installed | "
"ConvertTo-Json"
)
try:
result = self._run_powershell(script)
if result.status_code != 0:
return False
stdout = result.std_out.decode("utf-8", errors="replace").strip()
return stdout.lower() == "true"
except Exception:
return False
def _discover_resource_type(
self,
endpoint: str,
resource_type: str,
architecture: CpuArchitecture,
) -> list[DiscoveredResource]:
"""Discover resources of a specific type.
Args:
endpoint: The Windows host.
resource_type: The resource type to discover.
architecture: Detected CPU architecture.
Returns:
List of DiscoveredResource objects.
"""
discovery_map = {
"windows_service": self._discover_services,
"windows_scheduled_task": self._discover_scheduled_tasks,
"windows_iis_site": self._discover_iis_sites,
"windows_iis_app_pool": self._discover_iis_app_pools,
"windows_network_adapter": self._discover_network_adapters,
"windows_firewall_rule": self._discover_firewall_rules,
"windows_installed_software": self._discover_installed_software,
"windows_feature": self._discover_windows_features,
"windows_hyperv_vm": self._discover_hyperv_vms,
"windows_hyperv_switch": self._discover_hyperv_switches,
"windows_dns_record": self._discover_dns_records,
"windows_local_user": self._discover_local_users,
"windows_local_group": self._discover_local_groups,
}
discover_fn = discovery_map.get(resource_type)
if discover_fn is None:
return []
return discover_fn(endpoint, architecture)
def _discover_services(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Windows services."""
script = (
"Get-Service | Select-Object Name, DisplayName, Status, StartType | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_service",
unique_id=f"{endpoint}/service/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"display_name": item.get("DisplayName", ""),
"status": str(item.get("Status", "")),
"start_type": str(item.get("StartType", "")),
},
)
)
return resources
def _discover_scheduled_tasks(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Windows scheduled tasks."""
script = (
"Get-ScheduledTask | Where-Object {$_.TaskPath -notlike '\\\\Microsoft\\\\*'} | "
"Select-Object TaskName, TaskPath, State | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("TaskName", "")
task_path = item.get("TaskPath", "\\")
resources.append(
DiscoveredResource(
resource_type="windows_scheduled_task",
unique_id=f"{endpoint}/scheduled_task/{task_path}{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"task_path": task_path,
"state": str(item.get("State", "")),
},
)
)
return resources
def _discover_iis_sites(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover IIS websites."""
script = (
"Import-Module WebAdministration; "
"Get-Website | Select-Object Name, ID, State, PhysicalPath | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_iis_site",
unique_id=f"{endpoint}/iis_site/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"site_id": str(item.get("ID", "")),
"state": str(item.get("State", "")),
"physical_path": item.get("PhysicalPath", ""),
},
)
)
return resources
def _discover_iis_app_pools(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover IIS application pools."""
script = (
"Import-Module WebAdministration; "
"Get-ChildItem IIS:\\AppPools | "
"Select-Object Name, State, ManagedRuntimeVersion | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_iis_app_pool",
unique_id=f"{endpoint}/iis_app_pool/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"state": str(item.get("State", "")),
"managed_runtime_version": item.get("ManagedRuntimeVersion", ""),
},
)
)
return resources
def _discover_network_adapters(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover network adapters."""
script = (
"Get-NetAdapter | Select-Object Name, InterfaceDescription, "
"Status, MacAddress, LinkSpeed | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_network_adapter",
unique_id=f"{endpoint}/network_adapter/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"interface_description": item.get("InterfaceDescription", ""),
"status": str(item.get("Status", "")),
"mac_address": item.get("MacAddress", ""),
"link_speed": item.get("LinkSpeed", ""),
},
)
)
return resources
def _discover_firewall_rules(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Windows firewall rules."""
script = (
"Get-NetFirewallRule | Where-Object {$_.Enabled -eq 'True'} | "
"Select-Object Name, DisplayName, Direction, Action, Profile | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_firewall_rule",
unique_id=f"{endpoint}/firewall_rule/{name}",
name=item.get("DisplayName", name),
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"rule_name": name,
"direction": str(item.get("Direction", "")),
"action": str(item.get("Action", "")),
"profile": str(item.get("Profile", "")),
},
)
)
return resources
def _discover_installed_software(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover installed software via registry."""
script = (
"Get-ItemProperty HKLM:\\Software\\Microsoft\\Windows\\CurrentVersion\\Uninstall\\* | "
"Where-Object {$_.DisplayName -ne $null} | "
"Select-Object DisplayName, DisplayVersion, Publisher, InstallDate | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("DisplayName", "")
resources.append(
DiscoveredResource(
resource_type="windows_installed_software",
unique_id=f"{endpoint}/installed_software/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"version": item.get("DisplayVersion", ""),
"publisher": item.get("Publisher", ""),
"install_date": item.get("InstallDate", ""),
},
)
)
return resources
def _discover_windows_features(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover installed Windows features."""
script = (
"Get-WindowsFeature | Where-Object {$_.Installed -eq $true} | "
"Select-Object Name, DisplayName, FeatureType | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_feature",
unique_id=f"{endpoint}/feature/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"display_name": item.get("DisplayName", ""),
"feature_type": item.get("FeatureType", ""),
},
)
)
return resources
def _discover_hyperv_vms(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Hyper-V virtual machines."""
script = (
"Get-VM | Select-Object Name, VMId, State, "
"MemoryAssigned, ProcessorCount, Generation | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
vm_id = str(item.get("VMId", ""))
resources.append(
DiscoveredResource(
resource_type="windows_hyperv_vm",
unique_id=f"{endpoint}/hyperv_vm/{vm_id}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"vm_id": vm_id,
"state": str(item.get("State", "")),
"memory_assigned": str(item.get("MemoryAssigned", "")),
"processor_count": str(item.get("ProcessorCount", "")),
"generation": str(item.get("Generation", "")),
},
)
)
return resources
def _discover_hyperv_switches(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover Hyper-V virtual switches."""
script = (
"Get-VMSwitch | Select-Object Name, Id, SwitchType, "
"NetAdapterInterfaceDescription | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
switch_id = str(item.get("Id", ""))
resources.append(
DiscoveredResource(
resource_type="windows_hyperv_switch",
unique_id=f"{endpoint}/hyperv_switch/{switch_id}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"switch_id": switch_id,
"switch_type": str(item.get("SwitchType", "")),
"net_adapter": item.get("NetAdapterInterfaceDescription", ""),
},
)
)
return resources
def _discover_dns_records(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover DNS records from local DNS server."""
script = (
"Get-DnsServerZone | ForEach-Object { "
"Get-DnsServerResourceRecord -ZoneName $_.ZoneName "
"-ErrorAction SilentlyContinue } | "
"Select-Object HostName, RecordType, "
"@{N='RecordData';E={$_.RecordData.IPv4Address.IPAddressToString}} | "
"ConvertTo-Json -Depth 3"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
hostname = item.get("HostName", "")
record_type = item.get("RecordType", "")
resources.append(
DiscoveredResource(
resource_type="windows_dns_record",
unique_id=f"{endpoint}/dns_record/{hostname}/{record_type}",
name=f"{hostname} ({record_type})",
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"hostname": hostname,
"record_type": record_type,
"record_data": item.get("RecordData", ""),
},
)
)
return resources
def _discover_local_users(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover local user accounts."""
script = (
"Get-LocalUser | Select-Object Name, Enabled, "
"Description, LastLogon | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_local_user",
unique_id=f"{endpoint}/local_user/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"enabled": str(item.get("Enabled", "")),
"description": item.get("Description", ""),
"last_logon": str(item.get("LastLogon", "")),
},
)
)
return resources
def _discover_local_groups(
self, endpoint: str, architecture: CpuArchitecture
) -> list[DiscoveredResource]:
"""Discover local groups."""
script = (
"Get-LocalGroup | Select-Object Name, Description, SID | "
"ConvertTo-Json -Depth 2"
)
items = self._run_powershell_json(script)
resources = []
for item in items:
name = item.get("Name", "")
resources.append(
DiscoveredResource(
resource_type="windows_local_group",
unique_id=f"{endpoint}/local_group/{name}",
name=name,
provider=ProviderType.WINDOWS,
platform_category=PlatformCategory.WINDOWS,
architecture=architecture,
endpoint=endpoint,
attributes={
"description": item.get("Description", ""),
"sid": str(item.get("SID", "")),
},
)
)
return resources

View File

@@ -0,0 +1,5 @@
"""State builder module for Terraform state file generation."""
from iac_reverse.state_builder.state_builder import StateBuilder
__all__ = ["StateBuilder"]

View File

@@ -0,0 +1,332 @@
"""Terraform state file builder (format version 4).
Generates a valid Terraform state file that binds generated resource blocks
to their corresponding live infrastructure resources using provider-assigned
unique identifiers. This enables Terraform to recognize existing resources
without attempting to recreate them.
"""
import logging
import uuid
from iac_reverse.generator.sanitize import sanitize_identifier
from iac_reverse.models import (
CodeGenerationResult,
DependencyGraph,
DiscoveredResource,
PROVIDER_SUPPORTED_RESOURCE_TYPES,
ResourceRelationship,
StateEntry,
StateFile,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# All supported resource types across all providers (for state mapping)
# ---------------------------------------------------------------------------
SUPPORTED_STATE_RESOURCE_TYPES: set[str] = set()
for _types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values():
SUPPORTED_STATE_RESOURCE_TYPES.update(_types)
# ---------------------------------------------------------------------------
# Sensitive attribute patterns
# ---------------------------------------------------------------------------
SENSITIVE_ATTRIBUTE_PATTERNS = [
"password",
"secret",
"token",
"key",
"certificate",
]
# ---------------------------------------------------------------------------
# StateBuilder
# ---------------------------------------------------------------------------
class StateBuilder:
"""Builds Terraform state files (format v4) from code generation results.
Accepts a CodeGenerationResult, DependencyGraph, and provider_version string.
Produces a StateFile with version=4, unique UUID lineage, serial=1, and
state entries for each resource in the dependency graph.
Resources that cannot be mapped (missing provider-assigned identifier or
unrecognized resource type) are excluded from the state file and tracked
in the ``unmapped_resources`` attribute.
"""
def __init__(self, terraform_version: str = "1.7.0") -> None:
"""Initialize the StateBuilder.
Args:
terraform_version: The Terraform version string to embed in the
state file. Defaults to "1.7.0".
"""
self._terraform_version = terraform_version
self._unmapped_resources: list[tuple[str, str]] = []
@property
def unmapped_resources(self) -> list[tuple[str, str]]:
"""Return the list of unmapped resources from the last build.
Each entry is a tuple of (resource_identifier, reason) where
resource_identifier is a string combining type and name, and
reason explains why the resource was excluded.
"""
return list(self._unmapped_resources)
def _is_mappable(self, resource: DiscoveredResource) -> tuple[bool, str]:
"""Check whether a resource can be mapped to a state entry.
A resource is unmappable if:
- Its unique_id is empty, None, or whitespace-only (missing
provider-assigned identifier)
- Its resource_type is not recognized/supported for state mapping
Args:
resource: The DiscoveredResource to check.
Returns:
A tuple of (is_mappable, reason). If mappable, reason is empty.
"""
# Check for missing provider-assigned identifier
if not resource.unique_id or not resource.unique_id.strip():
return (
False,
"missing provider-assigned resource identifier (empty unique_id)",
)
# Check for unrecognized resource type
if resource.resource_type not in SUPPORTED_STATE_RESOURCE_TYPES:
return (
False,
f"resource type '{resource.resource_type}' is not recognized "
f"for state mapping",
)
return (True, "")
def build(
self,
code_result: CodeGenerationResult,
graph: DependencyGraph,
provider_version: str,
) -> StateFile:
"""Build a Terraform state file from generated code and dependency graph.
Resources that cannot be mapped are excluded from the state file.
Warnings are logged for each unmapped resource, and the list of
unmapped resources is available via the ``unmapped_resources`` property.
Args:
code_result: The result of code generation (used for context).
graph: The DependencyGraph containing resources and relationships.
provider_version: The provider version string used to set
schema_version on state entries.
Returns:
A StateFile instance ready for serialization via to_json().
"""
# Reset unmapped resources tracking for this build
self._unmapped_resources = []
# Build lookup maps for dependency resolution
resource_map: dict[str, DiscoveredResource] = {
r.unique_id: r for r in graph.resources if r.unique_id
}
# Build relationships by source for dependency lookup
relationships_by_source: dict[str, list[ResourceRelationship]] = {}
for rel in graph.relationships:
relationships_by_source.setdefault(rel.source_id, []).append(rel)
# Parse schema version from provider_version string
schema_version = self._parse_schema_version(provider_version)
# Build state entries for each resource, skipping unmappable ones
entries: list[StateEntry] = []
for resource in graph.resources:
mappable, reason = self._is_mappable(resource)
if not mappable:
resource_identifier = (
f"{resource.resource_type}.{resource.name}"
)
logger.warning(
"Excluding resource '%s' from state file: %s",
resource_identifier,
reason,
)
self._unmapped_resources.append(
(resource_identifier, reason)
)
continue
entry = self._build_state_entry(
resource=resource,
resource_map=resource_map,
relationships_by_source=relationships_by_source,
schema_version=schema_version,
)
entries.append(entry)
# Generate unique lineage UUID
lineage = str(uuid.uuid4())
return StateFile(
version=4,
terraform_version=self._terraform_version,
serial=1,
lineage=lineage,
resources=entries,
)
def _build_state_entry(
self,
resource: DiscoveredResource,
resource_map: dict[str, DiscoveredResource],
relationships_by_source: dict[str, list[ResourceRelationship]],
schema_version: int,
) -> StateEntry:
"""Build a single state entry for a discovered resource.
Args:
resource: The DiscoveredResource to create a state entry for.
resource_map: Map of unique_id -> DiscoveredResource for lookups.
relationships_by_source: Map of source_id -> relationships.
schema_version: The schema version to set on the entry.
Returns:
A StateEntry binding the resource to its live infrastructure ID.
"""
# Sanitize the resource name for Terraform identifier
resource_name = sanitize_identifier(resource.name)
# Get full attribute set from discovery data
attributes = dict(resource.attributes)
# Identify sensitive attributes
sensitive_attributes = self._identify_sensitive_attributes(attributes)
# Build dependency references as Terraform resource addresses
dependencies = self._build_dependencies(
resource, resource_map, relationships_by_source
)
return StateEntry(
resource_type=resource.resource_type,
resource_name=resource_name,
provider_id=resource.unique_id,
attributes=attributes,
sensitive_attributes=sensitive_attributes,
schema_version=schema_version,
dependencies=dependencies,
)
def _identify_sensitive_attributes(
self, attributes: dict
) -> list[str]:
"""Identify attributes that should be marked as sensitive.
Checks attribute keys against known sensitive patterns:
password, secret, token, key, certificate.
Args:
attributes: The full attribute dictionary.
Returns:
List of attribute key paths that are sensitive.
"""
sensitive: list[str] = []
self._find_sensitive_keys(attributes, "", sensitive)
return sensitive
def _find_sensitive_keys(
self, obj: object, prefix: str, sensitive: list[str]
) -> None:
"""Recursively find sensitive attribute keys in nested structures.
Args:
obj: The current object to inspect (dict, list, or scalar).
prefix: The current key path prefix.
sensitive: Accumulator list for sensitive key paths.
"""
if isinstance(obj, dict):
for key, value in obj.items():
current_path = f"{prefix}.{key}" if prefix else key
key_lower = key.lower()
if any(
pattern in key_lower
for pattern in SENSITIVE_ATTRIBUTE_PATTERNS
):
sensitive.append(current_path)
# Recurse into nested dicts
if isinstance(value, dict):
self._find_sensitive_keys(value, current_path, sensitive)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
self._find_sensitive_keys(
item, f"{current_path}[{i}]", sensitive
)
def _build_dependencies(
self,
resource: DiscoveredResource,
resource_map: dict[str, DiscoveredResource],
relationships_by_source: dict[str, list[ResourceRelationship]],
) -> list[str]:
"""Build Terraform resource address references for dependencies.
Converts relationship targets into Terraform resource addresses
of the form: resource_type.resource_name
Args:
resource: The source resource.
resource_map: Map of unique_id -> DiscoveredResource.
relationships_by_source: Map of source_id -> relationships.
Returns:
List of Terraform resource addresses for dependencies.
"""
dependencies: list[str] = []
rels = relationships_by_source.get(resource.unique_id, [])
for rel in rels:
target = resource_map.get(rel.target_id)
if target is not None:
target_tf_name = sanitize_identifier(target.name)
address = f"{target.resource_type}.{target_tf_name}"
if address not in dependencies:
dependencies.append(address)
return dependencies
def _parse_schema_version(self, provider_version: str) -> int:
"""Parse a schema version integer from the provider version string.
Extracts the major version number from a semver-like string.
For example, "3.2.1" returns 3, "1" returns 1.
Args:
provider_version: A version string (e.g., "3.2.1", "1.0.0").
Returns:
The major version number as an integer, or 0 if parsing fails.
"""
try:
# Take the first numeric segment as the schema version
parts = provider_version.strip().split(".")
return int(parts[0])
except (ValueError, IndexError):
logger.warning(
"Could not parse schema version from '%s', defaulting to 0",
provider_version,
)
return 0

View File

@@ -0,0 +1,5 @@
"""Validator module for Terraform output validation."""
from iac_reverse.validator.validator import Validator
__all__ = ["Validator"]

View File

@@ -0,0 +1,653 @@
"""Terraform validation runner.
Runs terraform init, validate, and plan against generated output
to verify syntactic correctness and detect infrastructure drift.
Includes auto-correction logic that attempts to fix common validation
errors heuristically.
"""
import json
import re
import shutil
import subprocess
from pathlib import Path
from iac_reverse.models import PlannedChange, ValidationError, ValidationResult
class Validator:
"""Runs Terraform commands to validate generated IaC output.
Validates generated .tf and .tfstate files by running terraform init,
terraform validate, and terraform plan. Reports validation errors and
planned changes (drift) back to the caller.
When validation fails, attempts heuristic-based auto-corrections up to
max_correction_attempts times before reporting failure.
"""
def validate(
self, output_dir: str, max_correction_attempts: int = 3
) -> ValidationResult:
"""Run terraform init, validate, and plan against the output directory.
After terraform validate fails, attempts auto-correction of common
errors (unknown attributes, missing required blocks, syntax issues)
up to max_correction_attempts times. Re-validates after each correction.
Args:
output_dir: Path to directory containing generated .tf and .tfstate files.
max_correction_attempts: Maximum number of auto-correction attempts
before reporting failure. Defaults to 3.
Returns:
ValidationResult with init/validate/plan success flags,
any planned changes (drift), validation errors, and the number
of correction attempts made.
"""
# Check terraform binary availability
terraform_bin = shutil.which("terraform")
if terraform_bin is None:
return ValidationResult(
init_success=False,
validate_success=False,
plan_success=False,
errors=[
ValidationError(
file="",
message=(
"Terraform binary not found. "
"Terraform is required for validation. "
"Please install Terraform and ensure it is on your PATH."
),
)
],
correction_attempts=0,
)
output_path = Path(output_dir)
errors: list[ValidationError] = []
planned_changes: list[PlannedChange] = []
# Run terraform init
init_success = self._run_init(output_path, errors)
if not init_success:
return ValidationResult(
init_success=False,
validate_success=False,
plan_success=False,
errors=errors,
correction_attempts=0,
)
# Run terraform validate with auto-correction loop
correction_attempts = 0
validate_success = self._run_validate(output_path, errors)
while not validate_success and correction_attempts < max_correction_attempts:
# Attempt to correct the errors
corrected = self._attempt_correction(output_path, errors)
if not corrected:
# No corrections could be applied, stop trying
break
correction_attempts += 1
# Re-validate after correction
errors = []
validate_success = self._run_validate(output_path, errors)
if not validate_success:
return ValidationResult(
init_success=True,
validate_success=False,
plan_success=False,
errors=errors,
correction_attempts=correction_attempts,
)
# Run terraform plan
plan_success = self._run_plan(output_path, errors, planned_changes)
return ValidationResult(
init_success=True,
validate_success=True,
plan_success=plan_success,
planned_changes=planned_changes,
errors=errors,
correction_attempts=correction_attempts,
)
def _run_init(
self, output_path: Path, errors: list[ValidationError]
) -> bool:
"""Run terraform init in the output directory.
Returns True if init succeeds, False otherwise.
"""
try:
result = subprocess.run(
["terraform", "init", "-no-color"],
cwd=str(output_path),
capture_output=True,
text=True,
timeout=120,
)
if result.returncode != 0:
errors.append(
ValidationError(
file="",
message=f"terraform init failed: {result.stderr.strip()}",
)
)
return False
return True
except subprocess.TimeoutExpired:
errors.append(
ValidationError(
file="",
message="terraform init timed out after 120 seconds",
)
)
return False
except OSError as e:
errors.append(
ValidationError(
file="",
message=f"Failed to execute terraform init: {e}",
)
)
return False
def _run_validate(
self, output_path: Path, errors: list[ValidationError]
) -> bool:
"""Run terraform validate with JSON output and parse errors.
Returns True if validation passes, False otherwise.
"""
try:
result = subprocess.run(
["terraform", "validate", "-json"],
cwd=str(output_path),
capture_output=True,
text=True,
timeout=60,
)
return self._parse_validate_output(result.stdout, errors)
except subprocess.TimeoutExpired:
errors.append(
ValidationError(
file="",
message="terraform validate timed out after 60 seconds",
)
)
return False
except OSError as e:
errors.append(
ValidationError(
file="",
message=f"Failed to execute terraform validate: {e}",
)
)
return False
def _parse_validate_output(
self, stdout: str, errors: list[ValidationError]
) -> bool:
"""Parse terraform validate JSON output.
Expected format:
{
"valid": true/false,
"error_count": N,
"diagnostics": [
{
"severity": "error",
"summary": "...",
"detail": "...",
"range": {
"filename": "main.tf",
"start": {"line": 1, "column": 1},
...
}
}
]
}
"""
try:
data = json.loads(stdout)
except (json.JSONDecodeError, TypeError):
errors.append(
ValidationError(
file="",
message="Failed to parse terraform validate output as JSON",
)
)
return False
if data.get("valid", False):
return True
diagnostics = data.get("diagnostics", [])
for diag in diagnostics:
if diag.get("severity") != "error":
continue
filename = ""
line = None
range_info = diag.get("range")
if range_info:
filename = range_info.get("filename", "")
start = range_info.get("start")
if start:
line = start.get("line")
summary = diag.get("summary", "")
detail = diag.get("detail", "")
message = summary
if detail:
message = f"{summary}: {detail}"
errors.append(
ValidationError(file=filename, message=message, line=line)
)
return False
def _run_plan(
self,
output_path: Path,
errors: list[ValidationError],
planned_changes: list[PlannedChange],
) -> bool:
"""Run terraform plan with JSON output and parse planned changes.
Returns True if zero changes are planned, False otherwise.
"""
try:
result = subprocess.run(
["terraform", "plan", "-json", "-no-color"],
cwd=str(output_path),
capture_output=True,
text=True,
timeout=300,
)
if result.returncode not in (0, 2):
# returncode 2 means changes are planned, which is valid output
errors.append(
ValidationError(
file="",
message=f"terraform plan failed: {result.stderr.strip()}",
)
)
return False
return self._parse_plan_output(
result.stdout, errors, planned_changes
)
except subprocess.TimeoutExpired:
errors.append(
ValidationError(
file="",
message="terraform plan timed out after 300 seconds",
)
)
return False
except OSError as e:
errors.append(
ValidationError(
file="",
message=f"Failed to execute terraform plan: {e}",
)
)
return False
def _parse_plan_output(
self,
stdout: str,
errors: list[ValidationError],
planned_changes: list[PlannedChange],
) -> bool:
"""Parse terraform plan JSON output (streaming JSON lines format).
Terraform plan -json outputs one JSON object per line. We look for
lines with type "resource_drift" or "planned_change" to identify
changes, and "change_summary" for the overall result.
Each resource change line looks like:
{
"type": "planned_change",
"change": {
"resource": {
"addr": "aws_instance.example"
},
"action": "create" | "update" | "delete"
}
}
"""
has_changes = False
for line in stdout.strip().splitlines():
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
except json.JSONDecodeError:
continue
entry_type = entry.get("type", "")
if entry_type in ("planned_change", "resource_drift"):
change = entry.get("change", {})
resource = change.get("resource", {})
resource_addr = resource.get("addr", "unknown")
action = change.get("action", "unknown")
# Map terraform action names to our change types
change_type = self._map_action_to_change_type(action)
# Build details from before/after if available
details = f"Action: {action}"
planned_changes.append(
PlannedChange(
resource_address=resource_addr,
change_type=change_type,
details=details,
)
)
has_changes = True
elif entry_type == "change_summary":
changes_info = entry.get("changes", {})
add = changes_info.get("add", 0)
change = changes_info.get("change", 0)
remove = changes_info.get("remove", 0)
if add + change + remove > 0:
has_changes = True
# plan_success is True only when there are zero planned changes
return not has_changes
@staticmethod
def _map_action_to_change_type(action: str) -> str:
"""Map terraform plan action to our change type vocabulary."""
action_map = {
"create": "add",
"update": "modify",
"delete": "destroy",
"replace": "modify",
"read": "add",
}
return action_map.get(action, action)
# ------------------------------------------------------------------
# Auto-correction logic
# ------------------------------------------------------------------
def _attempt_correction(
self, output_path: Path, errors: list[ValidationError]
) -> bool:
"""Attempt to auto-correct validation errors using heuristics.
Applies corrections for:
- Unknown/unsupported attributes (removes the offending line)
- Missing required provider blocks (adds empty provider block)
- Common syntax issues (unclosed braces, trailing commas)
Args:
output_path: Path to the directory containing .tf files.
errors: List of validation errors to attempt to correct.
Returns:
True if at least one correction was applied, False otherwise.
"""
any_corrected = False
for error in errors:
corrected = self._correct_single_error(output_path, error)
if corrected:
any_corrected = True
return any_corrected
def _correct_single_error(
self, output_path: Path, error: ValidationError
) -> bool:
"""Attempt to correct a single validation error.
Returns True if a correction was applied.
"""
message = error.message.lower()
# Handle unknown/unsupported attribute errors
if self._is_unknown_attribute_error(message):
return self._remove_attribute_line(output_path, error)
# Handle missing required provider block
if self._is_missing_provider_error(message):
return self._add_missing_provider_block(output_path, error)
# Handle syntax errors (unclosed braces, trailing commas)
if self._is_syntax_error(message):
return self._fix_syntax_error(output_path, error)
return False
@staticmethod
def _is_unknown_attribute_error(message: str) -> bool:
"""Check if the error is about an unknown or unsupported attribute."""
patterns = [
"unsupported argument",
"unsupported attribute",
"unknown attribute",
"an argument named",
"is not expected here",
"no such attribute",
]
return any(p in message for p in patterns)
@staticmethod
def _is_missing_provider_error(message: str) -> bool:
"""Check if the error is about a missing required provider."""
patterns = [
"missing required provider",
"provider configuration not present",
"no provider",
"required provider",
]
return any(p in message for p in patterns)
@staticmethod
def _is_syntax_error(message: str) -> bool:
"""Check if the error is a syntax error that might be fixable."""
patterns = [
"unexpected closing brace",
"unclosed configuration block",
"expected closing brace",
"invalid character",
"trailing comma",
"argument or block definition required",
]
return any(p in message for p in patterns)
def _remove_attribute_line(
self, output_path: Path, error: ValidationError
) -> bool:
"""Remove the line containing an unknown/unsupported attribute.
If the error has file and line info, removes that specific line.
Otherwise, attempts to find and remove the attribute by name from
the error message.
"""
if not error.file:
return False
file_path = output_path / error.file
if not file_path.exists():
return False
try:
lines = file_path.read_text(encoding="utf-8").splitlines()
except OSError:
return False
if error.line is not None and 1 <= error.line <= len(lines):
# Remove the specific line
line_idx = error.line - 1
removed_line = lines[line_idx].strip()
# Only remove if it looks like an attribute assignment
if "=" in removed_line or removed_line.endswith("{"):
lines.pop(line_idx)
try:
file_path.write_text(
"\n".join(lines) + "\n", encoding="utf-8"
)
return True
except OSError:
return False
# Try to find the attribute name from the error message
attr_name = self._extract_attribute_name(error.message)
if attr_name:
return self._remove_attribute_by_name(file_path, attr_name, lines)
return False
@staticmethod
def _extract_attribute_name(message: str) -> str:
"""Extract the attribute name from an error message.
Looks for patterns like:
- "An argument named 'foo' is not expected here"
- "Unsupported argument: foo"
"""
# Pattern: quoted attribute name
match = re.search(r"['\"](\w+)['\"]", message)
if match:
return match.group(1)
# Pattern: "named X is not"
match = re.search(r"named\s+(\w+)\s+is", message)
if match:
return match.group(1)
return ""
@staticmethod
def _remove_attribute_by_name(
file_path: Path, attr_name: str, lines: list[str]
) -> bool:
"""Remove lines containing the given attribute assignment."""
pattern = re.compile(rf"^\s*{re.escape(attr_name)}\s*=")
new_lines = [line for line in lines if not pattern.match(line)]
if len(new_lines) == len(lines):
return False # Nothing was removed
try:
file_path.write_text("\n".join(new_lines) + "\n", encoding="utf-8")
return True
except OSError:
return False
def _add_missing_provider_block(
self, output_path: Path, error: ValidationError
) -> bool:
"""Add a missing provider block to the configuration.
Extracts the provider name from the error message and creates
an empty provider block in a providers.tf file.
"""
provider_name = self._extract_provider_name(error.message)
if not provider_name:
return False
providers_file = output_path / "providers.tf"
provider_block = f'\nprovider "{provider_name}" {{}}\n'
try:
if providers_file.exists():
existing = providers_file.read_text(encoding="utf-8")
# Don't add if already present
if f'provider "{provider_name}"' in existing:
return False
providers_file.write_text(
existing + provider_block, encoding="utf-8"
)
else:
providers_file.write_text(provider_block, encoding="utf-8")
return True
except OSError:
return False
@staticmethod
def _extract_provider_name(message: str) -> str:
"""Extract provider name from a missing provider error message.
Looks for patterns like:
- "Missing required provider 'aws'"
- 'provider "kubernetes" configuration not present'
"""
match = re.search(r"provider\s+['\"](\w+)['\"]", message)
if match:
return match.group(1)
match = re.search(r"['\"](\w+)['\"]", message)
if match:
return match.group(1)
return ""
def _fix_syntax_error(
self, output_path: Path, error: ValidationError
) -> bool:
"""Attempt to fix common syntax errors.
Handles:
- Trailing commas before closing braces
- Missing closing braces
- Lines with 'argument or block definition required' (remove empty/bad lines)
"""
if not error.file:
return False
file_path = output_path / error.file
if not file_path.exists():
return False
try:
content = file_path.read_text(encoding="utf-8")
except OSError:
return False
original_content = content
# Fix trailing commas before closing braces/brackets
content = re.sub(r",(\s*[}\]])", r"\1", content)
# Fix 'argument or block definition required' - remove empty lines
# at the error location
if error.line is not None and "argument or block definition required" in error.message.lower():
lines = content.splitlines()
if 1 <= error.line <= len(lines):
line_idx = error.line - 1
line = lines[line_idx].strip()
# Remove the problematic line if it's empty or just whitespace/punctuation
if not line or line in (",", ";"):
lines.pop(line_idx)
content = "\n".join(lines) + "\n"
if content != original_content:
try:
file_path.write_text(content, encoding="utf-8")
return True
except OSError:
return False
return False

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Test suite for IaC Reverse Engineering Tool."""

Binary file not shown.

View File

@@ -0,0 +1 @@
"""Integration tests for IaC Reverse Engineering Tool."""

View File

@@ -0,0 +1 @@
"""Property-based tests for IaC Reverse Engineering Tool."""

Binary file not shown.

View File

@@ -0,0 +1,719 @@
"""Property-based tests for the Code Generator.
**Validates: Requirements 2.2, 3.1, 3.2, 3.3, 3.4, 3.5, 3.6**
Properties tested:
- Property 10: References in generated output use Terraform syntax
- Property 11: Generated HCL syntactic validity
- Property 12: File organization by resource type
- Property 13: Variable extraction for shared values
- Property 14: Identifier sanitization validity
- Property 15: Traceability comments in generated code
"""
import re
from hypothesis import given, settings, assume, HealthCheck
from hypothesis import strategies as st
from iac_reverse.generator import CodeGenerator, VariableExtractor, sanitize_identifier
from iac_reverse.models import (
CpuArchitecture,
DependencyGraph,
DiscoveredResource,
PlatformCategory,
ProviderType,
ResourceRelationship,
ScanProfile,
)
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
platform_category_strategy = st.sampled_from(list(PlatformCategory))
cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture))
# Strategy for resource names (valid identifiers with some variety)
resource_name_strategy = st.text(
min_size=1,
max_size=20,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"),
).filter(lambda s: s.strip() != "")
# Strategy for resource types (terraform-style: provider_type)
resource_type_strategy = st.sampled_from([
"kubernetes_deployment",
"kubernetes_service",
"kubernetes_namespace",
"docker_service",
"docker_network",
"docker_volume",
"synology_shared_folder",
"synology_volume",
"harvester_virtualmachine",
"harvester_volume",
"bare_metal_hardware",
"windows_service",
"windows_iis_site",
])
# Strategy for simple attribute values (strings, ints, bools)
simple_attr_value_strategy = st.one_of(
st.text(min_size=1, max_size=30, alphabet=st.characters(
whitelist_categories=("L", "N"), whitelist_characters="_-./: "
)).filter(lambda s: s.strip() != ""),
st.integers(min_value=0, max_value=10000),
st.booleans(),
)
# Strategy for attribute dictionaries
attributes_strategy = st.dictionaries(
keys=st.text(
min_size=1,
max_size=15,
alphabet=st.characters(whitelist_categories=("L",), whitelist_characters="_"),
).filter(lambda s: s.strip() != "" and s[0].isalpha()),
values=simple_attr_value_strategy,
min_size=1,
max_size=5,
)
def make_resource(
unique_id: str,
resource_type: str = "kubernetes_deployment",
name: str = "my_resource",
provider: ProviderType = ProviderType.KUBERNETES,
platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION,
architecture: CpuArchitecture = CpuArchitecture.AMD64,
attributes: dict | None = None,
raw_references: list[str] | None = None,
) -> DiscoveredResource:
"""Helper to create a DiscoveredResource with sensible defaults."""
return DiscoveredResource(
resource_type=resource_type,
unique_id=unique_id,
name=name,
provider=provider,
platform_category=platform_category,
architecture=architecture,
endpoint="https://api.internal.lab:6443",
attributes=attributes or {"key": "value"},
raw_references=raw_references or [],
)
def make_dependency_graph(
resources: list[DiscoveredResource],
relationships: list[ResourceRelationship] | None = None,
) -> DependencyGraph:
"""Helper to create a DependencyGraph from resources."""
return DependencyGraph(
resources=resources,
relationships=relationships or [],
topological_order=[r.unique_id for r in resources],
cycles=[],
unresolved_references=[],
)
@st.composite
def resource_with_dependency_strategy(draw):
"""Generate a pair of resources where one depends on the other.
Returns (resources, relationships) where the first resource references the second.
"""
resource_type_a = draw(resource_type_strategy)
resource_type_b = draw(resource_type_strategy)
name_a = draw(resource_name_strategy)
name_b = draw(resource_name_strategy)
arch = draw(cpu_architecture_strategy)
# Ensure unique IDs are different
uid_a = f"ns/{resource_type_a}/{name_a}"
uid_b = f"ns/{resource_type_b}/{name_b}"
assume(uid_a != uid_b)
# Resource B is the dependency target
resource_b = make_resource(
unique_id=uid_b,
resource_type=resource_type_b,
name=name_b,
architecture=arch,
attributes={"port": 8080},
)
# Resource A references resource B's unique_id in its attributes
resource_a = make_resource(
unique_id=uid_a,
resource_type=resource_type_a,
name=name_a,
architecture=arch,
attributes={"target_id": uid_b, "replicas": 3},
raw_references=[uid_b],
)
relationship = ResourceRelationship(
source_id=uid_a,
target_id=uid_b,
relationship_type="reference",
source_attribute="target_id",
)
return [resource_a, resource_b], [relationship]
@st.composite
def multiple_resources_strategy(draw):
"""Generate a list of resources with distinct types for file organization testing."""
num_types = draw(st.integers(min_value=1, max_value=5))
types = draw(
st.lists(
resource_type_strategy,
min_size=num_types,
max_size=num_types,
unique=True,
)
)
resources = []
for i, rtype in enumerate(types):
# Each type gets 1-3 resources
num_resources_of_type = draw(st.integers(min_value=1, max_value=3))
for j in range(num_resources_of_type):
uid = f"{rtype}/instance_{i}_{j}"
name = f"res_{i}_{j}"
attrs = draw(attributes_strategy)
resource = make_resource(
unique_id=uid,
resource_type=rtype,
name=name,
attributes=attrs,
)
resources.append(resource)
return resources
@st.composite
def resources_with_shared_values_strategy(draw):
"""Generate resources where at least one attribute value appears in 2+ resources."""
shared_key = draw(st.sampled_from(["region", "environment", "zone", "cluster"]))
shared_value = draw(st.text(
min_size=1,
max_size=15,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"),
).filter(lambda s: s.strip() != ""))
num_resources = draw(st.integers(min_value=2, max_value=5))
resources = []
for i in range(num_resources):
uid = f"resource_{i}"
name = f"res_{i}"
# All resources share the same key-value pair
attrs = {shared_key: shared_value, "name": f"instance_{i}"}
resource = make_resource(
unique_id=uid,
resource_type="kubernetes_deployment",
name=name,
attributes=attrs,
)
resources.append(resource)
return resources, shared_key, shared_value
# Strategy for arbitrary strings to test sanitize_identifier
arbitrary_string_strategy = st.text(min_size=0, max_size=50)
# ---------------------------------------------------------------------------
# Property 10: References in generated output use Terraform syntax
# ---------------------------------------------------------------------------
class TestReferencesUseTerraformSyntax:
"""Property 10: References in generated output use Terraform syntax.
**Validates: Requirements 2.2, 3.5**
For any resource with dependencies, the generated HCL uses Terraform
resource references (type.name.id) not hardcoded IDs.
"""
@given(data=resource_with_dependency_strategy())
@settings(max_examples=100)
def test_references_use_terraform_resource_syntax(
self, data: tuple[list[DiscoveredResource], list[ResourceRelationship]]
):
"""Generated HCL uses type.name.id references instead of hardcoded IDs."""
resources, relationships = data
graph = make_dependency_graph(resources, relationships)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
# The source resource (resources[0]) references resources[1]
target = resources[1]
target_tf_name = sanitize_identifier(target.name)
expected_ref = f"{target.resource_type}.{target_tf_name}.id"
# Find the file containing the source resource
source = resources[0]
source_file = None
for f in result.resource_files:
if f.filename == f"{source.resource_type}.tf":
source_file = f
break
assert source_file is not None, (
f"Expected file {source.resource_type}.tf not found"
)
# The generated content should contain the Terraform reference
assert expected_ref in source_file.content, (
f"Expected Terraform reference '{expected_ref}' not found in output. "
f"Content: {source_file.content[:500]}"
)
@given(data=resource_with_dependency_strategy())
@settings(max_examples=100)
def test_hardcoded_ids_not_present_for_resolved_references(
self, data: tuple[list[DiscoveredResource], list[ResourceRelationship]]
):
"""The target resource's unique_id should not appear as a hardcoded string in the source resource's block."""
resources, relationships = data
graph = make_dependency_graph(resources, relationships)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
target = resources[1]
source = resources[0]
# Find the file containing the source resource
source_file = None
for f in result.resource_files:
if f.filename == f"{source.resource_type}.tf":
source_file = f
break
assert source_file is not None
# The hardcoded unique_id of the target should NOT appear as a quoted string
hardcoded_pattern = f'"{target.unique_id}"'
assert hardcoded_pattern not in source_file.content, (
f"Hardcoded ID '{hardcoded_pattern}' should not appear in generated HCL. "
f"Should use Terraform reference instead."
)
# ---------------------------------------------------------------------------
# Property 11: Generated HCL syntactic validity
# ---------------------------------------------------------------------------
class TestGeneratedHclSyntacticValidity:
"""Property 11: Generated HCL syntactic validity.
**Validates: Requirements 3.1**
For any set of resources, the generated HCL contains valid resource blocks
with proper structure (resource keyword, type, name, braces).
"""
@given(resources=multiple_resources_strategy())
@settings(max_examples=100, suppress_health_check=[HealthCheck.too_slow])
def test_generated_hcl_has_valid_resource_blocks(
self, resources: list[DiscoveredResource]
):
"""Each generated file contains properly structured resource blocks."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
for gen_file in result.resource_files:
content = gen_file.content
# Each resource block should have the pattern:
# resource "type" "name" {
resource_block_pattern = re.compile(
r'resource\s+"[^"]+"\s+"[^"]+"\s*\{'
)
blocks_found = resource_block_pattern.findall(content)
assert len(blocks_found) == gen_file.resource_count, (
f"Expected {gen_file.resource_count} resource blocks in "
f"{gen_file.filename}, found {len(blocks_found)}"
)
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_generated_hcl_has_balanced_braces(
self, resources: list[DiscoveredResource]
):
"""Generated HCL has balanced opening and closing braces."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
for gen_file in result.resource_files:
content = gen_file.content
open_braces = content.count("{")
close_braces = content.count("}")
assert open_braces == close_braces, (
f"Unbalanced braces in {gen_file.filename}: "
f"{open_braces} opening vs {close_braces} closing"
)
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_generated_hcl_resource_type_matches_filename(
self, resources: list[DiscoveredResource]
):
"""Each resource block's type matches the file it's in (filename = type.tf)."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
for gen_file in result.resource_files:
expected_type = gen_file.filename.replace(".tf", "")
# All resource blocks in this file should be of the expected type
resource_types_in_file = re.findall(
r'resource\s+"([^"]+)"', gen_file.content
)
for rtype in resource_types_in_file:
assert rtype == expected_type, (
f"Resource type '{rtype}' found in {gen_file.filename} "
f"but expected only '{expected_type}'"
)
# ---------------------------------------------------------------------------
# Property 12: File organization by resource type
# ---------------------------------------------------------------------------
class TestFileOrganizationByResourceType:
"""Property 12: File organization by resource type.
**Validates: Requirements 3.2**
For any set of resources, each resource type gets its own .tf file.
"""
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_one_file_per_resource_type(
self, resources: list[DiscoveredResource]
):
"""The number of resource files equals the number of distinct resource types."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
distinct_types = {r.resource_type for r in resources}
assert len(result.resource_files) == len(distinct_types), (
f"Expected {len(distinct_types)} files for {len(distinct_types)} "
f"distinct types, got {len(result.resource_files)}"
)
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_each_file_named_after_resource_type(
self, resources: list[DiscoveredResource]
):
"""Each generated file is named <resource_type>.tf."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
distinct_types = {r.resource_type for r in resources}
expected_filenames = {f"{rt}.tf" for rt in distinct_types}
actual_filenames = {f.filename for f in result.resource_files}
assert actual_filenames == expected_filenames, (
f"Expected filenames {expected_filenames}, got {actual_filenames}"
)
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_every_resource_appears_in_exactly_one_file(
self, resources: list[DiscoveredResource]
):
"""Every resource's unique_id appears in exactly one generated file."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
for resource in resources:
files_containing = [
f.filename
for f in result.resource_files
if resource.unique_id in f.content
]
assert len(files_containing) == 1, (
f"Resource '{resource.unique_id}' found in {len(files_containing)} "
f"files: {files_containing}. Expected exactly 1."
)
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_resource_count_per_file_matches(
self, resources: list[DiscoveredResource]
):
"""Each file's resource_count matches the actual number of resources of that type."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
# Count resources per type
from collections import Counter
type_counts = Counter(r.resource_type for r in resources)
for gen_file in result.resource_files:
expected_type = gen_file.filename.replace(".tf", "")
assert gen_file.resource_count == type_counts[expected_type], (
f"File {gen_file.filename} reports {gen_file.resource_count} resources "
f"but expected {type_counts[expected_type]}"
)
# ---------------------------------------------------------------------------
# Property 13: Variable extraction for shared values
# ---------------------------------------------------------------------------
class TestVariableExtractionForSharedValues:
"""Property 13: Variable extraction for shared values.
**Validates: Requirements 3.3**
For any set of resources where a value appears in 2+ resources,
a variable is extracted.
"""
@given(data=resources_with_shared_values_strategy())
@settings(max_examples=100)
def test_shared_value_produces_extracted_variable(
self, data: tuple[list[DiscoveredResource], str, str]
):
"""A value appearing in 2+ resources results in an extracted variable."""
resources, shared_key, shared_value = data
extractor = VariableExtractor()
variables = extractor.extract_variables(resources)
# There should be at least one variable extracted for the shared key
var_names = [v.name for v in variables]
# The variable name should contain the shared key
matching_vars = [v for v in variables if shared_key in v.name]
assert len(matching_vars) >= 1, (
f"Expected at least one variable for shared key '{shared_key}', "
f"got variables: {var_names}"
)
@given(data=resources_with_shared_values_strategy())
@settings(max_examples=100)
def test_extracted_variable_has_correct_default(
self, data: tuple[list[DiscoveredResource], str, str]
):
"""The extracted variable's default value matches the shared value."""
resources, shared_key, shared_value = data
extractor = VariableExtractor()
variables = extractor.extract_variables(resources)
matching_vars = [v for v in variables if shared_key in v.name]
assert len(matching_vars) >= 1
# The default should be the shared value (formatted as a string literal)
var = matching_vars[0]
assert shared_value in var.default_value, (
f"Expected default to contain '{shared_value}', got '{var.default_value}'"
)
@given(data=resources_with_shared_values_strategy())
@settings(max_examples=100)
def test_extracted_variable_tracks_usage(
self, data: tuple[list[DiscoveredResource], str, str]
):
"""The extracted variable's used_by list contains at least 2 resource IDs."""
resources, shared_key, shared_value = data
extractor = VariableExtractor()
variables = extractor.extract_variables(resources)
matching_vars = [v for v in variables if shared_key in v.name]
assert len(matching_vars) >= 1
var = matching_vars[0]
assert len(var.used_by) >= 2, (
f"Expected variable to be used by 2+ resources, "
f"got {len(var.used_by)}: {var.used_by}"
)
@given(data=resources_with_shared_values_strategy())
@settings(max_examples=100)
def test_extracted_variable_has_type_and_description(
self, data: tuple[list[DiscoveredResource], str, str]
):
"""Each extracted variable has a non-empty type expression and description."""
resources, shared_key, shared_value = data
extractor = VariableExtractor()
variables = extractor.extract_variables(resources)
for var in variables:
assert var.type_expr != "", f"Variable '{var.name}' has empty type_expr"
assert var.description != "", f"Variable '{var.name}' has empty description"
# ---------------------------------------------------------------------------
# Property 14: Identifier sanitization validity
# ---------------------------------------------------------------------------
class TestIdentifierSanitizationValidity:
"""Property 14: Identifier sanitization validity.
**Validates: Requirements 3.4**
For any input string, sanitize_identifier produces a valid Terraform identifier.
"""
TERRAFORM_IDENTIFIER_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
@given(name=arbitrary_string_strategy)
@settings(max_examples=200)
def test_sanitized_identifier_matches_terraform_pattern(self, name: str):
"""The output always matches ^[a-zA-Z_][a-zA-Z0-9_]*$."""
result = sanitize_identifier(name)
assert self.TERRAFORM_IDENTIFIER_REGEX.match(result), (
f"sanitize_identifier({name!r}) = {result!r} does not match "
f"Terraform identifier pattern"
)
@given(name=arbitrary_string_strategy)
@settings(max_examples=200)
def test_sanitized_identifier_is_non_empty(self, name: str):
"""The output is always a non-empty string."""
result = sanitize_identifier(name)
assert len(result) > 0, (
f"sanitize_identifier({name!r}) produced empty string"
)
@given(name=st.text(min_size=1, max_size=30, alphabet="0123456789"))
@settings(max_examples=100)
def test_digit_only_input_produces_valid_identifier(self, name: str):
"""Input consisting only of digits still produces a valid identifier."""
result = sanitize_identifier(name)
assert self.TERRAFORM_IDENTIFIER_REGEX.match(result), (
f"sanitize_identifier({name!r}) = {result!r} is not valid for digit-only input"
)
# Must not start with a digit
assert not result[0].isdigit(), (
f"sanitize_identifier({name!r}) = {result!r} starts with a digit"
)
@given(name=st.text(
min_size=1,
max_size=30,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"),
).filter(lambda s: s[0].isalpha() or s[0] == "_"))
@settings(max_examples=100)
def test_already_valid_identifiers_are_preserved_or_simplified(self, name: str):
"""Input that is already a valid identifier produces a valid result."""
result = sanitize_identifier(name)
assert self.TERRAFORM_IDENTIFIER_REGEX.match(result), (
f"sanitize_identifier({name!r}) = {result!r} is not valid"
)
# ---------------------------------------------------------------------------
# Property 15: Traceability comments in generated code
# ---------------------------------------------------------------------------
class TestTraceabilityCommentsInGeneratedCode:
"""Property 15: Traceability comments in generated code.
**Validates: Requirements 3.6**
For any resource, the generated HCL includes a comment with the original unique_id.
"""
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_each_resource_has_traceability_comment(
self, resources: list[DiscoveredResource]
):
"""Every resource's unique_id appears in a comment in the generated output."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
# Collect all generated content
all_content = "\n".join(f.content for f in result.resource_files)
for resource in resources:
# The unique_id should appear in a comment line
comment_pattern = f"# Source: {resource.unique_id}"
assert comment_pattern in all_content, (
f"Traceability comment for resource '{resource.unique_id}' "
f"not found in generated output"
)
@given(resources=multiple_resources_strategy())
@settings(max_examples=100)
def test_traceability_comment_precedes_resource_block(
self, resources: list[DiscoveredResource]
):
"""The traceability comment appears before its corresponding resource block."""
graph = make_dependency_graph(resources)
profiles: list[ScanProfile] = []
generator = CodeGenerator()
result = generator.generate(graph, profiles)
for resource in resources:
# Find the file containing this resource
target_file = None
for f in result.resource_files:
if resource.unique_id in f.content:
target_file = f
break
assert target_file is not None
content = target_file.content
comment_pos = content.find(f"# Source: {resource.unique_id}")
tf_name = sanitize_identifier(resource.name)
block_pattern = f'resource "{resource.resource_type}" "{tf_name}"'
block_pos = content.find(block_pattern, comment_pos)
assert comment_pos < block_pos, (
f"Comment for '{resource.unique_id}' (pos {comment_pos}) "
f"should precede resource block (pos {block_pos})"
)

View File

@@ -0,0 +1,565 @@
"""Property-based tests for the Dependency Resolver.
**Validates: Requirements 2.1, 2.3, 2.4, 2.5**
Properties tested:
- Property 6: Dependency relationship identification
- Property 7: Cycle detection correctness
- Property 8: Topological order validity
- Property 9: Unresolved references become data sources or variables
"""
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from iac_reverse.models import (
CpuArchitecture,
DependencyGraph,
DiscoveredResource,
PlatformCategory,
ProviderType,
ResourceRelationship,
ScanResult,
UnresolvedReference,
)
from iac_reverse.resolver import DependencyResolver
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
platform_category_strategy = st.sampled_from(list(PlatformCategory))
cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture))
# Strategy for generating valid resource IDs
resource_id_strategy = st.text(
min_size=3,
max_size=50,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-/"),
).filter(lambda s: s.strip() != "" and len(s) >= 3)
# Strategy for resource names
resource_name_strategy = st.text(
min_size=1,
max_size=30,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"),
).filter(lambda s: s.strip() != "")
# Strategy for resource types (simple identifiers)
resource_type_strategy = st.text(
min_size=3,
max_size=40,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"),
).filter(lambda s: s.strip() != "" and len(s) >= 3)
# Strategy for endpoint strings
endpoint_strategy = st.text(
min_size=5,
max_size=50,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters=".-:/"),
).filter(lambda s: s.strip() != "")
def make_resource(
unique_id: str,
resource_type: str = "generic_resource",
name: str = "resource",
raw_references: list[str] | None = None,
attributes: dict | None = None,
) -> DiscoveredResource:
"""Helper to create a DiscoveredResource with sensible defaults."""
return DiscoveredResource(
resource_type=resource_type,
unique_id=unique_id,
name=name,
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AMD64,
endpoint="https://api.internal.lab:6443",
attributes=attributes or {"key": "value"},
raw_references=raw_references or [],
)
def make_scan_result(resources: list[DiscoveredResource]) -> ScanResult:
"""Helper to create a ScanResult from a list of resources."""
return ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="2024-01-15T10:30:00Z",
profile_hash="test_hash",
is_partial=False,
)
# Strategy to generate a list of resources with unique IDs and controlled references
@st.composite
def acyclic_resource_graph_strategy(draw):
"""Generate a set of resources forming an acyclic dependency graph.
Resources are created in order, and each resource can only reference
resources that were created before it (ensuring no cycles).
"""
num_resources = draw(st.integers(min_value=2, max_value=8))
resources = []
ids = []
for i in range(num_resources):
uid = f"resource_{i}"
ids.append(uid)
# Each resource can only reference earlier resources (ensures acyclic)
if i > 0:
num_refs = draw(st.integers(min_value=0, max_value=min(i, 3)))
refs = draw(
st.lists(
st.sampled_from(ids[:i]),
min_size=num_refs,
max_size=num_refs,
unique=True,
)
)
else:
refs = []
resource = make_resource(
unique_id=uid,
name=f"res_{i}",
raw_references=refs,
)
resources.append(resource)
return resources
@st.composite
def cyclic_resource_graph_strategy(draw):
"""Generate a set of resources that contain at least one cycle.
Creates a base set of resources and then adds references to form a cycle.
"""
num_resources = draw(st.integers(min_value=2, max_value=6))
resources = []
ids = []
for i in range(num_resources):
uid = f"resource_{i}"
ids.append(uid)
resource = make_resource(
unique_id=uid,
name=f"res_{i}",
raw_references=[],
)
resources.append(resource)
# Create a cycle: pick a subset of at least 2 resources and form a ring
cycle_size = draw(st.integers(min_value=2, max_value=num_resources))
cycle_indices = draw(
st.lists(
st.sampled_from(list(range(num_resources))),
min_size=cycle_size,
max_size=cycle_size,
unique=True,
)
)
# Form a ring: each resource in the cycle references the next one
for j in range(len(cycle_indices)):
src_idx = cycle_indices[j]
tgt_idx = cycle_indices[(j + 1) % len(cycle_indices)]
target_id = ids[tgt_idx]
if target_id not in resources[src_idx].raw_references:
resources[src_idx].raw_references.append(target_id)
return resources
@st.composite
def resources_with_unresolved_refs_strategy(draw):
"""Generate resources where some raw_references point to IDs not in the inventory."""
num_resources = draw(st.integers(min_value=1, max_value=5))
resources = []
ids = []
for i in range(num_resources):
uid = f"resource_{i}"
ids.append(uid)
# Generate unresolved reference IDs (not in the inventory)
num_unresolved = draw(st.integers(min_value=1, max_value=4))
unresolved_ids = []
for i in range(num_unresolved):
# Mix of IDs with "/" (should suggest data_source) and without (should suggest variable)
if draw(st.booleans()):
unresolved_id = f"external/resource/{i}"
else:
unresolved_id = f"external_var_{i}"
unresolved_ids.append(unresolved_id)
# Create resources, some referencing unresolved IDs
for i in range(num_resources):
# Pick some unresolved refs for this resource
num_ext_refs = draw(st.integers(min_value=0, max_value=min(num_unresolved, 2)))
ext_refs = draw(
st.lists(
st.sampled_from(unresolved_ids),
min_size=num_ext_refs,
max_size=num_ext_refs,
unique=True,
)
)
resource = make_resource(
unique_id=ids[i],
name=f"res_{i}",
raw_references=ext_refs,
)
resources.append(resource)
return resources, unresolved_ids
# ---------------------------------------------------------------------------
# Property 6: Dependency relationship identification
# ---------------------------------------------------------------------------
class TestDependencyRelationshipIdentification:
"""Property 6: Dependency relationship identification.
**Validates: Requirements 2.1**
For any resource with raw_references pointing to other resources in the
inventory, the resolver SHALL create a ResourceRelationship for each
resolved reference.
"""
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_relationship_created_for_each_resolved_reference(
self, resources: list[DiscoveredResource]
):
"""For each raw_reference pointing to a known resource, a relationship is created."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
# Count expected relationships: each raw_reference that points to a resource in inventory
resource_ids = {r.unique_id for r in resources}
expected_relationships = 0
for resource in resources:
for ref in resource.raw_references:
if ref in resource_ids:
expected_relationships += 1
assert len(graph.relationships) == expected_relationships
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_relationship_source_and_target_are_correct(
self, resources: list[DiscoveredResource]
):
"""Each relationship has source_id as the referencing resource and target_id as the referenced."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
resource_ids = {r.unique_id for r in resources}
for rel in graph.relationships:
# source_id is the resource that holds the reference
assert rel.source_id in resource_ids
# target_id is the resource being referenced
assert rel.target_id in resource_ids
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_relationship_type_is_valid(
self, resources: list[DiscoveredResource]
):
"""Each relationship has a valid relationship_type."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
valid_types = {"parent-child", "reference", "dependency"}
for rel in graph.relationships:
assert rel.relationship_type in valid_types
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_relationship_source_attribute_is_non_empty(
self, resources: list[DiscoveredResource]
):
"""Each relationship has a non-empty source_attribute."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for rel in graph.relationships:
assert isinstance(rel.source_attribute, str)
assert len(rel.source_attribute) > 0
# ---------------------------------------------------------------------------
# Property 7: Cycle detection correctness
# ---------------------------------------------------------------------------
class TestCycleDetectionCorrectness:
"""Property 7: Cycle detection correctness.
**Validates: Requirements 2.3**
For any graph containing a cycle, the resolver SHALL detect and report it
in the cycles list. For any acyclic dependency graph, the resolver SHALL
report zero cycles.
"""
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_acyclic_graph_reports_zero_cycles(
self, resources: list[DiscoveredResource]
):
"""An acyclic graph should have no cycles reported."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
assert len(graph.cycles) == 0
@given(resources=cyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_cyclic_graph_reports_at_least_one_cycle(
self, resources: list[DiscoveredResource]
):
"""A graph with a cycle should have at least one cycle reported."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
assert len(graph.cycles) >= 1
@given(resources=cyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_cycle_contains_valid_resource_ids(
self, resources: list[DiscoveredResource]
):
"""Each reported cycle contains only valid resource IDs from the inventory."""
scan_result = make_scan_result(resources)
resource_ids = {r.unique_id for r in resources}
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for cycle in graph.cycles:
assert len(cycle) >= 2, "A cycle must involve at least 2 resources"
for resource_id in cycle:
assert resource_id in resource_ids
@given(resources=cyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_cycle_reports_have_resolution_suggestions(
self, resources: list[DiscoveredResource]
):
"""Each cycle report includes a suggested break edge and resolution strategy."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for report in graph.cycle_reports:
assert report.suggested_break is not None
assert len(report.suggested_break) == 2
assert report.break_relationship_type in {"parent-child", "reference", "dependency"}
assert isinstance(report.resolution_strategy, str)
assert len(report.resolution_strategy) > 0
# ---------------------------------------------------------------------------
# Property 8: Topological order validity
# ---------------------------------------------------------------------------
class TestTopologicalOrderValidity:
"""Property 8: Topological order validity.
**Validates: Requirements 2.4**
For any acyclic dependency graph, no resource SHALL appear before any
resource it depends on in the topological order.
"""
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_topological_order_contains_all_resources(
self, resources: list[DiscoveredResource]
):
"""The topological order must contain all resource IDs."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
resource_ids = {r.unique_id for r in resources}
assert set(graph.topological_order) == resource_ids
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_dependencies_appear_before_dependents(
self, resources: list[DiscoveredResource]
):
"""For every dependency edge (A depends on B), B appears before A in topological order.
In the resolver, if resource A has B in raw_references, then A depends on B,
meaning B must appear before A in the topological order.
"""
scan_result = make_scan_result(resources)
resource_ids = {r.unique_id for r in resources}
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
# Build position map
position = {rid: idx for idx, rid in enumerate(graph.topological_order)}
# For each resource, its referenced resources (that are in inventory) must come before it
for resource in resources:
for ref_id in resource.raw_references:
if ref_id in resource_ids:
assert position[ref_id] < position[resource.unique_id], (
f"Resource '{ref_id}' (dependency) should appear before "
f"'{resource.unique_id}' (dependent) in topological order"
)
@given(resources=acyclic_resource_graph_strategy())
@settings(max_examples=100)
def test_topological_order_has_no_duplicates(
self, resources: list[DiscoveredResource]
):
"""The topological order must not contain duplicate entries."""
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
assert len(graph.topological_order) == len(set(graph.topological_order))
# ---------------------------------------------------------------------------
# Property 9: Unresolved references become data sources or variables
# ---------------------------------------------------------------------------
class TestUnresolvedReferences:
"""Property 9: Unresolved references become data sources or variables.
**Validates: Requirements 2.5**
For any raw_reference pointing to an ID not in the inventory, the resolver
SHALL create an UnresolvedReference with suggested_resolution of either
"data_source" or "variable".
"""
@given(data=resources_with_unresolved_refs_strategy())
@settings(max_examples=100)
def test_unresolved_references_are_tracked(
self, data: tuple[list[DiscoveredResource], list[str]]
):
"""Each reference to an ID not in inventory creates an UnresolvedReference."""
resources, unresolved_ids = data
scan_result = make_scan_result(resources)
resource_ids = {r.unique_id for r in resources}
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
# Count expected unresolved references
expected_unresolved = 0
for resource in resources:
for ref in resource.raw_references:
if ref not in resource_ids:
expected_unresolved += 1
assert len(graph.unresolved_references) == expected_unresolved
@given(data=resources_with_unresolved_refs_strategy())
@settings(max_examples=100)
def test_unresolved_references_suggest_data_source_or_variable(
self, data: tuple[list[DiscoveredResource], list[str]]
):
"""Each UnresolvedReference has suggested_resolution of 'data_source' or 'variable'."""
resources, _ = data
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for unresolved in graph.unresolved_references:
assert unresolved.suggested_resolution in {"data_source", "variable"}, (
f"Expected 'data_source' or 'variable', got '{unresolved.suggested_resolution}'"
)
@given(data=resources_with_unresolved_refs_strategy())
@settings(max_examples=100)
def test_unresolved_references_have_valid_source_resource(
self, data: tuple[list[DiscoveredResource], list[str]]
):
"""Each UnresolvedReference has a source_resource_id that exists in the inventory."""
resources, _ = data
scan_result = make_scan_result(resources)
resource_ids = {r.unique_id for r in resources}
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for unresolved in graph.unresolved_references:
assert unresolved.source_resource_id in resource_ids
@given(data=resources_with_unresolved_refs_strategy())
@settings(max_examples=100)
def test_unresolved_references_have_non_empty_fields(
self, data: tuple[list[DiscoveredResource], list[str]]
):
"""Each UnresolvedReference has non-empty source_attribute and referenced_id."""
resources, _ = data
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for unresolved in graph.unresolved_references:
assert isinstance(unresolved.source_attribute, str)
assert len(unresolved.source_attribute) > 0
assert isinstance(unresolved.referenced_id, str)
assert len(unresolved.referenced_id) > 0
@given(data=resources_with_unresolved_refs_strategy())
@settings(max_examples=100)
def test_ids_with_slash_or_colon_suggest_data_source(
self, data: tuple[list[DiscoveredResource], list[str]]
):
"""References containing '/' or ':' should suggest 'data_source' resolution."""
resources, _ = data
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for unresolved in graph.unresolved_references:
if "/" in unresolved.referenced_id or ":" in unresolved.referenced_id:
assert unresolved.suggested_resolution == "data_source", (
f"Reference '{unresolved.referenced_id}' contains '/' or ':' "
f"and should suggest 'data_source', got '{unresolved.suggested_resolution}'"
)
@given(data=resources_with_unresolved_refs_strategy())
@settings(max_examples=100)
def test_ids_without_slash_or_colon_suggest_variable(
self, data: tuple[list[DiscoveredResource], list[str]]
):
"""References without '/' or ':' should suggest 'variable' resolution."""
resources, _ = data
scan_result = make_scan_result(resources)
resolver = DependencyResolver(scan_result)
graph = resolver.resolve()
for unresolved in graph.unresolved_references:
if "/" not in unresolved.referenced_id and ":" not in unresolved.referenced_id:
assert unresolved.suggested_resolution == "variable", (
f"Reference '{unresolved.referenced_id}' has no '/' or ':' "
f"and should suggest 'variable', got '{unresolved.suggested_resolution}'"
)

View File

@@ -0,0 +1,308 @@
"""Property-based tests for drift report correctness.
**Validates: Requirements 7.3**
Properties tested:
- Property 22: Drift report correctness — For any terraform plan output
containing planned changes, the Validator SHALL report each change with
the correct resource address and change type (add, modify, destroy).
"""
import json
import tempfile
from unittest.mock import MagicMock, patch
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from iac_reverse.models import PlannedChange, ValidationResult
from iac_reverse.validator import Validator
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
# Terraform action types that map to our change types
TERRAFORM_ACTIONS = ["create", "update", "delete"]
# Expected mapping from terraform actions to our change types
ACTION_TO_CHANGE_TYPE = {
"create": "add",
"update": "modify",
"delete": "destroy",
}
# Strategy for valid terraform resource addresses
# Format: <resource_type>.<resource_name> or <module>.<resource_type>.<name>
resource_type_prefix_strategy = st.sampled_from([
"aws_instance",
"kubernetes_deployment",
"docker_service",
"harvester_virtualmachine",
"synology_shared_folder",
"windows_service",
"bare_metal_hardware",
"null_resource",
"local_file",
"random_id",
])
resource_name_suffix_strategy = st.text(
min_size=1,
max_size=20,
alphabet=st.characters(whitelist_categories=("Ll",), whitelist_characters="_"),
).filter(lambda s: s[0].isalpha() or s[0] == "_")
@st.composite
def resource_address_strategy(draw):
"""Generate a valid terraform resource address like 'aws_instance.my_server'."""
prefix = draw(resource_type_prefix_strategy)
suffix = draw(resource_name_suffix_strategy)
# Optionally add a module prefix
use_module = draw(st.booleans())
if use_module:
module_name = draw(st.text(
min_size=1,
max_size=10,
alphabet=st.characters(whitelist_categories=("Ll",), whitelist_characters="_"),
).filter(lambda s: s[0].isalpha()))
return f"module.{module_name}.{prefix}.{suffix}"
return f"{prefix}.{suffix}"
terraform_action_strategy = st.sampled_from(TERRAFORM_ACTIONS)
@st.composite
def planned_change_entry_strategy(draw):
"""Generate a single planned change entry as it appears in terraform plan JSON output."""
addr = draw(resource_address_strategy())
action = draw(terraform_action_strategy)
return (addr, action)
@st.composite
def planned_changes_list_strategy(draw):
"""Generate a list of planned changes with unique resource addresses."""
num_changes = draw(st.integers(min_value=1, max_value=10))
changes = []
seen_addrs = set()
for _ in range(num_changes):
entry = draw(planned_change_entry_strategy())
addr, action = entry
# Ensure unique addresses
if addr in seen_addrs:
continue
seen_addrs.add(addr)
changes.append((addr, action))
assume(len(changes) >= 1)
return changes
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
VALIDATE_SUCCESS_JSON = json.dumps(
{"valid": True, "error_count": 0, "diagnostics": []}
)
def _make_completed_process(returncode=0, stdout="", stderr=""):
"""Create a mock CompletedProcess-like object."""
mock = MagicMock()
mock.returncode = returncode
mock.stdout = stdout
mock.stderr = stderr
return mock
def build_plan_output(changes: list[tuple[str, str]]) -> str:
"""Build terraform plan JSON streaming output from a list of (addr, action) tuples."""
lines = [json.dumps({"type": "version", "terraform": "1.7.0"})]
for addr, action in changes:
lines.append(
json.dumps(
{
"type": "planned_change",
"change": {
"resource": {"addr": addr},
"action": action,
},
}
)
)
# Add change_summary
total_add = sum(1 for _, a in changes if a == "create")
total_change = sum(1 for _, a in changes if a == "update")
total_remove = sum(1 for _, a in changes if a == "delete")
lines.append(
json.dumps(
{
"type": "change_summary",
"changes": {
"add": total_add,
"change": total_change,
"remove": total_remove,
},
}
)
)
return "\n".join(lines)
def run_validator_with_plan(plan_output: str) -> ValidationResult:
"""Run the Validator with mocked subprocess calls, returning the result."""
init_result = _make_completed_process(returncode=0)
validate_result = _make_completed_process(
returncode=0, stdout=VALIDATE_SUCCESS_JSON
)
plan_result = _make_completed_process(returncode=2, stdout=plan_output)
with tempfile.TemporaryDirectory() as tmp_dir:
with patch("shutil.which", return_value="/usr/bin/terraform"), patch(
"subprocess.run",
side_effect=[init_result, validate_result, plan_result],
):
validator = Validator()
return validator.validate(tmp_dir)
# ---------------------------------------------------------------------------
# Property 22: Drift report correctness
# ---------------------------------------------------------------------------
class TestDriftReportCorrectness:
"""Property 22: Drift report correctness.
**Validates: Requirements 7.3**
For any terraform plan output containing N planned changes, the drift
report SHALL list exactly N entries, each with the correct resource
address and change type (add, modify, or destroy).
"""
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_count_matches_planned_changes(
self, changes: list[tuple[str, str]]
):
"""The number of reported planned changes equals the number in the plan output."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
assert len(result.planned_changes) == len(changes), (
f"Expected {len(changes)} planned changes, "
f"got {len(result.planned_changes)}. "
f"Input changes: {changes}"
)
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_resource_addresses_match(
self, changes: list[tuple[str, str]]
):
"""Each reported change has the correct resource address from the plan."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
expected_addrs = {addr for addr, _ in changes}
actual_addrs = {c.resource_address for c in result.planned_changes}
assert actual_addrs == expected_addrs, (
f"Resource address mismatch.\n"
f"Expected: {sorted(expected_addrs)}\n"
f"Actual: {sorted(actual_addrs)}"
)
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_change_types_correct(
self, changes: list[tuple[str, str]]
):
"""Each reported change has the correct change type mapping."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
# Build expected mapping: addr -> change_type
expected_map = {
addr: ACTION_TO_CHANGE_TYPE[action] for addr, action in changes
}
for planned_change in result.planned_changes:
addr = planned_change.resource_address
assert addr in expected_map, (
f"Unexpected resource address '{addr}' in planned changes"
)
expected_type = expected_map[addr]
assert planned_change.change_type == expected_type, (
f"For resource '{addr}': expected change_type='{expected_type}', "
f"got '{planned_change.change_type}'"
)
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_plan_success_is_false(
self, changes: list[tuple[str, str]]
):
"""When there are planned changes, plan_success is always False."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
assert result.plan_success is False, (
f"plan_success should be False when there are {len(changes)} "
f"planned changes, but got True"
)
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_each_change_is_planned_change_instance(
self, changes: list[tuple[str, str]]
):
"""Each entry in the drift report is a PlannedChange instance."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
for i, change in enumerate(result.planned_changes):
assert isinstance(change, PlannedChange), (
f"Entry {i} is {type(change).__name__}, expected PlannedChange"
)
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_change_type_in_valid_set(
self, changes: list[tuple[str, str]]
):
"""Every reported change_type is one of 'add', 'modify', or 'destroy'."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
valid_types = {"add", "modify", "destroy"}
for change in result.planned_changes:
assert change.change_type in valid_types, (
f"Invalid change_type '{change.change_type}' for resource "
f"'{change.resource_address}'. Must be one of {valid_types}"
)
@given(changes=planned_changes_list_strategy())
@settings(max_examples=100)
def test_drift_report_no_duplicate_addresses(
self, changes: list[tuple[str, str]]
):
"""No resource address appears more than once in the drift report."""
plan_output = build_plan_output(changes)
result = run_validator_with_plan(plan_output)
addresses = [c.resource_address for c in result.planned_changes]
assert len(addresses) == len(set(addresses)), (
f"Duplicate resource addresses found in drift report: "
f"{[a for a in addresses if addresses.count(a) > 1]}"
)

View File

@@ -0,0 +1,790 @@
"""Property-based tests for Incremental Scan Engine.
**Validates: Requirements 8.1, 8.2, 8.3, 8.5, 8.6**
Properties tested:
- Property 23: Change classification correctness
- Property 24: Incremental update scope
- Property 25: Removed resource exclusion
- Property 26: Snapshot retention
"""
import json
import tempfile
from pathlib import Path
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from iac_reverse.incremental import ChangeDetector, IncrementalUpdater, SnapshotStore
from iac_reverse.models import (
ChangeSummary,
ChangeType,
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ResourceChange,
ScanResult,
)
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_strategy = st.sampled_from(list(ProviderType))
platform_strategy = st.sampled_from(list(PlatformCategory))
architecture_strategy = st.sampled_from(list(CpuArchitecture))
# Simple attribute values for resources
attribute_value_strategy = st.one_of(
st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz0123456789"),
st.integers(min_value=0, max_value=1000),
st.booleans(),
)
attributes_strategy = st.dictionaries(
keys=st.text(min_size=1, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz_"),
values=attribute_value_strategy,
min_size=1,
max_size=5,
)
# Resource name strategy (valid identifiers)
resource_name_strategy = st.text(
min_size=1,
max_size=15,
alphabet="abcdefghijklmnopqrstuvwxyz_",
).filter(lambda s: s[0].isalpha())
# Resource type strategy
resource_type_strategy = st.sampled_from([
"docker_service",
"kubernetes_deployment",
"synology_shared_folder",
"harvester_virtualmachine",
"bare_metal_hardware",
"windows_service",
])
@st.composite
def discovered_resource_strategy(draw, uid=None):
"""Generate a DiscoveredResource with valid fields."""
resource_type = draw(resource_type_strategy)
unique_id = uid or draw(st.text(
min_size=5, max_size=30,
alphabet="abcdefghijklmnopqrstuvwxyz0123456789_-/",
).filter(lambda s: s[0].isalpha()))
name = draw(resource_name_strategy)
provider = draw(provider_strategy)
platform = draw(platform_strategy)
arch = draw(architecture_strategy)
endpoint = draw(st.text(min_size=3, max_size=20, alphabet="abcdefghijklmnopqrstuvwxyz."))
attributes = draw(attributes_strategy)
return DiscoveredResource(
resource_type=resource_type,
unique_id=unique_id,
name=name,
provider=provider,
platform_category=platform,
architecture=arch,
endpoint=endpoint,
attributes=attributes,
raw_references=[],
)
@st.composite
def scan_result_strategy(draw, min_resources=0, max_resources=8):
"""Generate a ScanResult with unique resource IDs."""
num_resources = draw(st.integers(min_value=min_resources, max_value=max_resources))
resources = []
seen_ids = set()
for i in range(num_resources):
uid = f"resource_{i}_{draw(st.text(min_size=3, max_size=8, alphabet='abcdefghijklmnopqrstuvwxyz'))}"
if uid in seen_ids:
uid = f"resource_{i}_fallback"
seen_ids.add(uid)
resource = draw(discovered_resource_strategy(uid=uid))
resources.append(resource)
return ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="2024-01-15T10:30:00Z",
profile_hash="test_profile_hash",
is_partial=False,
)
@st.composite
def scan_result_pair_strategy(draw):
"""Generate a pair of scan results with some overlap for meaningful diffs.
Creates a previous and current scan where:
- Some resources exist in both (potentially modified)
- Some resources only in previous (removed)
- Some resources only in current (added)
"""
# Shared resources (exist in both, may be modified)
num_shared = draw(st.integers(min_value=0, max_value=4))
# Resources only in previous (will be removed)
num_removed = draw(st.integers(min_value=0, max_value=3))
# Resources only in current (will be added)
num_added = draw(st.integers(min_value=0, max_value=3))
assume(num_shared + num_removed + num_added >= 1)
previous_resources = []
current_resources = []
# Generate shared resources
for i in range(num_shared):
uid = f"shared_{i}"
resource_type = draw(resource_type_strategy)
name = draw(resource_name_strategy)
provider = draw(provider_strategy)
platform = draw(platform_strategy)
arch = draw(architecture_strategy)
endpoint = draw(st.text(min_size=3, max_size=10, alphabet="abcdefghijklmnopqrstuvwxyz."))
prev_attrs = draw(attributes_strategy)
prev_resource = DiscoveredResource(
resource_type=resource_type,
unique_id=uid,
name=name,
provider=provider,
platform_category=platform,
architecture=arch,
endpoint=endpoint,
attributes=prev_attrs,
raw_references=[],
)
previous_resources.append(prev_resource)
# Possibly modify attributes for current version
modify = draw(st.booleans())
if modify:
curr_attrs = draw(attributes_strategy)
else:
curr_attrs = dict(prev_attrs)
curr_resource = DiscoveredResource(
resource_type=resource_type,
unique_id=uid,
name=name,
provider=provider,
platform_category=platform,
architecture=arch,
endpoint=endpoint,
attributes=curr_attrs,
raw_references=[],
)
current_resources.append(curr_resource)
# Generate removed resources (only in previous)
for i in range(num_removed):
uid = f"removed_{i}"
resource = draw(discovered_resource_strategy(uid=uid))
previous_resources.append(resource)
# Generate added resources (only in current)
for i in range(num_added):
uid = f"added_{i}"
resource = draw(discovered_resource_strategy(uid=uid))
current_resources.append(resource)
previous = ScanResult(
resources=previous_resources,
warnings=[],
errors=[],
scan_timestamp="2024-01-14T09:00:00Z",
profile_hash="test_profile",
is_partial=False,
)
current = ScanResult(
resources=current_resources,
warnings=[],
errors=[],
scan_timestamp="2024-01-15T10:30:00Z",
profile_hash="test_profile",
is_partial=False,
)
return previous, current
# ---------------------------------------------------------------------------
# Property 23: Change classification correctness
# ---------------------------------------------------------------------------
class TestChangeClassificationCorrectness:
"""Property 23: Change classification correctness.
**Validates: Requirements 8.1, 8.5**
For any pair of scan results (previous and current), every resource
SHALL be classified exactly once as: added, removed, or modified.
The summary counts SHALL equal the actual number of resources in each
category.
"""
@given(data=scan_result_pair_strategy())
@settings(max_examples=100)
def test_every_resource_classified_exactly_once(self, data):
"""Every resource is classified as exactly one of: added, removed, or modified."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
prev_ids = {r.unique_id for r in previous.resources}
curr_ids = {r.unique_id for r in current.resources}
all_ids = prev_ids | curr_ids
# Each change should reference a resource from either scan
change_ids = [c.resource_id for c in summary.changes]
# No duplicates in changes
assert len(change_ids) == len(set(change_ids)), (
f"Duplicate resource IDs in changes: "
f"{[rid for rid in change_ids if change_ids.count(rid) > 1]}"
)
# Every changed resource must be from the union of both scans
for change in summary.changes:
assert change.resource_id in all_ids, (
f"Change references unknown resource: {change.resource_id}"
)
@given(data=scan_result_pair_strategy())
@settings(max_examples=100)
def test_added_resources_in_current_not_previous(self, data):
"""Resources classified as ADDED are in current but not in previous."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
prev_ids = {r.unique_id for r in previous.resources}
curr_ids = {r.unique_id for r in current.resources}
added_changes = [c for c in summary.changes if c.change_type == ChangeType.ADDED]
for change in added_changes:
assert change.resource_id in curr_ids, (
f"ADDED resource {change.resource_id} not in current scan"
)
assert change.resource_id not in prev_ids, (
f"ADDED resource {change.resource_id} exists in previous scan"
)
@given(data=scan_result_pair_strategy())
@settings(max_examples=100)
def test_removed_resources_in_previous_not_current(self, data):
"""Resources classified as REMOVED are in previous but not in current."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
prev_ids = {r.unique_id for r in previous.resources}
curr_ids = {r.unique_id for r in current.resources}
removed_changes = [c for c in summary.changes if c.change_type == ChangeType.REMOVED]
for change in removed_changes:
assert change.resource_id in prev_ids, (
f"REMOVED resource {change.resource_id} not in previous scan"
)
assert change.resource_id not in curr_ids, (
f"REMOVED resource {change.resource_id} exists in current scan"
)
@given(data=scan_result_pair_strategy())
@settings(max_examples=100)
def test_modified_resources_in_both_with_differing_attributes(self, data):
"""Resources classified as MODIFIED exist in both scans with differing attributes."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
prev_map = {r.unique_id: r for r in previous.resources}
curr_map = {r.unique_id: r for r in current.resources}
modified_changes = [c for c in summary.changes if c.change_type == ChangeType.MODIFIED]
for change in modified_changes:
assert change.resource_id in prev_map, (
f"MODIFIED resource {change.resource_id} not in previous scan"
)
assert change.resource_id in curr_map, (
f"MODIFIED resource {change.resource_id} not in current scan"
)
# Attributes must actually differ
assert prev_map[change.resource_id].attributes != curr_map[change.resource_id].attributes, (
f"MODIFIED resource {change.resource_id} has identical attributes"
)
@given(data=scan_result_pair_strategy())
@settings(max_examples=100)
def test_summary_counts_match_actual_changes(self, data):
"""Summary counts equal the actual number of resources in each category."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
actual_added = sum(1 for c in summary.changes if c.change_type == ChangeType.ADDED)
actual_removed = sum(1 for c in summary.changes if c.change_type == ChangeType.REMOVED)
actual_modified = sum(1 for c in summary.changes if c.change_type == ChangeType.MODIFIED)
assert summary.added_count == actual_added, (
f"added_count={summary.added_count} != actual={actual_added}"
)
assert summary.removed_count == actual_removed, (
f"removed_count={summary.removed_count} != actual={actual_removed}"
)
assert summary.modified_count == actual_modified, (
f"modified_count={summary.modified_count} != actual={actual_modified}"
)
@given(data=scan_result_pair_strategy())
@settings(max_examples=100)
def test_change_types_are_valid(self, data):
"""Every change has a valid ChangeType value."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
valid_types = {ChangeType.ADDED, ChangeType.REMOVED, ChangeType.MODIFIED}
for change in summary.changes:
assert change.change_type in valid_types, (
f"Invalid change_type: {change.change_type}"
)
# ---------------------------------------------------------------------------
# Property 24: Incremental update scope
# ---------------------------------------------------------------------------
class TestIncrementalUpdateScope:
"""Property 24: Incremental update scope.
**Validates: Requirements 8.2**
For any change set applied to existing IaC files, only files containing
added, modified, or removed resources SHALL be modified. Files containing
only unchanged resources SHALL remain identical.
"""
@given(data=scan_result_pair_strategy())
@settings(max_examples=100, deadline=None)
def test_only_changed_resource_files_are_modified(self, data):
"""Only .tf files for resource types with changes are modified."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
# Skip if no changes (nothing to test)
assume(len(summary.changes) > 0)
with tempfile.TemporaryDirectory() as tmp_dir:
# Create initial .tf files for all resource types in previous scan
resource_types_in_previous = {r.resource_type for r in previous.resources}
# Also create a file for an "unchanged" resource type
unchanged_type = "unchanged_resource_type"
resource_types_in_previous.add(unchanged_type)
for rt in resource_types_in_previous:
tf_path = Path(tmp_dir) / f"{rt}.tf"
tf_path.write_text(f'# Placeholder for {rt}\n', encoding="utf-8")
# Record original content of the unchanged file
unchanged_path = Path(tmp_dir) / f"{unchanged_type}.tf"
original_unchanged_content = unchanged_path.read_text(encoding="utf-8")
# Build resource_attributes for added resources
resource_attributes = {}
for change in summary.changes:
if change.change_type == ChangeType.ADDED:
# Find the resource in current scan
for r in current.resources:
if r.unique_id == change.resource_id:
resource_attributes[change.resource_id] = r.attributes
break
# Apply incremental update
updater = IncrementalUpdater(
change_summary=summary,
output_dir=tmp_dir,
resource_attributes=resource_attributes,
)
updater.apply()
# The unchanged file should not be modified
assert unchanged_path.read_text(encoding="utf-8") == original_unchanged_content, (
"File for unchanged resource type was modified"
)
# Modified files should only be for resource types with changes
changed_resource_types = {c.resource_type for c in summary.changes}
for modified_file in updater.modified_files:
file_name = Path(modified_file).name
# Modified files should be .tf files for changed resource types
# or the state file
if file_name == "terraform.tfstate":
continue
assert file_name.endswith(".tf"), (
f"Unexpected modified file: {file_name}"
)
rt = file_name[:-3] # strip .tf
assert rt in changed_resource_types, (
f"File {file_name} was modified but resource type "
f"'{rt}' has no changes"
)
# ---------------------------------------------------------------------------
# Property 25: Removed resource exclusion
# ---------------------------------------------------------------------------
class TestRemovedResourceExclusion:
"""Property 25: Removed resource exclusion.
**Validates: Requirements 8.3**
For any resource classified as removed, the updated IaC output SHALL
not contain a resource block for that resource, AND the updated state
file SHALL not contain a state entry for that resource.
"""
@given(data=scan_result_pair_strategy())
@settings(max_examples=100, deadline=None)
def test_removed_resources_not_in_tf_files(self, data):
"""Removed resources do not appear in .tf files after update."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
removed_changes = [c for c in summary.changes if c.change_type == ChangeType.REMOVED]
assume(len(removed_changes) > 0)
with tempfile.TemporaryDirectory() as tmp_dir:
# Create .tf files with resource blocks for previous resources
from iac_reverse.generator.sanitize import sanitize_identifier
resources_by_type: dict[str, list] = {}
for r in previous.resources:
resources_by_type.setdefault(r.resource_type, []).append(r)
for rt, resources in resources_by_type.items():
tf_path = Path(tmp_dir) / f"{rt}.tf"
lines = []
for r in resources:
tf_name = sanitize_identifier(r.name)
lines.append(f'# Source: {r.unique_id}')
lines.append(f'resource "{rt}" "{tf_name}" {{')
for k, v in r.attributes.items():
lines.append(f' {k} = "{v}"')
lines.append("}")
lines.append("")
tf_path.write_text("\n".join(lines), encoding="utf-8")
# Build resource_attributes for added resources
resource_attributes = {}
for change in summary.changes:
if change.change_type == ChangeType.ADDED:
for r in current.resources:
if r.unique_id == change.resource_id:
resource_attributes[change.resource_id] = r.attributes
break
# Apply incremental update
updater = IncrementalUpdater(
change_summary=summary,
output_dir=tmp_dir,
resource_attributes=resource_attributes,
)
updater.apply()
# Verify removed resources are not in any .tf file
for change in removed_changes:
tf_path = Path(tmp_dir) / f"{change.resource_type}.tf"
if tf_path.exists():
content = tf_path.read_text(encoding="utf-8")
tf_name = sanitize_identifier(change.resource_name)
# The resource block should not exist
block_header = f'resource "{change.resource_type}" "{tf_name}"'
assert block_header not in content, (
f"Removed resource {change.resource_id} still has a "
f"resource block in {tf_path.name}"
)
@given(data=scan_result_pair_strategy())
@settings(max_examples=100, deadline=None)
def test_removed_resources_not_in_state_file(self, data):
"""Removed resources do not appear in the state file after update."""
previous, current = data
detector = ChangeDetector()
summary = detector.compare(current, previous)
removed_changes = [c for c in summary.changes if c.change_type == ChangeType.REMOVED]
assume(len(removed_changes) > 0)
with tempfile.TemporaryDirectory() as tmp_dir:
from iac_reverse.generator.sanitize import sanitize_identifier
# Create initial state file with entries for previous resources
state = {
"version": 4,
"terraform_version": "1.7.0",
"serial": 1,
"lineage": "test-lineage",
"outputs": {},
"resources": [],
}
for r in previous.resources:
tf_name = sanitize_identifier(r.name)
state["resources"].append({
"mode": "managed",
"type": r.resource_type,
"name": tf_name,
"provider": f'provider["registry.terraform.io/hashicorp/{r.resource_type.split("_")[0]}"]',
"instances": [{
"schema_version": 0,
"attributes": {"id": r.unique_id, **r.attributes},
"sensitive_attributes": [],
"dependencies": [],
}],
})
state_path = Path(tmp_dir) / "terraform.tfstate"
state_path.write_text(json.dumps(state, indent=2), encoding="utf-8")
# Create .tf files so updater can process removals
resources_by_type: dict[str, list] = {}
for r in previous.resources:
resources_by_type.setdefault(r.resource_type, []).append(r)
for rt, resources in resources_by_type.items():
tf_path = Path(tmp_dir) / f"{rt}.tf"
lines = []
for r in resources:
tf_name = sanitize_identifier(r.name)
lines.append(f'# Source: {r.unique_id}')
lines.append(f'resource "{rt}" "{tf_name}" {{')
for k, v in r.attributes.items():
lines.append(f' {k} = "{v}"')
lines.append("}")
lines.append("")
tf_path.write_text("\n".join(lines), encoding="utf-8")
# Build resource_attributes for added resources
resource_attributes = {}
for change in summary.changes:
if change.change_type == ChangeType.ADDED:
for r in current.resources:
if r.unique_id == change.resource_id:
resource_attributes[change.resource_id] = r.attributes
break
# Apply incremental update
updater = IncrementalUpdater(
change_summary=summary,
output_dir=tmp_dir,
resource_attributes=resource_attributes,
)
updater.apply()
# Verify removed resources are not in state file
updated_state = json.loads(
state_path.read_text(encoding="utf-8")
)
state_entries = updated_state.get("resources", [])
for change in removed_changes:
tf_name = sanitize_identifier(change.resource_name)
matching = [
e for e in state_entries
if e.get("type") == change.resource_type
and e.get("name") == tf_name
]
assert len(matching) == 0, (
f"Removed resource {change.resource_id} still has a "
f"state entry (type={change.resource_type}, name={tf_name})"
)
# ---------------------------------------------------------------------------
# Property 26: Snapshot retention
# ---------------------------------------------------------------------------
class TestSnapshotRetention:
"""Property 26: Snapshot retention.
**Validates: Requirements 8.6**
For any sequence of N scans (N >= 2) for the same Scan_Profile, at
least the two most recent scan results SHALL be retained in storage
after each scan completes.
"""
@given(num_scans=st.integers(min_value=2, max_value=8))
@settings(max_examples=100)
def test_at_least_two_snapshots_retained(self, num_scans):
"""After N scans, at least 2 most recent snapshots are retained."""
from unittest.mock import patch
from datetime import datetime, timezone
with tempfile.TemporaryDirectory() as tmp_dir:
store = SnapshotStore(base_dir=tmp_dir)
profile_hash = "retention_test_profile"
# Store N scan results with mocked timestamps to ensure unique filenames
for i in range(num_scans):
result = ScanResult(
resources=[
DiscoveredResource(
resource_type="docker_service",
unique_id=f"svc_{i}",
name=f"service_{i}",
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AMD64,
endpoint="localhost",
attributes={"version": str(i)},
raw_references=[],
)
],
warnings=[],
errors=[],
scan_timestamp=f"2024-01-{15 + i:02d}T10:00:00Z",
profile_hash=profile_hash,
is_partial=False,
)
# Mock datetime.now to return unique timestamps
mock_time = datetime(2024, 1, 15 + i, 10, 0, 0, tzinfo=timezone.utc)
with patch(
"iac_reverse.incremental.snapshot_store.datetime"
) as mock_dt:
mock_dt.now.return_value = mock_time
mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw)
store.store_snapshot(result, profile_hash)
# Count remaining snapshots
snapshot_files = list(store.snapshot_dir.glob(f"{profile_hash}_*.json"))
assert len(snapshot_files) >= 2, (
f"After {num_scans} scans, only {len(snapshot_files)} "
f"snapshots retained (expected >= 2)"
)
@given(num_scans=st.integers(min_value=2, max_value=8))
@settings(max_examples=100)
def test_most_recent_snapshot_is_loadable(self, num_scans):
"""The most recent snapshot can be loaded after multiple stores."""
from unittest.mock import patch
from datetime import datetime, timezone
with tempfile.TemporaryDirectory() as tmp_dir:
store = SnapshotStore(base_dir=tmp_dir)
profile_hash = "loadable_test_profile"
last_resource_id = None
for i in range(num_scans):
last_resource_id = f"svc_{i}"
result = ScanResult(
resources=[
DiscoveredResource(
resource_type="kubernetes_deployment",
unique_id=last_resource_id,
name=f"deploy_{i}",
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AARCH64,
endpoint="k8s-api.local",
attributes={"replicas": i + 1},
raw_references=[],
)
],
warnings=[],
errors=[],
scan_timestamp=f"2024-01-{15 + i:02d}T10:00:00Z",
profile_hash=profile_hash,
is_partial=False,
)
mock_time = datetime(2024, 1, 15 + i, 10, 0, 0, tzinfo=timezone.utc)
with patch(
"iac_reverse.incremental.snapshot_store.datetime"
) as mock_dt:
mock_dt.now.return_value = mock_time
mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw)
store.store_snapshot(result, profile_hash)
# Load the most recent snapshot
loaded = store.load_previous(profile_hash)
assert loaded is not None, "Could not load most recent snapshot"
assert len(loaded.resources) == 1
assert loaded.resources[0].unique_id == last_resource_id, (
f"Expected most recent resource '{last_resource_id}', "
f"got '{loaded.resources[0].unique_id}'"
)
@given(num_scans=st.integers(min_value=3, max_value=10))
@settings(max_examples=100)
def test_different_profiles_retain_independently(self, num_scans):
"""Snapshots for different profiles are retained independently."""
from unittest.mock import patch
from datetime import datetime, timezone
with tempfile.TemporaryDirectory() as tmp_dir:
store = SnapshotStore(base_dir=tmp_dir)
profile_a = "profile_alpha"
profile_b = "profile_beta"
scan_idx = 0
for i in range(num_scans):
for profile_hash in [profile_a, profile_b]:
result = ScanResult(
resources=[
DiscoveredResource(
resource_type="docker_service",
unique_id=f"{profile_hash}_svc_{i}",
name=f"svc_{i}",
provider=ProviderType.DOCKER_SWARM,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AMD64,
endpoint="localhost",
attributes={"idx": i},
raw_references=[],
)
],
warnings=[],
errors=[],
scan_timestamp=f"2024-01-{15 + i:02d}T10:00:00Z",
profile_hash=profile_hash,
is_partial=False,
)
# Use unique timestamps per store call
mock_time = datetime(2024, 1, 15, 10, scan_idx, 0, tzinfo=timezone.utc)
scan_idx += 1
with patch(
"iac_reverse.incremental.snapshot_store.datetime"
) as mock_dt:
mock_dt.now.return_value = mock_time
mock_dt.side_effect = lambda *a, **kw: datetime(*a, **kw)
store.store_snapshot(result, profile_hash)
# Both profiles should have at least 2 snapshots
snapshots_a = list(store.snapshot_dir.glob(f"{profile_a}_*.json"))
snapshots_b = list(store.snapshot_dir.glob(f"{profile_b}_*.json"))
assert len(snapshots_a) >= 2, (
f"Profile A has {len(snapshots_a)} snapshots (expected >= 2)"
)
assert len(snapshots_b) >= 2, (
f"Profile B has {len(snapshots_b)} snapshots (expected >= 2)"
)

View File

@@ -0,0 +1,803 @@
"""Property-based tests for multi-provider merging and filtering.
**Validates: Requirements 5.3, 5.4, 6.1, 6.2, 6.4, 6.6, 6.7**
Property 18: Multi-provider merge with naming conflict resolution
For any two or more resource inventories from different on-premises providers
where resource names collide, the merged inventory SHALL contain all resources
from all providers, with conflicting names prefixed by the provider identifier,
and no resources lost.
Property 19: Provider block generation
For any resource set spanning N distinct on-premises providers, the generated
provider configuration SHALL contain exactly N provider blocks, one per distinct
provider.
Property 20: Scan profile validation completeness (additional multi-provider scenarios)
Already covered in test_scan_profile_validation_prop.py; this adds multi-provider
scenarios.
Property 21: Filtering correctness
For any scan profile with resource type filters, the discovered resources SHALL
be a subset where every resource's type is in the filter list. No resource outside
the filter criteria shall appear.
"""
from typing import Callable
from hypothesis import given, assume, settings
from hypothesis import strategies as st
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
GeneratedFile,
PlatformCategory,
PROVIDER_PLATFORM_MAP,
PROVIDER_SUPPORTED_RESOURCE_TYPES,
ProviderType,
ScanProfile,
ScanProgress,
ScanResult,
)
from iac_reverse.generator import ProviderBlockGenerator, ResourceMerger
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import Scanner
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
architecture_strategy = st.sampled_from(list(CpuArchitecture))
non_empty_credentials_strategy = st.dictionaries(
keys=st.text(
min_size=1, max_size=20,
alphabet=st.characters(whitelist_categories=("L", "N")),
),
values=st.text(min_size=1, max_size=50),
min_size=1,
max_size=5,
)
# Strategy for resource names that could collide across providers
resource_name_strategy = st.text(
min_size=1,
max_size=30,
alphabet=st.characters(whitelist_categories=("L", "N", "Pd")),
).filter(lambda s: s.strip())
def discovered_resource_strategy(
provider: ProviderType,
name: str | None = None,
) -> st.SearchStrategy[DiscoveredResource]:
"""Generate a DiscoveredResource for a given provider with optional fixed name."""
supported_types = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
platform_category = PROVIDER_PLATFORM_MAP[provider]
return st.builds(
DiscoveredResource,
resource_type=st.sampled_from(supported_types),
unique_id=st.uuids().map(str),
name=st.just(name) if name else resource_name_strategy,
provider=st.just(provider),
platform_category=st.just(platform_category),
architecture=architecture_strategy,
endpoint=st.just("http://localhost:8080"),
attributes=st.just({"key": "value"}),
raw_references=st.just([]),
)
def scan_result_strategy(
provider: ProviderType,
resources: list[DiscoveredResource] | None = None,
) -> st.SearchStrategy[ScanResult]:
"""Generate a ScanResult for a given provider."""
if resources is not None:
return st.just(ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="2024-01-01T00:00:00Z",
profile_hash="abc123",
))
return st.builds(
ScanResult,
resources=st.lists(
discovered_resource_strategy(provider),
min_size=1,
max_size=5,
),
warnings=st.just([]),
errors=st.just([]),
scan_timestamp=st.just("2024-01-01T00:00:00Z"),
profile_hash=st.just("abc123"),
)
# ---------------------------------------------------------------------------
# Mock Plugin for Filtering Tests
# ---------------------------------------------------------------------------
class FilteringPlugin(ProviderPlugin):
"""A plugin that discovers resources only for requested resource types."""
def __init__(self, provider: ProviderType):
self._provider = provider
self._supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
def authenticate(self, credentials: dict[str, str]) -> None:
pass
def get_platform_category(self) -> PlatformCategory:
return PROVIDER_PLATFORM_MAP[self._provider]
def list_endpoints(self) -> list[str]:
return ["http://localhost:8080"]
def list_supported_resource_types(self) -> list[str]:
return self._supported
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
"""Discover exactly one resource per requested resource type."""
resources = []
for i, rt in enumerate(resource_types):
resources.append(
DiscoveredResource(
resource_type=rt,
unique_id=f"id-{rt}-{i}",
name=f"resource-{rt}-{i}",
provider=self._provider,
platform_category=PROVIDER_PLATFORM_MAP[self._provider],
architecture=CpuArchitecture.AMD64,
endpoint="http://localhost:8080",
attributes={"key": "value"},
)
)
progress_callback(ScanProgress(
current_resource_type=rt,
resources_discovered=i + 1,
resource_types_completed=i + 1,
total_resource_types=len(resource_types),
))
return ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="2024-01-01T00:00:00Z",
profile_hash="abc123",
)
# ---------------------------------------------------------------------------
# Property 18: Multi-provider merge with naming conflict resolution
# ---------------------------------------------------------------------------
class TestMultiProviderMergeConflictResolution:
"""Property 18: Multi-provider merge with naming conflict resolution.
When resources from different providers share the same name, the merger
prefixes with provider identifier.
**Validates: Requirements 5.3**
"""
@given(
provider_a=provider_type_strategy,
provider_b=provider_type_strategy,
shared_name=resource_name_strategy,
)
@settings(max_examples=100)
def test_conflicting_names_are_prefixed(self, provider_a, provider_b, shared_name):
"""Resources with the same name from different providers get prefixed."""
assume(provider_a != provider_b)
resource_a = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0],
unique_id="id-a",
name=shared_name,
provider=provider_a,
platform_category=PROVIDER_PLATFORM_MAP[provider_a],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-a:8080",
attributes={"source": "a"},
)
resource_b = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_b][0],
unique_id="id-b",
name=shared_name,
provider=provider_b,
platform_category=PROVIDER_PLATFORM_MAP[provider_b],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-b:8080",
attributes={"source": "b"},
)
scan_result_a = ScanResult(
resources=[resource_a],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
scan_result_b = ScanResult(
resources=[resource_b],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
merger = ResourceMerger()
merged = merger.merge([scan_result_a, scan_result_b])
# Both resources must be present (no loss)
assert len(merged) == 2
# Conflicting names must be prefixed with provider identifier
merged_names = {r.name for r in merged}
expected_name_a = f"{provider_a.value}_{shared_name}"
expected_name_b = f"{provider_b.value}_{shared_name}"
assert expected_name_a in merged_names, (
f"Expected '{expected_name_a}' in merged names, got: {merged_names}"
)
assert expected_name_b in merged_names, (
f"Expected '{expected_name_b}' in merged names, got: {merged_names}"
)
@given(
provider_a=provider_type_strategy,
provider_b=provider_type_strategy,
name_a=resource_name_strategy,
name_b=resource_name_strategy,
)
@settings(max_examples=100)
def test_non_conflicting_names_unchanged(self, provider_a, provider_b, name_a, name_b):
"""Resources with unique names across providers are not prefixed."""
assume(provider_a != provider_b)
assume(name_a != name_b)
resource_a = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0],
unique_id="id-a",
name=name_a,
provider=provider_a,
platform_category=PROVIDER_PLATFORM_MAP[provider_a],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-a:8080",
attributes={},
)
resource_b = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_b][0],
unique_id="id-b",
name=name_b,
provider=provider_b,
platform_category=PROVIDER_PLATFORM_MAP[provider_b],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-b:8080",
attributes={},
)
scan_result_a = ScanResult(
resources=[resource_a],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
scan_result_b = ScanResult(
resources=[resource_b],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
merger = ResourceMerger()
merged = merger.merge([scan_result_a, scan_result_b])
# No resources lost
assert len(merged) == 2
# Names should remain unchanged (no prefix)
merged_names = {r.name for r in merged}
assert name_a in merged_names, (
f"Expected original name '{name_a}' preserved, got: {merged_names}"
)
assert name_b in merged_names, (
f"Expected original name '{name_b}' preserved, got: {merged_names}"
)
@given(
provider_a=provider_type_strategy,
provider_b=provider_type_strategy,
provider_c=provider_type_strategy,
shared_name=resource_name_strategy,
)
@settings(max_examples=100)
def test_three_provider_conflict_all_prefixed(
self, provider_a, provider_b, provider_c, shared_name
):
"""When 3 providers share a name, all get prefixed."""
assume(len({provider_a, provider_b, provider_c}) == 3)
resources_and_results = []
for provider in [provider_a, provider_b, provider_c]:
resource = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider][0],
unique_id=f"id-{provider.value}",
name=shared_name,
provider=provider,
platform_category=PROVIDER_PLATFORM_MAP[provider],
architecture=CpuArchitecture.AMD64,
endpoint=f"http://{provider.value}:8080",
attributes={},
)
result = ScanResult(
resources=[resource],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
resources_and_results.append(result)
merger = ResourceMerger()
merged = merger.merge(resources_and_results)
# All 3 resources preserved
assert len(merged) == 3
# All must be prefixed
for provider in [provider_a, provider_b, provider_c]:
expected = f"{provider.value}_{shared_name}"
assert any(r.name == expected for r in merged), (
f"Expected prefixed name '{expected}' in merged results"
)
@given(
provider_a=provider_type_strategy,
provider_b=provider_type_strategy,
shared_name=resource_name_strategy,
)
@settings(max_examples=100)
def test_merge_preserves_all_resources_no_loss(
self, provider_a, provider_b, shared_name
):
"""Merging never loses resources regardless of naming conflicts."""
assume(provider_a != provider_b)
resource_a = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0],
unique_id="id-a",
name=shared_name,
provider=provider_a,
platform_category=PROVIDER_PLATFORM_MAP[provider_a],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-a:8080",
attributes={"source": "a"},
)
resource_b = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_b][0],
unique_id="id-b",
name=shared_name,
provider=provider_b,
platform_category=PROVIDER_PLATFORM_MAP[provider_b],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-b:8080",
attributes={"source": "b"},
)
# Also add a non-conflicting resource
resource_c = DiscoveredResource(
resource_type=PROVIDER_SUPPORTED_RESOURCE_TYPES[provider_a][0],
unique_id="id-c",
name="unique_resource_name",
provider=provider_a,
platform_category=PROVIDER_PLATFORM_MAP[provider_a],
architecture=CpuArchitecture.AMD64,
endpoint="http://host-a:8080",
attributes={"source": "c"},
)
scan_result_a = ScanResult(
resources=[resource_a, resource_c],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
scan_result_b = ScanResult(
resources=[resource_b],
warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
merger = ResourceMerger()
merged = merger.merge([scan_result_a, scan_result_b])
# Total resources = 3 (no loss)
assert len(merged) == 3
# Provider-specific attributes preserved
unique_ids = {r.unique_id for r in merged}
assert "id-a" in unique_ids
assert "id-b" in unique_ids
assert "id-c" in unique_ids
# ---------------------------------------------------------------------------
# Property 19: Provider block generation
# ---------------------------------------------------------------------------
class TestProviderBlockGeneration:
"""Property 19: Provider block generation.
For any set of providers used, a provider block is generated for each.
**Validates: Requirements 5.4**
"""
@given(
providers=st.lists(
provider_type_strategy,
min_size=1,
max_size=6,
).map(lambda ps: list(set(ps))), # Deduplicate
)
@settings(max_examples=100)
def test_one_provider_block_per_distinct_provider(self, providers):
"""Generated output contains exactly one provider block per distinct provider."""
assume(len(providers) >= 1)
# Create profiles for each provider
profiles = [
ScanProfile(
provider=p,
credentials={"token": "test-token"},
)
for p in providers
]
provider_types = set(providers)
generator = ProviderBlockGenerator()
result = generator.generate(profiles=profiles, provider_types=provider_types)
# Result should be a GeneratedFile
assert isinstance(result, GeneratedFile)
assert result.filename == "providers.tf"
content = result.content
# Count provider blocks: each provider type should have exactly one
# provider "name" { block
for provider_type in provider_types:
# Get the terraform provider name for this type
from iac_reverse.generator.provider_block import _PROVIDER_METADATA
tf_name = _PROVIDER_METADATA[provider_type][0]
provider_block_marker = f'provider "{tf_name}"'
count = content.count(provider_block_marker)
assert count == 1, (
f"Expected exactly 1 provider block for '{tf_name}', "
f"found {count} in:\n{content}"
)
@given(
providers=st.lists(
provider_type_strategy,
min_size=2,
max_size=6,
).map(lambda ps: list(set(ps))),
)
@settings(max_examples=100)
def test_required_providers_block_lists_all(self, providers):
"""The terraform required_providers block lists all providers used."""
assume(len(providers) >= 2)
profiles = [
ScanProfile(
provider=p,
credentials={"token": "test-token"},
)
for p in providers
]
provider_types = set(providers)
generator = ProviderBlockGenerator()
result = generator.generate(profiles=profiles, provider_types=provider_types)
content = result.content
# The required_providers block must exist
assert "required_providers" in content
# Each provider must appear in the required_providers block
from iac_reverse.generator.provider_block import _PROVIDER_METADATA
for provider_type in provider_types:
tf_name, source, _ = _PROVIDER_METADATA[provider_type]
assert tf_name in content, (
f"Expected provider name '{tf_name}' in required_providers block"
)
assert source in content, (
f"Expected source '{source}' in required_providers block"
)
@given(provider=provider_type_strategy)
@settings(max_examples=100)
def test_single_provider_generates_one_block(self, provider):
"""A single provider generates exactly one provider block."""
profiles = [
ScanProfile(
provider=provider,
credentials={"token": "test-token"},
)
]
generator = ProviderBlockGenerator()
result = generator.generate(
profiles=profiles,
provider_types={provider},
)
content = result.content
from iac_reverse.generator.provider_block import _PROVIDER_METADATA
tf_name = _PROVIDER_METADATA[provider][0]
# Exactly one provider block
provider_block_marker = f'provider "{tf_name}"'
assert content.count(provider_block_marker) == 1
# terraform block with required_providers
assert "terraform {" in content
assert "required_providers" in content
# ---------------------------------------------------------------------------
# Property 20: Scan profile validation completeness (multi-provider scenarios)
# ---------------------------------------------------------------------------
class TestScanProfileValidationMultiProvider:
"""Property 20: Scan profile validation completeness (multi-provider scenarios).
Additional multi-provider scenarios beyond what's in test_scan_profile_validation_prop.py.
**Validates: Requirements 6.1, 6.6, 6.7**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_valid_multi_provider_profiles_all_pass_validation(
self, provider, credentials
):
"""Each valid profile in a multi-provider set passes validation independently."""
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
errors = profile.validate()
assert errors == [], f"Expected no errors for valid profile, got: {errors}"
@given(
provider_a=provider_type_strategy,
provider_b=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_mixed_valid_invalid_profiles_detected_independently(
self, provider_a, provider_b, credentials
):
"""Invalid profiles are detected independently in a multi-provider set."""
# Valid profile
valid_profile = ScanProfile(
provider=provider_a,
credentials=credentials,
resource_type_filters=None,
)
# Invalid profile (empty credentials)
invalid_profile = ScanProfile(
provider=provider_b,
credentials={},
resource_type_filters=None,
)
valid_errors = valid_profile.validate()
invalid_errors = invalid_profile.validate()
assert valid_errors == []
assert len(invalid_errors) >= 1
assert any("credentials" in e for e in invalid_errors)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_cross_provider_resource_types_detected_as_unsupported(
self, provider, credentials
):
"""Resource types from a different provider are flagged as unsupported."""
# Pick a resource type from a different provider
other_providers = [p for p in ProviderType if p != provider]
assume(len(other_providers) > 0)
other_provider = other_providers[0]
other_types = PROVIDER_SUPPORTED_RESOURCE_TYPES[other_provider]
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=other_types[:2],
)
errors = profile.validate()
# Should detect unsupported types (unless they happen to overlap)
supported = set(PROVIDER_SUPPORTED_RESOURCE_TYPES[provider])
unsupported = [t for t in other_types[:2] if t not in supported]
if unsupported:
assert any("unsupported" in e.lower() for e in errors), (
f"Expected unsupported error for cross-provider types, got: {errors}"
)
# ---------------------------------------------------------------------------
# Property 21: Filtering correctness
# ---------------------------------------------------------------------------
class TestFilteringCorrectness:
"""Property 21: Filtering correctness.
When resource type filters are specified, only those types appear in the
scan result.
**Validates: Requirements 6.2, 6.4**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_filtered_scan_returns_only_filtered_types(self, provider, credentials):
"""When filters are specified, only filtered resource types appear in results."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
assume(len(supported) >= 2)
# Pick a subset of supported types as filter
filter_types = supported[:2]
plugin = FilteringPlugin(provider=provider)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=filter_types,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# All discovered resources must have types in the filter list
for resource in result.resources:
assert resource.resource_type in filter_types, (
f"Resource type '{resource.resource_type}' not in filter list "
f"{filter_types}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_no_filter_returns_all_supported_types(self, provider, credentials):
"""When no filters are specified, all supported types are discovered."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = FilteringPlugin(provider=provider)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# All supported types should be discovered
discovered_types = {r.resource_type for r in result.resources}
for rt in supported:
assert rt in discovered_types, (
f"Expected type '{rt}' to be discovered when no filter, "
f"got: {discovered_types}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_single_type_filter_returns_only_that_type(self, provider, credentials):
"""A single-type filter returns only resources of that type."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
assume(len(supported) >= 1)
single_filter = [supported[0]]
plugin = FilteringPlugin(provider=provider)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=single_filter,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# Only the single filtered type should appear
for resource in result.resources:
assert resource.resource_type == single_filter[0], (
f"Expected only '{single_filter[0]}', got '{resource.resource_type}'"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_empty_filter_returns_no_resources(self, provider, credentials):
"""An empty filter list results in no resources discovered."""
plugin = FilteringPlugin(provider=provider)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=[],
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# Empty filter means nothing to discover
assert len(result.resources) == 0, (
f"Expected 0 resources with empty filter, got {len(result.resources)}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_filter_subset_excludes_non_filtered_types(self, provider, credentials):
"""Types not in the filter list do not appear in results."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
assume(len(supported) >= 3)
# Filter to first 2 types only
filter_types = supported[:2]
excluded_types = supported[2:]
plugin = FilteringPlugin(provider=provider)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=filter_types,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# None of the excluded types should appear
discovered_types = {r.resource_type for r in result.resources}
for excluded in excluded_types:
assert excluded not in discovered_types, (
f"Excluded type '{excluded}' should not appear in filtered results"
)

View File

@@ -0,0 +1,222 @@
"""Property-based tests for resource inventory completeness.
**Validates: Requirements 1.2**
Property 1: Resource inventory completeness
For any discovered resource from any on-premises provider (Docker Swarm, Kubernetes,
Synology, Harvester, Bare Metal, Windows), the resulting inventory entry SHALL contain
non-empty values for resource_type, unique_id, name, provider, platform_category,
architecture, and attributes fields.
"""
from hypothesis import given, settings
from hypothesis import strategies as st
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
ProviderType,
ScanResult,
)
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
platform_category_strategy = st.sampled_from(list(PlatformCategory))
cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture))
non_empty_text_strategy = st.text(
min_size=1,
max_size=100,
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
).filter(lambda s: s.strip() != "")
resource_type_strategy = st.text(
min_size=1,
max_size=50,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"),
).filter(lambda s: len(s) > 0 and s.strip() != "")
unique_id_strategy = st.text(
min_size=1,
max_size=200,
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
).filter(lambda s: s.strip() != "")
name_strategy = st.text(
min_size=1,
max_size=100,
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
).filter(lambda s: s.strip() != "")
endpoint_strategy = st.text(
min_size=1,
max_size=200,
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
).filter(lambda s: s.strip() != "")
non_empty_attributes_strategy = st.dictionaries(
keys=st.text(
min_size=1,
max_size=30,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_"),
),
values=st.one_of(
st.text(min_size=1, max_size=50),
st.integers(min_value=0, max_value=10000),
st.booleans(),
),
min_size=1,
max_size=10,
)
raw_references_strategy = st.lists(
st.text(min_size=0, max_size=100),
min_size=0,
max_size=5,
)
discovered_resource_strategy = st.builds(
DiscoveredResource,
resource_type=resource_type_strategy,
unique_id=unique_id_strategy,
name=name_strategy,
provider=provider_type_strategy,
platform_category=platform_category_strategy,
architecture=cpu_architecture_strategy,
endpoint=endpoint_strategy,
attributes=non_empty_attributes_strategy,
raw_references=raw_references_strategy,
)
# ---------------------------------------------------------------------------
# Property Tests
# ---------------------------------------------------------------------------
class TestResourceInventoryCompleteness:
"""Property 1: Resource inventory completeness.
**Validates: Requirements 1.2**
For any discovered resource from any on-premises provider, the resulting
inventory entry SHALL contain non-empty values for resource_type, unique_id,
name, provider, platform_category, architecture, and attributes fields.
"""
@given(resource=discovered_resource_strategy)
def test_resource_type_is_non_empty(self, resource: DiscoveredResource):
"""resource_type field must be a non-empty string."""
assert isinstance(resource.resource_type, str)
assert len(resource.resource_type) > 0
assert resource.resource_type.strip() != ""
@given(resource=discovered_resource_strategy)
def test_unique_id_is_non_empty(self, resource: DiscoveredResource):
"""unique_id field must be a non-empty string."""
assert isinstance(resource.unique_id, str)
assert len(resource.unique_id) > 0
assert resource.unique_id.strip() != ""
@given(resource=discovered_resource_strategy)
def test_name_is_non_empty(self, resource: DiscoveredResource):
"""name field must be a non-empty string."""
assert isinstance(resource.name, str)
assert len(resource.name) > 0
assert resource.name.strip() != ""
@given(resource=discovered_resource_strategy)
def test_provider_is_valid_enum(self, resource: DiscoveredResource):
"""provider field must be a valid ProviderType enum value."""
assert isinstance(resource.provider, ProviderType)
assert resource.provider is not None
@given(resource=discovered_resource_strategy)
def test_platform_category_is_valid_enum(self, resource: DiscoveredResource):
"""platform_category field must be a valid PlatformCategory enum value."""
assert isinstance(resource.platform_category, PlatformCategory)
assert resource.platform_category is not None
@given(resource=discovered_resource_strategy)
def test_architecture_is_valid_enum(self, resource: DiscoveredResource):
"""architecture field must be a valid CpuArchitecture enum value."""
assert isinstance(resource.architecture, CpuArchitecture)
assert resource.architecture is not None
@given(resource=discovered_resource_strategy)
def test_attributes_is_non_empty_dict(self, resource: DiscoveredResource):
"""attributes field must be a non-empty dictionary."""
assert isinstance(resource.attributes, dict)
assert len(resource.attributes) > 0
@given(resource=discovered_resource_strategy)
def test_all_mandatory_fields_populated(self, resource: DiscoveredResource):
"""All mandatory fields must be non-empty/non-None simultaneously."""
# resource_type
assert isinstance(resource.resource_type, str) and len(resource.resource_type) > 0
# unique_id
assert isinstance(resource.unique_id, str) and len(resource.unique_id) > 0
# name
assert isinstance(resource.name, str) and len(resource.name) > 0
# provider
assert isinstance(resource.provider, ProviderType)
# platform_category
assert isinstance(resource.platform_category, PlatformCategory)
# architecture
assert isinstance(resource.architecture, CpuArchitecture)
# attributes
assert isinstance(resource.attributes, dict) and len(resource.attributes) > 0
@given(
resources=st.lists(discovered_resource_strategy, min_size=1, max_size=10),
warnings=st.lists(st.text(min_size=0, max_size=50), max_size=3),
errors=st.lists(st.text(min_size=0, max_size=50), max_size=3),
)
@settings(max_examples=50)
def test_scan_result_resources_all_have_mandatory_fields(
self, resources, warnings, errors
):
"""When a mock plugin produces resources in a ScanResult, every resource has all required fields."""
scan_result = ScanResult(
resources=resources,
warnings=warnings,
errors=errors,
scan_timestamp="2024-01-15T10:30:00Z",
profile_hash="test_hash_abc123",
is_partial=False,
)
for resource in scan_result.resources:
# resource_type is non-empty string
assert isinstance(resource.resource_type, str)
assert len(resource.resource_type) > 0
assert resource.resource_type.strip() != ""
# unique_id is non-empty string
assert isinstance(resource.unique_id, str)
assert len(resource.unique_id) > 0
assert resource.unique_id.strip() != ""
# name is non-empty string
assert isinstance(resource.name, str)
assert len(resource.name) > 0
assert resource.name.strip() != ""
# provider is valid enum
assert isinstance(resource.provider, ProviderType)
# platform_category is valid enum
assert isinstance(resource.platform_category, PlatformCategory)
# architecture is valid enum
assert isinstance(resource.architecture, CpuArchitecture)
# attributes is non-empty dict
assert isinstance(resource.attributes, dict)
assert len(resource.attributes) > 0

View File

@@ -0,0 +1,257 @@
"""Property-based tests for ScanProfile validation completeness.
**Validates: Requirements 6.1, 6.6, 6.7**
Property 20: Scan profile validation completeness
For any scan profile with K invalid fields (missing provider, empty credentials,
unreachable endpoints, filters exceeding 200 entries, or unsupported resource types),
the validation error SHALL list all K invalid fields in a single response.
"""
from hypothesis import given, assume, settings
from hypothesis import strategies as st
from iac_reverse.models import (
MAX_RESOURCE_TYPE_FILTERS,
PROVIDER_SUPPORTED_RESOURCE_TYPES,
ProviderType,
ScanProfile,
)
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
non_empty_credentials_strategy = st.dictionaries(
keys=st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=("L", "N", "P"))),
values=st.text(min_size=1, max_size=50),
min_size=1,
max_size=5,
)
empty_credentials_strategy = st.just({})
def valid_resource_types_strategy(provider: ProviderType) -> st.SearchStrategy:
"""Generate a list of valid resource types for the given provider."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
return st.lists(st.sampled_from(supported), min_size=0, max_size=min(len(supported), 10))
invalid_resource_type_strategy = st.text(
min_size=5, max_size=30,
alphabet=st.characters(whitelist_categories=("L",))
).filter(
lambda t: all(t not in types for types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values())
)
# ---------------------------------------------------------------------------
# Property Tests
# ---------------------------------------------------------------------------
class TestScanProfileValidationCompleteness:
"""Property 20: Scan profile validation completeness.
**Validates: Requirements 6.1, 6.6, 6.7**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
def test_valid_profile_returns_no_errors(self, provider, credentials):
"""A profile with non-empty credentials and no filters is always valid."""
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
errors = profile.validate()
assert errors == [], f"Expected no errors for valid profile, got: {errors}"
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
def test_valid_profile_with_valid_filters_returns_no_errors(self, provider, credentials):
"""A profile with valid credentials and valid resource type filters is valid."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=supported,
)
errors = profile.validate()
assert errors == [], f"Expected no errors for valid profile with valid filters, got: {errors}"
@given(provider=provider_type_strategy)
def test_empty_credentials_always_produces_credentials_error(self, provider):
"""Empty credentials must always produce an error mentioning 'credentials'."""
profile = ScanProfile(
provider=provider,
credentials={},
resource_type_filters=None,
)
errors = profile.validate()
assert len(errors) >= 1
assert any("credentials" in e for e in errors), (
f"Expected error mentioning 'credentials', got: {errors}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
extra_count=st.integers(min_value=1, max_value=50),
)
def test_oversized_filters_produces_count_error(self, provider, credentials, extra_count):
"""Filters exceeding MAX_RESOURCE_TYPE_FILTERS must produce an error about the count limit."""
# Build a list that exceeds the limit using valid types repeated
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
oversized_count = MAX_RESOURCE_TYPE_FILTERS + extra_count
filters = (supported * (oversized_count // len(supported) + 1))[:oversized_count]
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=filters,
)
errors = profile.validate()
assert any(
"at most" in e or str(MAX_RESOURCE_TYPE_FILTERS) in e
for e in errors
), f"Expected error mentioning count limit, got: {errors}"
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
invalid_types=st.lists(invalid_resource_type_strategy, min_size=1, max_size=5),
)
def test_unsupported_types_produces_unsupported_error(self, provider, credentials, invalid_types):
"""Unsupported resource types must produce an error mentioning them."""
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=invalid_types,
)
errors = profile.validate()
assert any("unsupported" in e.lower() for e in errors), (
f"Expected error mentioning unsupported types, got: {errors}"
)
@given(
provider=provider_type_strategy,
invalid_types=st.lists(invalid_resource_type_strategy, min_size=1, max_size=3),
)
def test_no_short_circuit_credentials_and_unsupported(self, provider, invalid_types):
"""When both credentials are empty AND unsupported types exist, both errors are reported."""
profile = ScanProfile(
provider=provider,
credentials={},
resource_type_filters=invalid_types,
)
errors = profile.validate()
assert len(errors) >= 2, f"Expected at least 2 errors, got {len(errors)}: {errors}"
assert any("credentials" in e for e in errors), (
f"Expected credentials error, got: {errors}"
)
assert any("unsupported" in e.lower() for e in errors), (
f"Expected unsupported types error, got: {errors}"
)
@given(
provider=provider_type_strategy,
extra_count=st.integers(min_value=1, max_value=20),
)
def test_no_short_circuit_credentials_and_oversized(self, provider, extra_count):
"""When both credentials are empty AND filters exceed limit, both errors are reported."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
oversized_count = MAX_RESOURCE_TYPE_FILTERS + extra_count
filters = (supported * (oversized_count // len(supported) + 1))[:oversized_count]
profile = ScanProfile(
provider=provider,
credentials={},
resource_type_filters=filters,
)
errors = profile.validate()
assert len(errors) >= 2, f"Expected at least 2 errors, got {len(errors)}: {errors}"
assert any("credentials" in e for e in errors), (
f"Expected credentials error, got: {errors}"
)
assert any(
"at most" in e or str(MAX_RESOURCE_TYPE_FILTERS) in e
for e in errors
), f"Expected count limit error, got: {errors}"
@given(
provider=provider_type_strategy,
extra_count=st.integers(min_value=1, max_value=10),
invalid_types=st.lists(invalid_resource_type_strategy, min_size=1, max_size=3),
)
def test_no_short_circuit_all_three_issues(self, provider, extra_count, invalid_types):
"""When credentials empty, filters oversized, AND unsupported types exist, all errors reported."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
oversized_count = MAX_RESOURCE_TYPE_FILTERS + extra_count
# Mix valid types (to reach oversized count) with invalid types
valid_padding = (supported * (oversized_count // len(supported) + 1))[:oversized_count]
filters = valid_padding + invalid_types
profile = ScanProfile(
provider=provider,
credentials={},
resource_type_filters=filters,
)
errors = profile.validate()
assert len(errors) >= 3, f"Expected at least 3 errors, got {len(errors)}: {errors}"
assert any("credentials" in e for e in errors), (
f"Expected credentials error, got: {errors}"
)
assert any(
"at most" in e or str(MAX_RESOURCE_TYPE_FILTERS) in e
for e in errors
), f"Expected count limit error, got: {errors}"
assert any("unsupported" in e.lower() for e in errors), (
f"Expected unsupported types error, got: {errors}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
def test_empty_list_filters_is_valid(self, provider, credentials):
"""An empty resource_type_filters list (not None) should be valid."""
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=[],
)
errors = profile.validate()
assert errors == [], f"Expected no errors for empty filter list, got: {errors}"
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
count=st.integers(min_value=1, max_value=MAX_RESOURCE_TYPE_FILTERS),
)
def test_filters_at_or_below_limit_with_valid_types_is_valid(self, provider, credentials, count):
"""Any number of valid filters at or below the limit should produce no count error."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
# Repeat valid types to reach the desired count
filters = (supported * (count // len(supported) + 1))[:count]
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=filters,
)
errors = profile.validate()
# Should not have a count limit error
assert not any(
"at most" in e or (str(MAX_RESOURCE_TYPE_FILTERS) in e and "entries" in e)
for e in errors
), f"Unexpected count limit error for {count} filters: {errors}"

View File

@@ -0,0 +1,608 @@
"""Property-based tests for Scanner behavior.
**Validates: Requirements 1.3, 1.4, 1.5, 1.7**
Property 2: Authentication error descriptiveness
For any provider type and any authentication failure reason, the error returned
by the Scanner SHALL contain both the provider name string and the failure reason string.
Property 3: Graceful degradation on unsupported resource types
For any scan request containing a mix of supported and unsupported resource types,
the Scanner SHALL produce warnings for each unsupported type AND return a complete
inventory for all supported types.
Property 4: Progress reporting frequency
The Scanner SHALL report progress at least once per resource type completion.
Property 5: Partial inventory preservation on failure
If the Provider API connection is lost during an active scan, the Scanner SHALL
return a partial resource inventory.
"""
from typing import Callable
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from iac_reverse.models import (
CpuArchitecture,
DiscoveredResource,
PlatformCategory,
PROVIDER_SUPPORTED_RESOURCE_TYPES,
ProviderType,
ScanProfile,
ScanProgress,
ScanResult,
)
from iac_reverse.plugin_base import ProviderPlugin
from iac_reverse.scanner.scanner import (
AuthenticationError,
ConnectionLostError,
Scanner,
)
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
non_empty_string_strategy = st.text(
min_size=1,
max_size=100,
alphabet=st.characters(whitelist_categories=("L", "N", "P", "S")),
).filter(lambda s: s.strip())
non_empty_credentials_strategy = st.dictionaries(
keys=st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=("L", "N"))),
values=st.text(min_size=1, max_size=50),
min_size=1,
max_size=5,
)
unsupported_resource_type_strategy = st.text(
min_size=5,
max_size=30,
alphabet=st.characters(whitelist_categories=("L",)),
).filter(
lambda t: all(t not in types for types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values())
)
# ---------------------------------------------------------------------------
# Mock Plugin Implementations
# ---------------------------------------------------------------------------
class FailingAuthPlugin(ProviderPlugin):
"""A plugin that always fails authentication with a given reason."""
def __init__(self, failure_reason: str):
self.failure_reason = failure_reason
def authenticate(self, credentials: dict[str, str]) -> None:
raise RuntimeError(self.failure_reason)
def get_platform_category(self) -> PlatformCategory:
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
return ["http://localhost:8080"]
def list_supported_resource_types(self) -> list[str]:
return ["mock_resource"]
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
return ScanResult(
resources=[], warnings=[], errors=[],
scan_timestamp="", profile_hash="",
)
class GracefulDegradationPlugin(ProviderPlugin):
"""A plugin that supports specific resource types and discovers resources for them."""
def __init__(self, supported_types: list[str]):
self._supported_types = supported_types
def authenticate(self, credentials: dict[str, str]) -> None:
pass # Always succeeds
def get_platform_category(self) -> PlatformCategory:
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
return ["http://localhost:8080"]
def list_supported_resource_types(self) -> list[str]:
return self._supported_types
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
# Create one resource per supported resource type requested
resources = []
for i, rt in enumerate(resource_types):
resources.append(
DiscoveredResource(
resource_type=rt,
unique_id=f"id-{rt}-{i}",
name=f"resource-{rt}-{i}",
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AMD64,
endpoint="http://localhost:8080",
attributes={"key": "value"},
)
)
progress_callback(ScanProgress(
current_resource_type=rt,
resources_discovered=i + 1,
resource_types_completed=i + 1,
total_resource_types=len(resource_types),
))
return ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="",
profile_hash="",
)
class ProgressTrackingPlugin(ProviderPlugin):
"""A plugin that reports progress per resource type."""
def __init__(self, supported_types: list[str]):
self._supported_types = supported_types
def authenticate(self, credentials: dict[str, str]) -> None:
pass
def get_platform_category(self) -> PlatformCategory:
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
return ["http://localhost:8080"]
def list_supported_resource_types(self) -> list[str]:
return self._supported_types
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
resources = []
for i, rt in enumerate(resource_types):
resource = DiscoveredResource(
resource_type=rt,
unique_id=f"id-{rt}-{i}",
name=f"resource-{rt}-{i}",
provider=ProviderType.KUBERNETES,
platform_category=PlatformCategory.CONTAINER_ORCHESTRATION,
architecture=CpuArchitecture.AMD64,
endpoint="http://localhost:8080",
attributes={},
)
resources.append(resource)
progress_callback(ScanProgress(
current_resource_type=rt,
resources_discovered=i + 1,
resource_types_completed=i + 1,
total_resource_types=len(resource_types),
))
return ScanResult(
resources=resources,
warnings=[],
errors=[],
scan_timestamp="",
profile_hash="",
)
class ConnectionLossPlugin(ProviderPlugin):
"""A plugin that loses connection after discovering some resources."""
def __init__(self, supported_types: list[str], fail_after: int):
self._supported_types = supported_types
self._fail_after = fail_after
def authenticate(self, credentials: dict[str, str]) -> None:
pass
def get_platform_category(self) -> PlatformCategory:
return PlatformCategory.CONTAINER_ORCHESTRATION
def list_endpoints(self) -> list[str]:
return ["http://localhost:8080"]
def list_supported_resource_types(self) -> list[str]:
return self._supported_types
def detect_architecture(self, endpoint: str) -> CpuArchitecture:
return CpuArchitecture.AMD64
def discover_resources(
self,
endpoints: list[str],
resource_types: list[str],
progress_callback: Callable[[ScanProgress], None],
) -> ScanResult:
# Simulate connection loss by raising ConnectionError
raise ConnectionError(
f"Connection lost after discovering {self._fail_after} resources"
)
# ---------------------------------------------------------------------------
# Property Tests
# ---------------------------------------------------------------------------
class TestAuthenticationErrorDescriptiveness:
"""Property 2: Authentication error descriptiveness.
For any provider type and any authentication failure reason, the error
returned by the Scanner SHALL contain both the provider name string and
the failure reason string.
**Validates: Requirements 1.3**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
failure_reason=non_empty_string_strategy,
)
@settings(max_examples=100)
def test_auth_error_contains_provider_name_and_reason(
self, provider, credentials, failure_reason
):
"""AuthenticationError must contain both provider name and failure reason."""
plugin = FailingAuthPlugin(failure_reason=failure_reason)
profile = ScanProfile(
provider=provider,
credentials=credentials,
)
scanner = Scanner(profile=profile, plugin=plugin)
with pytest.raises(AuthenticationError) as exc_info:
scanner.scan()
error = exc_info.value
# The error must contain the provider name
assert provider.value in str(error), (
f"Expected provider name '{provider.value}' in error message, "
f"got: '{str(error)}'"
)
# The error must contain the failure reason
assert failure_reason in str(error), (
f"Expected failure reason '{failure_reason}' in error message, "
f"got: '{str(error)}'"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
failure_reason=non_empty_string_strategy,
)
@settings(max_examples=100)
def test_auth_error_attributes_match(self, provider, credentials, failure_reason):
"""AuthenticationError attributes must store provider_name and reason."""
plugin = FailingAuthPlugin(failure_reason=failure_reason)
profile = ScanProfile(
provider=provider,
credentials=credentials,
)
scanner = Scanner(profile=profile, plugin=plugin)
with pytest.raises(AuthenticationError) as exc_info:
scanner.scan()
error = exc_info.value
assert error.provider_name == provider.value
assert error.reason == failure_reason
class TestGracefulDegradationOnUnsupportedTypes:
"""Property 3: Graceful degradation on unsupported resource types.
For any scan request containing a mix of supported and unsupported resource
types, the Scanner SHALL produce warnings for each unsupported type AND
return a complete inventory for all supported types.
**Validates: Requirements 1.4**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
unsupported_types=st.lists(unsupported_resource_type_strategy, min_size=1, max_size=5),
)
@settings(max_examples=100)
def test_unsupported_types_produce_warnings(
self, provider, credentials, unsupported_types
):
"""Each unsupported resource type must produce a warning."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = GracefulDegradationPlugin(supported_types=supported)
# Mix supported and unsupported types
mixed_filters = list(supported[:2]) + unsupported_types
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=mixed_filters,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# There must be a warning for each unsupported type
for unsupported in unsupported_types:
assert any(unsupported in w for w in result.warnings), (
f"Expected warning for unsupported type '{unsupported}', "
f"got warnings: {result.warnings}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
unsupported_types=st.lists(unsupported_resource_type_strategy, min_size=1, max_size=5),
)
@settings(max_examples=100)
def test_supported_types_still_discovered(
self, provider, credentials, unsupported_types
):
"""Supported types must still be fully discovered despite unsupported types."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = GracefulDegradationPlugin(supported_types=supported)
# Use at least one supported type plus unsupported types
supported_subset = supported[:2]
mixed_filters = supported_subset + unsupported_types
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=mixed_filters,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# All supported types in the filter should have resources discovered
discovered_types = {r.resource_type for r in result.resources}
for st_type in supported_subset:
assert st_type in discovered_types, (
f"Expected supported type '{st_type}' to be discovered, "
f"but only found: {discovered_types}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
unsupported_types=st.lists(unsupported_resource_type_strategy, min_size=1, max_size=5),
)
@settings(max_examples=100)
def test_warning_count_matches_unsupported_count(
self, provider, credentials, unsupported_types
):
"""Number of warnings must be at least the number of unsupported types."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = GracefulDegradationPlugin(supported_types=supported)
# Only unsupported types in filter
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=unsupported_types,
)
scanner = Scanner(profile=profile, plugin=plugin)
result = scanner.scan()
# Deduplicate unsupported types for comparison
unique_unsupported = set(unsupported_types)
assert len(result.warnings) >= len(unique_unsupported), (
f"Expected at least {len(unique_unsupported)} warnings, "
f"got {len(result.warnings)}: {result.warnings}"
)
class TestProgressReportingFrequency:
"""Property 4: Progress reporting frequency.
For any scan across N resource types, the progress callback SHALL be
invoked at least N times, once per resource type completion.
**Validates: Requirements 1.5**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_progress_reported_at_least_once_per_resource_type(
self, provider, credentials
):
"""Progress callback must be invoked at least once per resource type."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = ProgressTrackingPlugin(supported_types=supported)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None, # Scan all supported types
)
progress_reports: list[ScanProgress] = []
def track_progress(progress: ScanProgress) -> None:
progress_reports.append(progress)
scanner = Scanner(profile=profile, plugin=plugin)
scanner.scan(progress_callback=track_progress)
# Must have at least N progress reports for N resource types
assert len(progress_reports) >= len(supported), (
f"Expected at least {len(supported)} progress reports, "
f"got {len(progress_reports)}"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
)
@settings(max_examples=100)
def test_progress_reports_cover_all_resource_types(
self, provider, credentials
):
"""Progress reports must cover every resource type being scanned."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = ProgressTrackingPlugin(supported_types=supported)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
progress_reports: list[ScanProgress] = []
def track_progress(progress: ScanProgress) -> None:
progress_reports.append(progress)
scanner = Scanner(profile=profile, plugin=plugin)
scanner.scan(progress_callback=track_progress)
# Every resource type should appear in at least one progress report
reported_types = {p.current_resource_type for p in progress_reports}
for rt in supported:
assert rt in reported_types, (
f"Expected resource type '{rt}' in progress reports, "
f"but only found: {reported_types}"
)
class TestPartialInventoryPreservationOnFailure:
"""Property 5: Partial inventory preservation on failure.
If the Provider API connection is lost during an active scan, the Scanner
SHALL return a partial resource inventory.
**Validates: Requirements 1.7**
"""
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
fail_after=st.integers(min_value=0, max_value=10),
)
@settings(max_examples=100)
def test_connection_loss_raises_with_partial_result(
self, provider, credentials, fail_after
):
"""Connection loss must raise ConnectionLostError with partial result."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = ConnectionLossPlugin(
supported_types=supported,
fail_after=fail_after,
)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
scanner = Scanner(profile=profile, plugin=plugin)
with pytest.raises(ConnectionLostError) as exc_info:
scanner.scan()
error = exc_info.value
# Must have a partial_result attribute
assert hasattr(error, "partial_result")
partial = error.partial_result
assert isinstance(partial, ScanResult)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
fail_after=st.integers(min_value=0, max_value=10),
)
@settings(max_examples=100)
def test_partial_result_is_marked_as_partial(
self, provider, credentials, fail_after
):
"""Partial result from connection loss must have is_partial=True."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = ConnectionLossPlugin(
supported_types=supported,
fail_after=fail_after,
)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
scanner = Scanner(profile=profile, plugin=plugin)
with pytest.raises(ConnectionLostError) as exc_info:
scanner.scan()
partial = exc_info.value.partial_result
assert partial.is_partial is True, (
"Partial result from connection loss must have is_partial=True"
)
@given(
provider=provider_type_strategy,
credentials=non_empty_credentials_strategy,
fail_after=st.integers(min_value=0, max_value=10),
)
@settings(max_examples=100)
def test_partial_result_contains_error_info(
self, provider, credentials, fail_after
):
"""Partial result must contain error information about the failure."""
supported = PROVIDER_SUPPORTED_RESOURCE_TYPES[provider]
plugin = ConnectionLossPlugin(
supported_types=supported,
fail_after=fail_after,
)
profile = ScanProfile(
provider=provider,
credentials=credentials,
resource_type_filters=None,
)
scanner = Scanner(profile=profile, plugin=plugin)
with pytest.raises(ConnectionLostError) as exc_info:
scanner.scan()
partial = exc_info.value.partial_result
# Must have at least one error or warning indicating the failure
assert len(partial.errors) > 0 or len(partial.warnings) > 0, (
"Partial result must contain error/warning info about the connection loss"
)

View File

@@ -0,0 +1,567 @@
"""Property-based tests for the State Builder.
**Validates: Requirements 4.1, 4.2, 4.4, 4.5**
Properties tested:
- Property 16: State file structural validity
- Property 17: State entry completeness and schema correctness
"""
import json
import re
import uuid
from hypothesis import given, settings, assume
from hypothesis import strategies as st
from iac_reverse.generator.sanitize import sanitize_identifier
from iac_reverse.models import (
CodeGenerationResult,
CpuArchitecture,
DependencyGraph,
DiscoveredResource,
GeneratedFile,
PlatformCategory,
PROVIDER_SUPPORTED_RESOURCE_TYPES,
ProviderType,
ResourceRelationship,
)
from iac_reverse.state_builder import StateBuilder
# ---------------------------------------------------------------------------
# Hypothesis Strategies
# ---------------------------------------------------------------------------
provider_type_strategy = st.sampled_from(list(ProviderType))
platform_category_strategy = st.sampled_from(list(PlatformCategory))
cpu_architecture_strategy = st.sampled_from(list(CpuArchitecture))
# All supported resource types across all providers (flat list)
ALL_SUPPORTED_RESOURCE_TYPES = []
for _types in PROVIDER_SUPPORTED_RESOURCE_TYPES.values():
ALL_SUPPORTED_RESOURCE_TYPES.extend(_types)
resource_type_strategy = st.sampled_from(ALL_SUPPORTED_RESOURCE_TYPES)
# Strategy for resource names (valid identifiers with some variety)
resource_name_strategy = st.text(
min_size=1,
max_size=20,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-"),
).filter(lambda s: s.strip() != "")
# Strategy for unique IDs (non-empty strings)
unique_id_strategy = st.text(
min_size=1,
max_size=40,
alphabet=st.characters(whitelist_categories=("L", "N"), whitelist_characters="_-/:."),
).filter(lambda s: s.strip() != "")
# Strategy for simple attribute values
simple_attr_value_strategy = st.one_of(
st.text(
min_size=1,
max_size=30,
alphabet=st.characters(
whitelist_categories=("L", "N"), whitelist_characters="_-./: "
),
).filter(lambda s: s.strip() != ""),
st.integers(min_value=0, max_value=10000),
st.booleans(),
)
# Strategy for attribute dictionaries (non-empty)
attributes_strategy = st.dictionaries(
keys=st.text(
min_size=1,
max_size=15,
alphabet=st.characters(whitelist_categories=("L",), whitelist_characters="_"),
).filter(lambda s: s.strip() != "" and s[0].isalpha()),
values=simple_attr_value_strategy,
min_size=1,
max_size=5,
)
# Strategy for provider version strings (semver-like)
provider_version_strategy = st.from_regex(r"[1-9][0-9]{0,1}\.[0-9]{1,2}\.[0-9]{1,2}", fullmatch=True)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_resource(
unique_id: str,
resource_type: str = "kubernetes_deployment",
name: str = "my_resource",
provider: ProviderType = ProviderType.KUBERNETES,
platform_category: PlatformCategory = PlatformCategory.CONTAINER_ORCHESTRATION,
architecture: CpuArchitecture = CpuArchitecture.AMD64,
attributes: dict | None = None,
raw_references: list[str] | None = None,
) -> DiscoveredResource:
"""Helper to create a DiscoveredResource with sensible defaults."""
return DiscoveredResource(
resource_type=resource_type,
unique_id=unique_id,
name=name,
provider=provider,
platform_category=platform_category,
architecture=architecture,
endpoint="https://api.internal.lab:6443",
attributes=attributes or {"key": "value"},
raw_references=raw_references or [],
)
def make_dependency_graph(
resources: list[DiscoveredResource],
relationships: list[ResourceRelationship] | None = None,
) -> DependencyGraph:
"""Helper to create a DependencyGraph from resources."""
return DependencyGraph(
resources=resources,
relationships=relationships or [],
topological_order=[r.unique_id for r in resources],
cycles=[],
unresolved_references=[],
)
def make_code_generation_result() -> CodeGenerationResult:
"""Helper to create a minimal CodeGenerationResult."""
return CodeGenerationResult(
resource_files=[
GeneratedFile(filename="main.tf", content="", resource_count=0)
],
variables_file=GeneratedFile(
filename="variables.tf", content="", resource_count=0
),
provider_file=GeneratedFile(
filename="provider.tf", content="", resource_count=0
),
)
# ---------------------------------------------------------------------------
# Composite strategies
# ---------------------------------------------------------------------------
@st.composite
def mappable_resource_strategy(draw):
"""Generate a single DiscoveredResource that is mappable to state.
A mappable resource has a non-empty unique_id and a recognized resource type.
"""
resource_type = draw(resource_type_strategy)
name = draw(resource_name_strategy)
unique_id = draw(unique_id_strategy)
provider = draw(provider_type_strategy)
platform_category = draw(platform_category_strategy)
architecture = draw(cpu_architecture_strategy)
attributes = draw(attributes_strategy)
return make_resource(
unique_id=unique_id,
resource_type=resource_type,
name=name,
provider=provider,
platform_category=platform_category,
architecture=architecture,
attributes=attributes,
)
@st.composite
def multiple_mappable_resources_strategy(draw):
"""Generate a list of mappable resources with unique IDs."""
num_resources = draw(st.integers(min_value=1, max_value=5))
resources = []
seen_ids = set()
for _ in range(num_resources):
resource = draw(mappable_resource_strategy())
# Ensure unique IDs are distinct
if resource.unique_id in seen_ids:
continue
seen_ids.add(resource.unique_id)
resources.append(resource)
assume(len(resources) >= 1)
return resources
@st.composite
def resource_with_sensitive_attrs_strategy(draw):
"""Generate a resource with attributes that include sensitive-looking keys."""
resource_type = draw(resource_type_strategy)
name = draw(resource_name_strategy)
unique_id = draw(unique_id_strategy)
# Include at least one sensitive key
sensitive_key = draw(st.sampled_from([
"password", "api_secret", "auth_token", "private_key", "tls_certificate",
]))
sensitive_value = draw(st.text(min_size=1, max_size=20, alphabet="abcdefghijklmnop"))
# Also include non-sensitive attributes
normal_attrs = draw(attributes_strategy)
normal_attrs[sensitive_key] = sensitive_value
return make_resource(
unique_id=unique_id,
resource_type=resource_type,
name=name,
attributes=normal_attrs,
)
# ---------------------------------------------------------------------------
# Property 16: State file structural validity
# ---------------------------------------------------------------------------
class TestStateFileStructuralValidity:
"""Property 16: State file structural validity.
**Validates: Requirements 4.1**
For any set of resources, the generated state file has version=4,
valid UUID lineage, serial=1, and valid JSON structure.
"""
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_file_version_is_4(
self, resources: list[DiscoveredResource]
):
"""The generated state file always has version=4."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert state_file.version == 4, (
f"Expected version=4, got version={state_file.version}"
)
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_file_has_valid_uuid_lineage(
self, resources: list[DiscoveredResource]
):
"""The generated state file has a valid UUID lineage."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
# Lineage should be a valid UUID
try:
parsed_uuid = uuid.UUID(state_file.lineage)
except ValueError:
raise AssertionError(
f"Lineage '{state_file.lineage}' is not a valid UUID"
)
assert parsed_uuid.version == 4, (
f"Expected UUID version 4, got version {parsed_uuid.version}"
)
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_file_serial_is_1(
self, resources: list[DiscoveredResource]
):
"""The generated state file always has serial=1."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert state_file.serial == 1, (
f"Expected serial=1, got serial={state_file.serial}"
)
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_file_produces_valid_json(
self, resources: list[DiscoveredResource]
):
"""The state file serializes to valid JSON via to_json()."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
json_str = state_file.to_json()
# Must parse as valid JSON
try:
parsed = json.loads(json_str)
except json.JSONDecodeError as e:
raise AssertionError(
f"State file to_json() produced invalid JSON: {e}"
)
assert isinstance(parsed, dict), "State JSON root must be a dict"
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_json_has_required_top_level_fields(
self, resources: list[DiscoveredResource]
):
"""The serialized state JSON has version, terraform_version, serial, lineage, resources."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
parsed = json.loads(state_file.to_json())
required_fields = {"version", "terraform_version", "serial", "lineage", "resources"}
missing = required_fields - set(parsed.keys())
assert not missing, (
f"State JSON missing required top-level fields: {missing}"
)
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_json_resource_entries_have_required_fields(
self, resources: list[DiscoveredResource]
):
"""Each resource entry in the JSON has mode, type, name, provider, and instances."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
parsed = json.loads(state_file.to_json())
required_resource_fields = {"mode", "type", "name", "provider", "instances"}
for i, entry in enumerate(parsed["resources"]):
missing = required_resource_fields - set(entry.keys())
assert not missing, (
f"Resource entry {i} missing required fields: {missing}. "
f"Entry keys: {list(entry.keys())}"
)
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_json_instances_have_schema_and_attributes(
self, resources: list[DiscoveredResource]
):
"""Each instance in the state JSON has schema_version, attributes, sensitive_attributes, dependencies."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
parsed = json.loads(state_file.to_json())
required_instance_fields = {
"schema_version", "attributes", "sensitive_attributes", "dependencies"
}
for i, entry in enumerate(parsed["resources"]):
for j, instance in enumerate(entry["instances"]):
missing = required_instance_fields - set(instance.keys())
assert not missing, (
f"Resource {i}, instance {j} missing fields: {missing}. "
f"Instance keys: {list(instance.keys())}"
)
# ---------------------------------------------------------------------------
# Property 17: State entry completeness and schema correctness
# ---------------------------------------------------------------------------
class TestStateEntryCompletenessAndSchemaCorrectness:
"""Property 17: State entry completeness and schema correctness.
**Validates: Requirements 4.4, 4.5**
For any resource, the state entry has non-empty resource_type,
resource_name, provider_id, and attributes matching the discovery data.
"""
@given(resource=mappable_resource_strategy())
@settings(max_examples=100)
def test_state_entry_has_non_empty_resource_type(
self, resource: DiscoveredResource
):
"""Each state entry has a non-empty resource_type."""
builder = StateBuilder()
graph = make_dependency_graph([resource])
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert len(state_file.resources) == 1
entry = state_file.resources[0]
assert entry.resource_type != "", (
"State entry resource_type must not be empty"
)
assert entry.resource_type == resource.resource_type, (
f"Expected resource_type '{resource.resource_type}', "
f"got '{entry.resource_type}'"
)
@given(resource=mappable_resource_strategy())
@settings(max_examples=100)
def test_state_entry_has_non_empty_resource_name(
self, resource: DiscoveredResource
):
"""Each state entry has a non-empty resource_name (sanitized)."""
builder = StateBuilder()
graph = make_dependency_graph([resource])
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert len(state_file.resources) == 1
entry = state_file.resources[0]
assert entry.resource_name != "", (
"State entry resource_name must not be empty"
)
# The name should be a sanitized version of the original
expected_name = sanitize_identifier(resource.name)
assert entry.resource_name == expected_name, (
f"Expected resource_name '{expected_name}', "
f"got '{entry.resource_name}'"
)
@given(resource=mappable_resource_strategy())
@settings(max_examples=100)
def test_state_entry_has_non_empty_provider_id(
self, resource: DiscoveredResource
):
"""Each state entry has a non-empty provider_id matching the resource's unique_id."""
builder = StateBuilder()
graph = make_dependency_graph([resource])
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert len(state_file.resources) == 1
entry = state_file.resources[0]
assert entry.provider_id != "", (
"State entry provider_id must not be empty"
)
assert entry.provider_id == resource.unique_id, (
f"Expected provider_id '{resource.unique_id}', "
f"got '{entry.provider_id}'"
)
@given(resource=mappable_resource_strategy())
@settings(max_examples=100)
def test_state_entry_attributes_match_discovery_data(
self, resource: DiscoveredResource
):
"""State entry attributes contain all attributes from the discovered resource."""
builder = StateBuilder()
graph = make_dependency_graph([resource])
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert len(state_file.resources) == 1
entry = state_file.resources[0]
# All discovery attributes should be present in the state entry
for key, value in resource.attributes.items():
assert key in entry.attributes, (
f"Discovery attribute '{key}' missing from state entry attributes. "
f"State attrs: {list(entry.attributes.keys())}"
)
assert entry.attributes[key] == value, (
f"Attribute '{key}' mismatch: discovery={value}, "
f"state={entry.attributes[key]}"
)
@given(
resource=mappable_resource_strategy(),
provider_version=provider_version_strategy,
)
@settings(max_examples=100)
def test_state_entry_schema_version_matches_provider_version(
self, resource: DiscoveredResource, provider_version: str
):
"""State entry schema_version matches the major version from provider_version."""
builder = StateBuilder()
graph = make_dependency_graph([resource])
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, provider_version)
assert len(state_file.resources) == 1
entry = state_file.resources[0]
# Schema version should be the major version number
expected_schema_version = int(provider_version.split(".")[0])
assert entry.schema_version == expected_schema_version, (
f"Expected schema_version={expected_schema_version} "
f"(from provider_version='{provider_version}'), "
f"got schema_version={entry.schema_version}"
)
@given(resource=resource_with_sensitive_attrs_strategy())
@settings(max_examples=100)
def test_state_entry_marks_sensitive_attributes(
self, resource: DiscoveredResource
):
"""State entry identifies and marks sensitive attributes correctly."""
builder = StateBuilder()
graph = make_dependency_graph([resource])
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
assert len(state_file.resources) == 1
entry = state_file.resources[0]
# Sensitive attributes list should not be empty when resource has
# attributes with sensitive patterns (password, secret, token, key, certificate)
sensitive_patterns = ["password", "secret", "token", "key", "certificate"]
has_sensitive = any(
any(pattern in attr_key.lower() for pattern in sensitive_patterns)
for attr_key in resource.attributes.keys()
)
if has_sensitive:
assert len(entry.sensitive_attributes) > 0, (
f"Resource has sensitive-looking attributes "
f"{list(resource.attributes.keys())} but sensitive_attributes "
f"is empty"
)
@given(resources=multiple_mappable_resources_strategy())
@settings(max_examples=100)
def test_state_json_id_field_matches_provider_id(
self, resources: list[DiscoveredResource]
):
"""In the serialized JSON, each instance's attributes.id matches the provider_id."""
builder = StateBuilder()
graph = make_dependency_graph(resources)
code_result = make_code_generation_result()
state_file = builder.build(code_result, graph, "1.0.0")
parsed = json.loads(state_file.to_json())
for i, entry in enumerate(parsed["resources"]):
for instance in entry["instances"]:
assert "id" in instance["attributes"], (
f"Resource entry {i} instance missing 'id' in attributes"
)
# The id should be non-empty
assert instance["attributes"]["id"] != "", (
f"Resource entry {i} has empty 'id' attribute"
)

1
tests/unit/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Unit tests for IaC Reverse Engineering Tool."""

Some files were not shown because too many files have changed in this diff Show More