added test go implementation
continuous-integration/drone/push Build encountered an error

This commit is contained in:
gyurix
2026-06-08 17:02:13 +02:00
parent a555cce680
commit fcda599ec7
9 changed files with 1112 additions and 44 deletions
+231
View File
@@ -0,0 +1,231 @@
package config
import (
"os"
"path/filepath"
"testing"
)
const testJSON = `{
"networks": {
"smarthost-loadbalancer": {
"network_name": "smarthost-loadbalancer",
"subnet": "172.18.103.0/24",
"gateway": "172.18.103.1"
},
"smarthost_backend-1": {
"network_name": "smarthost_backend-1",
"subnet": "172.18.104.0/24",
"gateway": "172.18.104.1"
}
},
"ips": {
"172.18.103.2": {
"ip": "172.18.103.2",
"container_name": "smarthostloadbalancer",
"selector": "smarthostloadbalancer",
"service_name": "smarthost-proxy"
},
"172.18.104.2": {
"ip": "172.18.104.2",
"container_name": "smarthostbackend-1",
"selector": "smarthostbackend-1",
"service_name": "smarthost-proxy"
}
},
"policies": [
{
"service_name": "smarthost-proxy",
"container_name": "smarthost_loadbalancer",
"selector": "smarthostloadbalancer",
"from": "publicbackend",
"port": 80,
"proto": "tcp"
},
{
"service_name": "smarthost-proxy",
"container_name": "smarthost_loadbalancer",
"selector": "smarthostloadbalancer",
"name": "wireguardproxy",
"iface": "wg0",
"nat": "dnat",
"to": "smarthostloadbalancer",
"port": 80,
"proto": "tcp"
}
]
}`
func TestLoad(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "networks.json")
if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load() returned error: %v", err)
}
if len(cfg.Networks) != 2 {
t.Errorf("expected 2 networks, got %d", len(cfg.Networks))
}
if cfg.Networks["smarthost-loadbalancer"].Subnet != "172.18.103.0/24" {
t.Errorf("unexpected subnet: %s", cfg.Networks["smarthost-loadbalancer"].Subnet)
}
if len(cfg.IPs) != 2 {
t.Errorf("expected 2 IPs, got %d", len(cfg.IPs))
}
if cfg.IPs["172.18.103.2"].ContainerName != "smarthostloadbalancer" {
t.Errorf("unexpected container name: %s", cfg.IPs["172.18.103.2"].ContainerName)
}
if len(cfg.Policies) != 2 {
t.Errorf("expected 2 policies, got %d", len(cfg.Policies))
}
if cfg.Policies[0].From != "publicbackend" {
t.Errorf("unexpected from: %s", cfg.Policies[0].From)
}
if cfg.Policies[1].Nat != "dnat" {
t.Errorf("unexpected nat: %s", cfg.Policies[1].Nat)
}
}
func TestLoadFileNotFound(t *testing.T) {
_, err := Load("/nonexistent/path/networks.json")
if err == nil {
t.Error("expected error for nonexistent file, got nil")
}
}
func TestLoadInvalidJSON(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad.json")
if err := os.WriteFile(path, []byte("{invalid json"), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
_, err := Load(path)
if err == nil {
t.Error("expected error for invalid JSON, got nil")
}
}
func TestToCIDR(t *testing.T) {
tests := []struct {
input string
want string
}{
{"172.18.103.0", "172.18.103.0/24"},
{"172.18.103.2", "172.18.103.2"},
{"172.18.103.0/24", "172.18.103.0/24"},
{"10.0.0.0", "10.0.0.0/24"},
{"invalid", "invalid"},
}
for _, tt := range tests {
got := ToCIDR(tt.input)
if got != tt.want {
t.Errorf("ToCIDR(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestNetworkPrefix(t *testing.T) {
tests := []struct {
input string
want string
}{
{"172.18.103.2", "172.18.103.0/24"},
{"10.0.0.1", "10.0.0.0/24"},
{"invalid", "invalid"},
}
for _, tt := range tests {
got := NetworkPrefix(tt.input)
if got != tt.want {
t.Errorf("NetworkPrefix(%q) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestIsIP(t *testing.T) {
tests := []struct {
input string
want bool
}{
{"172.18.103.2", true},
{"10.0.0.0", true},
{"255.255.255.255", true},
{"publicbackend", false},
{"172.18.103.0/24", false},
{"", false},
}
for _, tt := range tests {
got := IsIP(tt.input)
if got != tt.want {
t.Errorf("IsIP(%q) = %v, want %v", tt.input, got, tt.want)
}
}
}
func TestNetworkConfigParseCIDR(t *testing.T) {
nc := NetworkConfig{Subnet: "172.18.103.0/24"}
ipNet, err := nc.ParseCIDR()
if err != nil {
t.Fatalf("ParseCIDR() returned error: %v", err)
}
if ipNet.String() != "172.18.103.0/24" {
t.Errorf("ParseCIDR() = %s, want 172.18.103.0/24", ipNet.String())
}
nc2 := NetworkConfig{Subnet: "invalid"}
_, err = nc2.ParseCIDR()
if err == nil {
t.Error("expected error for invalid subnet, got nil")
}
}
func TestLoadReproducible(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "networks.json")
if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
// Load twice and compare
cfg1, err := Load(path)
if err != nil {
t.Fatalf("first Load() error: %v", err)
}
cfg2, err := Load(path)
if err != nil {
t.Fatalf("second Load() error: %v", err)
}
// Verify reproducibility
if len(cfg1.Networks) != len(cfg2.Networks) {
t.Errorf("reproducibility: network count mismatch %d vs %d", len(cfg1.Networks), len(cfg2.Networks))
}
if len(cfg1.IPs) != len(cfg2.IPs) {
t.Errorf("reproducibility: IP count mismatch %d vs %d", len(cfg1.IPs), len(cfg2.IPs))
}
if len(cfg1.Policies) != len(cfg2.Policies) {
t.Errorf("reproducibility: policy count mismatch %d vs %d", len(cfg1.Policies), len(cfg2.Policies))
}
for name, net1 := range cfg1.Networks {
net2 := cfg2.Networks[name]
if net1.Subnet != net2.Subnet || net1.Gateway != net2.Gateway {
t.Errorf("reproducibility: network %s mismatch", name)
}
}
}
+19 -6
View File
@@ -15,11 +15,27 @@ import (
"firewall_containers/network-go/config" "firewall_containers/network-go/config"
) )
// DockerAPI defines the interface for Docker operations, enabling mock implementations for testing
type DockerAPI interface {
Close() error
EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error
RemoveNetwork(ctx context.Context, networkName string) error
ConnectContainer(ctx context.Context, containerName, networkName, ip string) error
DisconnectContainer(ctx context.Context, containerName, networkName string) error
InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error)
WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error
GetContainerPID(ctx context.Context, containerName string) (int, error)
AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error
}
// Client wraps the Docker SDK client // Client wraps the Docker SDK client
type Client struct { type Client struct {
cli *client.Client cli *client.Client
} }
// Ensure Client implements DockerAPI
var _ DockerAPI = (*Client)(nil)
// NewClient creates a new Docker client // NewClient creates a new Docker client
func NewClient() (*Client, error) { func NewClient() (*Client, error) {
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
@@ -36,7 +52,6 @@ func (c *Client) Close() error {
// EnsureNetwork creates a Docker network if it does not already exist // EnsureNetwork creates a Docker network if it does not already exist
func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error { func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
// Check if network already exists
existingNetworks, err := c.cli.NetworkList(ctx, network.ListOptions{ existingNetworks, err := c.cli.NetworkList(ctx, network.ListOptions{
Filters: filters.NewArgs(filters.Arg("name", netCfg.NetworkName)), Filters: filters.NewArgs(filters.Arg("name", netCfg.NetworkName)),
}) })
@@ -46,12 +61,10 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
for _, n := range existingNetworks { for _, n := range existingNetworks {
if n.Name == netCfg.NetworkName { if n.Name == netCfg.NetworkName {
// Network already exists, skip creation
return nil return nil
} }
} }
// Parse subnet and gateway
_, ipNet, err := net.ParseCIDR(netCfg.Subnet) _, ipNet, err := net.ParseCIDR(netCfg.Subnet)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse subnet %s: %w", netCfg.Subnet, err) return fmt.Errorf("failed to parse subnet %s: %w", netCfg.Subnet, err)
@@ -62,7 +75,6 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
return fmt.Errorf("failed to parse gateway IP %s", netCfg.Gateway) return fmt.Errorf("failed to parse gateway IP %s", netCfg.Gateway)
} }
// Create the network
createOpts := network.CreateOptions{ createOpts := network.CreateOptions{
Driver: "bridge", Driver: "bridge",
IPAM: &network.IPAM{ IPAM: &network.IPAM{
@@ -84,7 +96,7 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
return fmt.Errorf("failed to create network %s: %w", netCfg.NetworkName, err) return fmt.Errorf("failed to create network %s: %w", netCfg.NetworkName, err)
} }
_ = resp // response contains ID and warnings _ = resp
return nil return nil
} }
@@ -154,6 +166,7 @@ func (c *Client) WaitForContainerRunning(ctx context.Context, containerName stri
} }
} }
// GetContainerPID returns the PID of a container for nsenter operations
func (c *Client) GetContainerPID(ctx context.Context, containerName string) (int, error) { func (c *Client) GetContainerPID(ctx context.Context, containerName string) (int, error) {
cont, err := c.cli.ContainerInspect(ctx, containerName) cont, err := c.cli.ContainerInspect(ctx, containerName)
if err != nil { if err != nil {
@@ -184,4 +197,4 @@ func (c *Client) AddRouteInContainer(ctx context.Context, containerName, network
return fmt.Errorf("failed to add route in container %s: %w\noutput: %s", containerName, err, string(output)) return fmt.Errorf("failed to add route in container %s: %w\noutput: %s", containerName, err, string(output))
} }
return nil return nil
} }
+3 -3
View File
@@ -16,14 +16,14 @@ import (
// Orchestrator reconciles the networks.json configuration into Docker networks // Orchestrator reconciles the networks.json configuration into Docker networks
// and iptables firewall rules // and iptables firewall rules
type Orchestrator struct { type Orchestrator struct {
dockerClient *docker.Client dockerClient docker.DockerAPI
iptablesMgr *iptables.Manager iptablesMgr iptables.IPTablesAPI
resolver *resolver.Resolver resolver *resolver.Resolver
debug bool debug bool
} }
// NewOrchestrator creates a new firewall orchestrator // NewOrchestrator creates a new firewall orchestrator
func NewOrchestrator(dockerClient *docker.Client, iptablesMgr *iptables.Manager, cfg *config.NetworksConfig) *Orchestrator { func NewOrchestrator(dockerClient docker.DockerAPI, iptablesMgr iptables.IPTablesAPI, cfg *config.NetworksConfig) *Orchestrator {
return &Orchestrator{ return &Orchestrator{
dockerClient: dockerClient, dockerClient: dockerClient,
iptablesMgr: iptablesMgr, iptablesMgr: iptablesMgr,
+333
View File
@@ -0,0 +1,333 @@
package firewall
import (
"context"
"testing"
"firewall_containers/network-go/config"
"firewall_containers/network-go/mock"
)
func testConfig() *config.NetworksConfig {
return &config.NetworksConfig{
Networks: map[string]config.NetworkConfig{
"smarthost-loadbalancer": {
NetworkName: "smarthost-loadbalancer",
Subnet: "172.18.103.0/24",
Gateway: "172.18.103.1",
},
"smarthost_backend-1": {
NetworkName: "smarthost_backend-1",
Subnet: "172.18.104.0/24",
Gateway: "172.18.104.1",
},
},
IPs: map[string]config.IPConfig{
"172.18.103.2": {
IP: "172.18.103.2",
ContainerName: "smarthostloadbalancer",
Selector: "smarthostloadbalancer",
ServiceName: "smarthost-proxy",
},
"172.18.104.2": {
IP: "172.18.104.2",
ContainerName: "smarthostbackend-1",
Selector: "smarthostbackend-1",
ServiceName: "smarthost-proxy",
},
},
Policies: []config.PolicyConfig{},
}
}
func TestReconcileAllCreatesNetworks(t *testing.T) {
cfg := testConfig()
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.ReconcileAll(ctx, cfg)
if !docker.EnsureNetworkCalled {
t.Error("EnsureNetwork was not called")
}
}
func TestReconcileAllConnectsContainers(t *testing.T) {
cfg := testConfig()
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.ReconcileAll(ctx, cfg)
if !docker.ConnectContainerCalled {
t.Error("ConnectContainer was not called")
}
}
func TestReconcileAllEnablesIPForward(t *testing.T) {
cfg := testConfig()
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.ReconcileAll(ctx, cfg)
if !iptables.EnsureIPForwardCalled {
t.Error("EnsureIPForward was not called")
}
}
func TestReconcilePoliciesForwardRule(t *testing.T) {
cfg := testConfig()
cfg.Policies = []config.PolicyConfig{
{
ServiceName: "smarthost-proxy",
ContainerName: "smarthostloadbalancer",
Selector: "smarthostloadbalancer",
From: "publicbackend",
Port: 80,
Proto: "tcp",
},
}
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables"}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.reconcilePolicies(ctx, cfg)
// "publicbackend" should be resolved to an IP from the config (if it matches)
// Since "publicbackend" is not in the IPs map, sourceIP will be empty
if !iptables.InsertForwardAcceptCalled {
t.Error("InsertForwardAccept was not called")
}
// Should use FORWARD chain (not iptables-legacy)
if iptables.InsertForwardAcceptChain != "FORWARD" {
t.Errorf("expected FORWARD chain, got %s", iptables.InsertForwardAcceptChain)
}
}
func TestReconcilePoliciesForwardRuleWithLegacy(t *testing.T) {
cfg := testConfig()
cfg.Policies = []config.PolicyConfig{
{
ServiceName: "smarthost-proxy",
From: "publicbackend",
Port: 80,
Proto: "tcp",
},
}
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables-legacy"}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.reconcilePolicies(ctx, cfg)
if !iptables.InsertForwardAcceptCalled {
t.Error("InsertForwardAccept was not called")
}
// Should use DOCKER-USER chain when iptables-legacy
if iptables.InsertForwardAcceptChain != "DOCKER-USER" {
t.Errorf("expected DOCKER-USER chain, got %s", iptables.InsertForwardAcceptChain)
}
}
func TestReconcilePoliciesDNATWithInterface(t *testing.T) {
cfg := testConfig()
cfg.Policies = []config.PolicyConfig{
{
ServiceName: "smarthost-proxy",
Name: "wireguardproxy",
Selector: "smarthostloadbalancer",
Iface: "wg0",
Nat: "dnat",
To: "smarthostloadbalancer",
Port: 80,
Proto: "tcp",
},
}
docker := &mock.MockDockerClient{
GetContainerPIDResult: 0, // simulate no PID available
GetContainerPIDErr: assertError("container not running"),
}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.reconcilePolicies(ctx, cfg)
// When GetContainerPID fails, should fall back to interface-based rule
if !iptables.InsertPreroutingRuleOnInterfaceCalled {
t.Error("InsertPreroutingRuleOnInterface was not called (should fall back from nsenter)")
}
if len(iptables.InsertPreroutingRuleOnInterfaceArgs) > 0 {
iface := iptables.InsertPreroutingRuleOnInterfaceArgs[0]
if iface != "wg0" {
t.Errorf("expected interface wg0, got %s", iface)
}
}
}
func TestReconcilePoliciesDNATWithContainerPID(t *testing.T) {
cfg := testConfig()
cfg.Policies = []config.PolicyConfig{
{
ServiceName: "smarthost-proxy",
Name: "wireguardproxy",
Selector: "smarthostloadbalancer",
Nat: "dnat",
To: "smarthostloadbalancer",
Port: 80,
Proto: "tcp",
},
}
docker := &mock.MockDockerClient{
GetContainerPIDResult: 1234,
GetContainerPIDErr: nil,
}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.reconcilePolicies(ctx, cfg)
if !docker.GetContainerPIDCalled {
t.Error("GetContainerPID was not called")
}
if !iptables.InsertPreroutingRuleInContainerCalled {
t.Error("InsertPreroutingRuleInContainer was not called")
}
if iptables.InsertPreroutingRuleInContainerPID != 1234 {
t.Errorf("expected PID 1234, got %d", iptables.InsertPreroutingRuleInContainerPID)
}
}
func TestReconcilePoliciesUnresolvedTarget(t *testing.T) {
cfg := testConfig()
cfg.Policies = []config.PolicyConfig{
{
ServiceName: "test",
Nat: "dnat",
Selector: "container1",
To: "nonexistent-target",
Port: 80,
},
}
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.reconcilePolicies(ctx, cfg)
// Should not call GetContainerPID when target can't be resolved
if docker.GetContainerPIDCalled {
t.Error("GetContainerPID should not be called when target is unresolvable")
}
}
func TestResolveIPDirectIP(t *testing.T) {
cfg := testConfig()
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
// Direct IP should be returned as CIDR
result := orch.resolveIP("172.18.103.2")
if result != "172.18.103.2" {
t.Errorf("expected 172.18.103.2, got %s", result)
}
// .0 ending should be converted to /24
result = orch.resolveIP("172.18.103.0")
if result != "172.18.103.0/24" {
t.Errorf("expected 172.18.103.0/24, got %s", result)
}
}
func TestResolveIPFromConfig(t *testing.T) {
cfg := testConfig()
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
// Should resolve container name from config
result := orch.resolveIP("smarthostloadbalancer")
if result != "172.18.103.2" {
t.Errorf("expected 172.18.103.2, got %s", result)
}
}
func TestFindNetworkForIP(t *testing.T) {
cfg := testConfig()
tests := []struct {
ip string
want string
}{
{"172.18.103.5", "smarthost-loadbalancer"},
{"172.18.104.5", "smarthost_backend-1"},
{"10.0.0.1", ""},
{"", ""},
}
for _, tt := range tests {
got := findNetworkForIP(cfg, tt.ip)
if got != tt.want {
t.Errorf("findNetworkForIP(%q) = %q, want %q", tt.ip, got, tt.want)
}
}
}
func TestReconcileAllReproducible(t *testing.T) {
cfg := testConfig()
cfg.Policies = []config.PolicyConfig{
{
ServiceName: "test-svc",
From: "publicbackend",
Port: 80,
Proto: "tcp",
},
}
// Run reconciliation twice with separate mocks
for i := 0; i < 2; i++ {
docker := &mock.MockDockerClient{}
iptables := &mock.MockIPTablesManager{}
orch := NewOrchestrator(docker, iptables, cfg)
ctx := context.Background()
orch.ReconcileAll(ctx, cfg)
if !docker.EnsureNetworkCalled {
t.Errorf("run %d: EnsureNetwork not called", i)
}
if !iptables.InsertForwardAcceptCalled {
t.Errorf("run %d: InsertForwardAccept not called", i)
}
}
}
// assertError is a helper to create a simple error for tests
type simpleErr struct{ msg string }
func (e simpleErr) Error() string { return e.msg }
func assertError(msg string) error {
return simpleErr{msg: msg}
}
-20
View File
@@ -88,26 +88,6 @@ Without these, the program cannot:
- Insert PREROUTING/POSTROUTING rules inside other containers - Insert PREROUTING/POSTROUTING rules inside other containers
- Add routes to container network namespaces - Add routes to container network namespaces
### Minimal Docker Compose Example
```yaml
version: "3.8"
services:
network-go:
build: ./network-go
network_mode: host
pid: "host"
cap_add:
- NET_ADMIN
- SYS_ADMIN
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /etc/user/config:/etc/user/config
environment:
- WATCH_PERIOD_SECONDS=30
- DEBUG=false
```
## Configuration — `/etc/user/config/networks.json` ## Configuration — `/etc/user/config/networks.json`
```json ```json
+20 -15
View File
@@ -6,12 +6,31 @@ import (
"strings" "strings"
) )
// IPTablesAPI defines the interface for iptables operations, enabling mock implementations for testing
type IPTablesAPI interface {
Binary() string
EnsureIPForward() error
EnsureEstablishedRelated(chain string) error
DeleteLine(chain string, lineNum string) error
InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error
InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error
InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error
InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error
InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error
DeleteForwardAccept(chain, comment string) error
InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error
InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error
}
// Manager manages iptables rules via the iptables/iptables-legacy CLI // Manager manages iptables rules via the iptables/iptables-legacy CLI
type Manager struct { type Manager struct {
binary string // /usr/sbin/iptables or /usr/sbin/iptables-legacy binary string
debug bool debug bool
} }
// Ensure Manager implements IPTablesAPI
var _ IPTablesAPI = (*Manager)(nil)
// NewManager creates a new iptables manager, auto-detecting the binary // NewManager creates a new iptables manager, auto-detecting the binary
func NewManager(debug bool) *Manager { func NewManager(debug bool) *Manager {
m := &Manager{debug: debug} m := &Manager{debug: debug}
@@ -83,17 +102,14 @@ func (m *Manager) EnsureIPForward() error {
} }
// EnsureEstablishedRelated inserts an ESTABLISHED,RELATED accept rule at the top of a chain // EnsureEstablishedRelated inserts an ESTABLISHED,RELATED accept rule at the top of a chain
// if it doesn't already exist
func (m *Manager) EnsureEstablishedRelated(chain string) error { func (m *Manager) EnsureEstablishedRelated(chain string) error {
checkArgs := []string{"-w", "-n", "-L", chain} checkArgs := []string{"-w", "-n", "-L", chain}
cmd := exec.Command(m.binary, checkArgs...) cmd := exec.Command(m.binary, checkArgs...)
output, err := cmd.Output() output, err := cmd.Output()
if err != nil { if err != nil {
// Chain may not exist, create it
return nil return nil
} }
// Only insert if ESTABLISHED,RELATED rule is not present
if !strings.Contains(string(output), "ESTABLISHED") || !strings.Contains(string(output), "RELATED") { if !strings.Contains(string(output), "ESTABLISHED") || !strings.Contains(string(output), "RELATED") {
args := []string{"-w", "-I", chain, "-m", "state", "--state", "established,related", "-j", "ACCEPT"} args := []string{"-w", "-I", chain, "-m", "state", "--state", "established,related", "-j", "ACCEPT"}
return m.run(args...) return m.run(args...)
@@ -114,7 +130,6 @@ func (m *Manager) DeleteLineInContainer(pid int, table, chain, lineNum string) e
} }
// getLineNumbers returns line numbers matching certain criteria in a chain/table // getLineNumbers returns line numbers matching certain criteria in a chain/table
// This implements the grep logic from the shell script: iptables -w --line-number -n -L $CHAIN | grep ...
func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []string { func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []string {
args := []string{"-w", "--line-number", "-n", "-L", chain} args := []string{"-w", "--line-number", "-n", "-L", chain}
if table != "" { if table != "" {
@@ -150,7 +165,6 @@ func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []
// deleteMatchingLines deletes all lines in a chain matching the given patterns // deleteMatchingLines deletes all lines in a chain matching the given patterns
func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...string) error { func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...string) error {
lines := m.getLineNumbers(chain, table, grepPatterns...) lines := m.getLineNumbers(chain, table, grepPatterns...)
// Reverse order (highest line first) so deletions don't shift line numbers
for i := len(lines) - 1; i >= 0; i-- { for i := len(lines) - 1; i >= 0; i-- {
if err := m.DeleteLine(chain, lines[i]); err != nil { if err := m.DeleteLine(chain, lines[i]); err != nil {
return err return err
@@ -161,7 +175,6 @@ func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...strin
// deleteMatchingLinesInContainer deletes matching lines inside a container namespace // deleteMatchingLinesInContainer deletes matching lines inside a container namespace
func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, grepPatterns ...string) error { func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, grepPatterns ...string) error {
// For container namespaces, we use a different approach: list via nsenter + grep
iptPath := "/sbin/iptables-legacy" iptPath := "/sbin/iptables-legacy"
if !strings.Contains(m.binary, "legacy") { if !strings.Contains(m.binary, "legacy") {
iptPath = "/sbin/iptables" iptPath = "/sbin/iptables"
@@ -192,7 +205,6 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g
} }
} }
// Delete in reverse order
for i := len(matchingLines) - 1; i >= 0; i-- { for i := len(matchingLines) - 1; i >= 0; i-- {
if err := m.DeleteLineInContainer(pid, table, chain, matchingLines[i]); err != nil { if err := m.DeleteLineInContainer(pid, table, chain, matchingLines[i]); err != nil {
return err return err
@@ -203,13 +215,11 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g
// InsertPreroutingRule inserts a DNAT PREROUTING rule on the host // InsertPreroutingRule inserts a DNAT PREROUTING rule on the host
func (m *Manager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error { func (m *Manager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
// First, delete existing matching rules
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment} patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
if err := m.deleteMatchingLines("PREROUTING", "nat", patterns...); err != nil { if err := m.deleteMatchingLines("PREROUTING", "nat", patterns...); err != nil {
return fmt.Errorf("failed to delete old PREROUTING rules: %w", err) return fmt.Errorf("failed to delete old PREROUTING rules: %w", err)
} }
// Insert the new rule
args := []string{ args := []string{
"-w", "-t", "nat", "-I", "PREROUTING", "-w", "-t", "nat", "-I", "PREROUTING",
"-d", sourceIP, "-d", sourceIP,
@@ -236,7 +246,6 @@ func (m *Manager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targ
// InsertPostroutingMasquerade inserts a MASQUERADE POSTROUTING rule on the host // InsertPostroutingMasquerade inserts a MASQUERADE POSTROUTING rule on the host
func (m *Manager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error { func (m *Manager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
// Delete existing matching rules first
patterns := []string{"MASQUERADE", comment, sourceCIDR, sourcePort} patterns := []string{"MASQUERADE", comment, sourceCIDR, sourcePort}
if err := m.deleteMatchingLines("POSTROUTING", "nat", patterns...); err != nil { if err := m.deleteMatchingLines("POSTROUTING", "nat", patterns...); err != nil {
return fmt.Errorf("failed to delete old POSTROUTING rules: %w", err) return fmt.Errorf("failed to delete old POSTROUTING rules: %w", err)
@@ -273,7 +282,6 @@ func (m *Manager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, target
// InsertForwardAccept inserts a FORWARD ACCEPT rule on the host // InsertForwardAccept inserts a FORWARD ACCEPT rule on the host
func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error { func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
// Build grep patterns to match existing rules
var grepPatterns []string var grepPatterns []string
grepPatterns = append(grepPatterns, proto) grepPatterns = append(grepPatterns, proto)
if sourceIP != "" { if sourceIP != "" {
@@ -289,12 +297,10 @@ func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePo
grepPatterns = append(grepPatterns, targetPort) grepPatterns = append(grepPatterns, targetPort)
} }
// Delete old matching rules
if err := m.deleteMatchingLines(chain, "", grepPatterns...); err != nil { if err := m.deleteMatchingLines(chain, "", grepPatterns...); err != nil {
return fmt.Errorf("failed to delete old FORWARD rules: %w", err) return fmt.Errorf("failed to delete old FORWARD rules: %w", err)
} }
// Build iptables args
args := []string{"-w", "-I", chain, "-p", proto} args := []string{"-w", "-I", chain, "-p", proto}
if sourceIP != "" { if sourceIP != "" {
args = append(args, "-s", sourceIP) args = append(args, "-s", sourceIP)
@@ -326,7 +332,6 @@ func (m *Manager) DeleteForwardAccept(chain, comment string) error {
// InsertPreroutingRuleInContainer inserts a DNAT PREROUTING rule inside a container namespace // InsertPreroutingRuleInContainer inserts a DNAT PREROUTING rule inside a container namespace
func (m *Manager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error { func (m *Manager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
// Delete existing first
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment} patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
if err := m.deleteMatchingLinesInContainer(pid, "nat", "PREROUTING", patterns...); err != nil { if err := m.deleteMatchingLinesInContainer(pid, "nat", "PREROUTING", patterns...); err != nil {
return fmt.Errorf("failed to delete old container PREROUTING rules: %w", err) return fmt.Errorf("failed to delete old container PREROUTING rules: %w", err)
+209
View File
@@ -0,0 +1,209 @@
package mock
import (
"context"
"time"
"github.com/docker/docker/api/types"
"firewall_containers/network-go/config"
"firewall_containers/network-go/docker"
"firewall_containers/network-go/iptables"
)
// Compile-time interface conformance checks
var _ docker.DockerAPI = (*MockDockerClient)(nil)
var _ iptables.IPTablesAPI = (*MockIPTablesManager)(nil)
// MockDockerClient implements docker.DockerAPI for testing
type MockDockerClient struct {
EnsureNetworkCalled bool
EnsureNetworkCfg config.NetworkConfig
EnsureNetworkErr error
ConnectContainerCalled bool
ConnectContainerName string
ConnectContainerNetwork string
ConnectContainerIP string
ConnectContainerErr error
WaitForRunningCalled bool
WaitForRunningName string
GetContainerPIDCalled bool
GetContainerPIDName string
GetContainerPIDResult int
GetContainerPIDErr error
AddRouteCalled bool
AddRouteContainer string
AddRouteNetwork string
AddRouteGateway string
AddRouteErr error
InspectContainerErr error
RemoveNetworkErr error
DisconnectContainerErr error
}
func (m *MockDockerClient) Close() error { return nil }
func (m *MockDockerClient) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
m.EnsureNetworkCalled = true
m.EnsureNetworkCfg = netCfg
return m.EnsureNetworkErr
}
func (m *MockDockerClient) RemoveNetwork(ctx context.Context, networkName string) error {
return m.RemoveNetworkErr
}
func (m *MockDockerClient) ConnectContainer(ctx context.Context, containerName, networkName, ip string) error {
m.ConnectContainerCalled = true
m.ConnectContainerName = containerName
m.ConnectContainerNetwork = networkName
m.ConnectContainerIP = ip
return m.ConnectContainerErr
}
func (m *MockDockerClient) DisconnectContainer(ctx context.Context, containerName, networkName string) error {
return m.DisconnectContainerErr
}
func (m *MockDockerClient) InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error) {
return nil, m.InspectContainerErr
}
func (m *MockDockerClient) WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error {
m.WaitForRunningCalled = true
m.WaitForRunningName = containerName
return nil
}
func (m *MockDockerClient) GetContainerPID(ctx context.Context, containerName string) (int, error) {
m.GetContainerPIDCalled = true
m.GetContainerPIDName = containerName
return m.GetContainerPIDResult, m.GetContainerPIDErr
}
func (m *MockDockerClient) AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error {
m.AddRouteCalled = true
m.AddRouteContainer = containerName
m.AddRouteNetwork = network
m.AddRouteGateway = gateway
return m.AddRouteErr
}
// MockIPTablesManager implements iptables.IPTablesAPI for testing
type MockIPTablesManager struct {
BinaryResult string
EnsureIPForwardCalled bool
EnsureIPForwardErr error
EnsureEstablishedRelatedCalled bool
EnsureEstablishedRelatedChain string
EnsureEstablishedRelatedErr error
InsertPreroutingRuleCalled bool
InsertPreroutingRuleArgs []string
InsertPreroutingRuleErr error
InsertPreroutingRuleOnInterfaceCalled bool
InsertPreroutingRuleOnInterfaceArgs []string
InsertPreroutingRuleOnInterfaceErr error
InsertPostroutingMasqueradeCalled bool
InsertPostroutingMasqueradeArgs []string
InsertPostroutingMasqueradeErr error
InsertForwardAcceptCalled bool
InsertForwardAcceptChain string
InsertForwardAcceptSourceIP string
InsertForwardAcceptTargetIP string
InsertForwardAcceptProto string
InsertForwardAcceptSourcePort string
InsertForwardAcceptTargetPort string
InsertForwardAcceptComment string
InsertForwardAcceptErr error
InsertPreroutingRuleInContainerCalled bool
InsertPreroutingRuleInContainerPID int
InsertPreroutingRuleInContainerArgs []string
InsertPreroutingRuleInContainerErr error
InsertPostroutingMasqueradeInContainerCalled bool
InsertPostroutingMasqueradeInContainerErr error
DeleteForwardAcceptErr error
DeleteLineErr error
}
func (m *MockIPTablesManager) Binary() string {
if m.BinaryResult == "" {
return "/usr/sbin/iptables"
}
return m.BinaryResult
}
func (m *MockIPTablesManager) EnsureIPForward() error {
m.EnsureIPForwardCalled = true
return m.EnsureIPForwardErr
}
func (m *MockIPTablesManager) EnsureEstablishedRelated(chain string) error {
m.EnsureEstablishedRelatedCalled = true
m.EnsureEstablishedRelatedChain = chain
return m.EnsureEstablishedRelatedErr
}
func (m *MockIPTablesManager) DeleteLine(chain string, lineNum string) error {
return m.DeleteLineErr
}
func (m *MockIPTablesManager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
m.InsertPreroutingRuleCalled = true
m.InsertPreroutingRuleArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment}
return m.InsertPreroutingRuleErr
}
func (m *MockIPTablesManager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error {
m.InsertPreroutingRuleOnInterfaceCalled = true
m.InsertPreroutingRuleOnInterfaceArgs = []string{iface, proto, sourcePort, targetIP, targetPort, comment}
return m.InsertPreroutingRuleOnInterfaceErr
}
func (m *MockIPTablesManager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
m.InsertPostroutingMasqueradeCalled = true
m.InsertPostroutingMasqueradeArgs = []string{sourceCIDR, proto, sourcePort, comment}
return m.InsertPostroutingMasqueradeErr
}
func (m *MockIPTablesManager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error {
return nil
}
func (m *MockIPTablesManager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
m.InsertForwardAcceptCalled = true
m.InsertForwardAcceptChain = chain
m.InsertForwardAcceptSourceIP = sourceIP
m.InsertForwardAcceptTargetIP = targetIP
m.InsertForwardAcceptProto = proto
m.InsertForwardAcceptSourcePort = sourcePort
m.InsertForwardAcceptTargetPort = targetPort
m.InsertForwardAcceptComment = comment
return m.InsertForwardAcceptErr
}
func (m *MockIPTablesManager) DeleteForwardAccept(chain, comment string) error {
return m.DeleteForwardAcceptErr
}
func (m *MockIPTablesManager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
m.InsertPreroutingRuleInContainerCalled = true
m.InsertPreroutingRuleInContainerPID = pid
m.InsertPreroutingRuleInContainerArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment}
return m.InsertPreroutingRuleInContainerErr
}
func (m *MockIPTablesManager) InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error {
m.InsertPostroutingMasqueradeInContainerCalled = true
return m.InsertPostroutingMasqueradeInContainerErr
}
+118
View File
@@ -0,0 +1,118 @@
package resolver
import (
"testing"
"firewall_containers/network-go/config"
)
func makeTestConfig() *config.NetworksConfig {
return &config.NetworksConfig{
Networks: map[string]config.NetworkConfig{
"test-net": {NetworkName: "test-net", Subnet: "172.18.103.0/24", Gateway: "172.18.103.1"},
},
IPs: map[string]config.IPConfig{
"172.18.103.2": {IP: "172.18.103.2", ContainerName: "smarthostloadbalancer", Selector: "smarthostloadbalancer", ServiceName: "smarthost-proxy"},
"172.18.104.2": {IP: "172.18.104.2", ContainerName: "smarthostbackend-1", Selector: "smarthostbackend-1", ServiceName: "smarthost-proxy"},
},
Policies: []config.PolicyConfig{},
}
}
func TestResolveByContainerName(t *testing.T) {
cfg := makeTestConfig()
r := NewResolver(cfg)
ips := r.Resolve("smarthostloadbalancer")
if len(ips) != 1 {
t.Fatalf("expected 1 IP, got %d", len(ips))
}
if ips[0] != "172.18.103.2" {
t.Errorf("expected 172.18.103.2, got %s", ips[0])
}
}
func TestResolveBySelector(t *testing.T) {
cfg := makeTestConfig()
r := NewResolver(cfg)
ips := r.Resolve("smarthostbackend-1")
if len(ips) != 1 {
t.Fatalf("expected 1 IP, got %d", len(ips))
}
if ips[0] != "172.18.104.2" {
t.Errorf("expected 172.18.104.2, got %s", ips[0])
}
}
func TestResolveNotFound(t *testing.T) {
cfg := makeTestConfig()
r := NewResolver(cfg)
ips := r.Resolve("nonexistent-container")
if len(ips) != 0 {
t.Errorf("expected 0 IPs, got %d", len(ips))
}
}
func TestResolveNilConfig(t *testing.T) {
r := NewResolver(nil)
ips := r.Resolve("anything")
if len(ips) != 0 {
t.Errorf("expected 0 IPs with nil config, got %d", len(ips))
}
}
func TestSetConfig(t *testing.T) {
r := NewResolver(nil)
// Initially nil config
ips := r.Resolve("test")
if len(ips) != 0 {
t.Errorf("expected 0 IPs with nil config, got %d", len(ips))
}
// Update with valid config
cfg := makeTestConfig()
r.SetConfig(cfg)
ips = r.Resolve("smarthostloadbalancer")
if len(ips) != 1 {
t.Fatalf("expected 1 IP after SetConfig, got %d", len(ips))
}
if ips[0] != "172.18.103.2" {
t.Errorf("expected 172.18.103.2, got %s", ips[0])
}
}
func TestResolveReproducible(t *testing.T) {
cfg := makeTestConfig()
r := NewResolver(cfg)
ips1 := r.Resolve("smarthostloadbalancer")
ips2 := r.Resolve("smarthostloadbalancer")
if len(ips1) != len(ips2) {
t.Errorf("reproducibility: result count mismatch %d vs %d", len(ips1), len(ips2))
}
if len(ips1) > 0 && ips1[0] != ips2[0] {
t.Errorf("reproducibility: IP mismatch %s vs %s", ips1[0], ips2[0])
}
}
func TestResolveMultipleMatch(t *testing.T) {
cfg := &config.NetworksConfig{
IPs: map[string]config.IPConfig{
"10.0.0.1": {IP: "10.0.0.1", ContainerName: "app-1", Selector: "app-1"},
"10.0.0.2": {IP: "10.0.0.2", ContainerName: "app-2", Selector: "app-2"},
},
}
r := NewResolver(cfg)
// "app-x" has a dash, so prefix "app" is extracted and matches both app-1 and app-2
ips := r.Resolve("app-x")
if len(ips) != 2 {
t.Errorf("expected 2 IPs for prefix match, got %d", len(ips))
}
}
+179
View File
@@ -0,0 +1,179 @@
package watcher
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestWatcherDetectsChange(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
changeDetected := make(chan bool, 1)
onChange := func() {
select {
case changeDetected <- true:
default:
}
}
fw := NewFileWatcher(path, 100*time.Millisecond, onChange)
fw.Start()
defer fw.Stop()
// Wait for initial hash to be computed
time.Sleep(200 * time.Millisecond)
// Modify the file
if err := os.WriteFile(path, []byte(`{"version": 2}`), 0644); err != nil {
t.Fatalf("failed to modify test file: %v", err)
}
// Wait for change detection
select {
case detected := <-changeDetected:
if !detected {
t.Error("expected change detection, got false")
}
case <-time.After(2 * time.Second):
t.Error("timeout waiting for file change detection")
}
}
func TestWatcherNoChange(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
changeDetected := make(chan bool, 1)
onChange := func() {
select {
case changeDetected <- true:
default:
}
}
fw := NewFileWatcher(path, 100*time.Millisecond, onChange)
fw.Start()
defer fw.Stop()
// Wait without modifying the file
time.Sleep(300 * time.Millisecond)
// Should not detect a change
select {
case <-changeDetected:
t.Error("unexpected change detection without file modification")
default:
// Expected: no change detected
}
}
func TestWatcherMultipleChanges(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
changeCount := 0
var mu chan struct{}
onChange := func() {
changeCount++
if mu != nil {
mu <- struct{}{}
}
}
fw := NewFileWatcher(path, 50*time.Millisecond, onChange)
fw.Start()
defer fw.Stop()
// Wait for initial hash
time.Sleep(100 * time.Millisecond)
// Make 3 changes
for i := 2; i <= 4; i++ {
if err := os.WriteFile(path, []byte(`{"v": `+string(rune('0'+i))+`}`), 0644); err != nil {
t.Fatalf("failed to modify file: %v", err)
}
time.Sleep(150 * time.Millisecond)
}
if changeCount < 2 {
t.Errorf("expected at least 2 change detections, got %d", changeCount)
}
}
func TestWatcherStop(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
fw := NewFileWatcher(path, 50*time.Millisecond, func() {})
fw.Start()
time.Sleep(100 * time.Millisecond)
fw.Stop()
// Should not panic on double stop or further operations
time.Sleep(100 * time.Millisecond)
}
func TestWatcherNonexistentFile(t *testing.T) {
onChange := func() {
t.Error("onChange should not be called for nonexistent file")
}
fw := NewFileWatcher("/nonexistent/file.json", 100*time.Millisecond, onChange)
fw.Start()
defer fw.Stop()
// Wait — should not call onChange since file doesn't exist
time.Sleep(300 * time.Millisecond)
}
func TestWatcherReproducible(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.json")
if err := os.WriteFile(path, []byte(`{"data": "value"}`), 0644); err != nil {
t.Fatalf("failed to write test file: %v", err)
}
// Create two watchers on the same file
var detected1, detected2 bool
fw1 := NewFileWatcher(path, 50*time.Millisecond, func() { detected1 = true })
fw2 := NewFileWatcher(path, 50*time.Millisecond, func() { detected2 = true })
fw1.Start()
fw2.Start()
defer fw1.Stop()
defer fw2.Stop()
time.Sleep(100 * time.Millisecond)
// Modify file
if err := os.WriteFile(path, []byte(`{"data": "changed"}`), 0644); err != nil {
t.Fatalf("failed to modify file: %v", err)
}
time.Sleep(300 * time.Millisecond)
if !detected1 {
t.Error("watcher 1 did not detect change")
}
if !detected2 {
t.Error("watcher 2 did not detect change")
}
}