From fcda599ec7d67a103fec939e231cd5b48f413475 Mon Sep 17 00:00:00 2001 From: gyurix Date: Mon, 8 Jun 2026 17:02:13 +0200 Subject: [PATCH] added test go implementation --- network-go/config/config_test.go | 231 +++++++++++++++++++ network-go/docker/docker.go | 25 +- network-go/firewall/firewall.go | 6 +- network-go/firewall/firewall_test.go | 333 +++++++++++++++++++++++++++ network-go/implementation.md | 20 -- network-go/iptables/iptables.go | 35 +-- network-go/mock/mock.go | 209 +++++++++++++++++ network-go/resolver/resolver_test.go | 118 ++++++++++ network-go/watcher/watcher_test.go | 179 ++++++++++++++ 9 files changed, 1112 insertions(+), 44 deletions(-) create mode 100644 network-go/config/config_test.go create mode 100644 network-go/firewall/firewall_test.go create mode 100644 network-go/mock/mock.go create mode 100644 network-go/resolver/resolver_test.go create mode 100644 network-go/watcher/watcher_test.go diff --git a/network-go/config/config_test.go b/network-go/config/config_test.go new file mode 100644 index 0000000..0172cdb --- /dev/null +++ b/network-go/config/config_test.go @@ -0,0 +1,231 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +const testJSON = `{ + "networks": { + "smarthost-loadbalancer": { + "network_name": "smarthost-loadbalancer", + "subnet": "172.18.103.0/24", + "gateway": "172.18.103.1" + }, + "smarthost_backend-1": { + "network_name": "smarthost_backend-1", + "subnet": "172.18.104.0/24", + "gateway": "172.18.104.1" + } + }, + "ips": { + "172.18.103.2": { + "ip": "172.18.103.2", + "container_name": "smarthostloadbalancer", + "selector": "smarthostloadbalancer", + "service_name": "smarthost-proxy" + }, + "172.18.104.2": { + "ip": "172.18.104.2", + "container_name": "smarthostbackend-1", + "selector": "smarthostbackend-1", + "service_name": "smarthost-proxy" + } + }, + "policies": [ + { + "service_name": "smarthost-proxy", + "container_name": "smarthost_loadbalancer", + "selector": "smarthostloadbalancer", + "from": "publicbackend", + "port": 80, + "proto": "tcp" + }, + { + "service_name": "smarthost-proxy", + "container_name": "smarthost_loadbalancer", + "selector": "smarthostloadbalancer", + "name": "wireguardproxy", + "iface": "wg0", + "nat": "dnat", + "to": "smarthostloadbalancer", + "port": 80, + "proto": "tcp" + } + ] +}` + +func TestLoad(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "networks.json") + if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load() returned error: %v", err) + } + + if len(cfg.Networks) != 2 { + t.Errorf("expected 2 networks, got %d", len(cfg.Networks)) + } + + if cfg.Networks["smarthost-loadbalancer"].Subnet != "172.18.103.0/24" { + t.Errorf("unexpected subnet: %s", cfg.Networks["smarthost-loadbalancer"].Subnet) + } + + if len(cfg.IPs) != 2 { + t.Errorf("expected 2 IPs, got %d", len(cfg.IPs)) + } + + if cfg.IPs["172.18.103.2"].ContainerName != "smarthostloadbalancer" { + t.Errorf("unexpected container name: %s", cfg.IPs["172.18.103.2"].ContainerName) + } + + if len(cfg.Policies) != 2 { + t.Errorf("expected 2 policies, got %d", len(cfg.Policies)) + } + + if cfg.Policies[0].From != "publicbackend" { + t.Errorf("unexpected from: %s", cfg.Policies[0].From) + } + + if cfg.Policies[1].Nat != "dnat" { + t.Errorf("unexpected nat: %s", cfg.Policies[1].Nat) + } +} + +func TestLoadFileNotFound(t *testing.T) { + _, err := Load("/nonexistent/path/networks.json") + if err == nil { + t.Error("expected error for nonexistent file, got nil") + } +} + +func TestLoadInvalidJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.json") + if err := os.WriteFile(path, []byte("{invalid json"), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Error("expected error for invalid JSON, got nil") + } +} + +func TestToCIDR(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"172.18.103.0", "172.18.103.0/24"}, + {"172.18.103.2", "172.18.103.2"}, + {"172.18.103.0/24", "172.18.103.0/24"}, + {"10.0.0.0", "10.0.0.0/24"}, + {"invalid", "invalid"}, + } + + for _, tt := range tests { + got := ToCIDR(tt.input) + if got != tt.want { + t.Errorf("ToCIDR(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestNetworkPrefix(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"172.18.103.2", "172.18.103.0/24"}, + {"10.0.0.1", "10.0.0.0/24"}, + {"invalid", "invalid"}, + } + + for _, tt := range tests { + got := NetworkPrefix(tt.input) + if got != tt.want { + t.Errorf("NetworkPrefix(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestIsIP(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"172.18.103.2", true}, + {"10.0.0.0", true}, + {"255.255.255.255", true}, + {"publicbackend", false}, + {"172.18.103.0/24", false}, + {"", false}, + } + + for _, tt := range tests { + got := IsIP(tt.input) + if got != tt.want { + t.Errorf("IsIP(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +func TestNetworkConfigParseCIDR(t *testing.T) { + nc := NetworkConfig{Subnet: "172.18.103.0/24"} + ipNet, err := nc.ParseCIDR() + if err != nil { + t.Fatalf("ParseCIDR() returned error: %v", err) + } + if ipNet.String() != "172.18.103.0/24" { + t.Errorf("ParseCIDR() = %s, want 172.18.103.0/24", ipNet.String()) + } + + nc2 := NetworkConfig{Subnet: "invalid"} + _, err = nc2.ParseCIDR() + if err == nil { + t.Error("expected error for invalid subnet, got nil") + } +} + +func TestLoadReproducible(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "networks.json") + if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + // Load twice and compare + cfg1, err := Load(path) + if err != nil { + t.Fatalf("first Load() error: %v", err) + } + + cfg2, err := Load(path) + if err != nil { + t.Fatalf("second Load() error: %v", err) + } + + // Verify reproducibility + if len(cfg1.Networks) != len(cfg2.Networks) { + t.Errorf("reproducibility: network count mismatch %d vs %d", len(cfg1.Networks), len(cfg2.Networks)) + } + if len(cfg1.IPs) != len(cfg2.IPs) { + t.Errorf("reproducibility: IP count mismatch %d vs %d", len(cfg1.IPs), len(cfg2.IPs)) + } + if len(cfg1.Policies) != len(cfg2.Policies) { + t.Errorf("reproducibility: policy count mismatch %d vs %d", len(cfg1.Policies), len(cfg2.Policies)) + } + + for name, net1 := range cfg1.Networks { + net2 := cfg2.Networks[name] + if net1.Subnet != net2.Subnet || net1.Gateway != net2.Gateway { + t.Errorf("reproducibility: network %s mismatch", name) + } + } +} \ No newline at end of file diff --git a/network-go/docker/docker.go b/network-go/docker/docker.go index b376729..b9c72ac 100644 --- a/network-go/docker/docker.go +++ b/network-go/docker/docker.go @@ -15,11 +15,27 @@ import ( "firewall_containers/network-go/config" ) +// DockerAPI defines the interface for Docker operations, enabling mock implementations for testing +type DockerAPI interface { + Close() error + EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error + RemoveNetwork(ctx context.Context, networkName string) error + ConnectContainer(ctx context.Context, containerName, networkName, ip string) error + DisconnectContainer(ctx context.Context, containerName, networkName string) error + InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error) + WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error + GetContainerPID(ctx context.Context, containerName string) (int, error) + AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error +} + // Client wraps the Docker SDK client type Client struct { cli *client.Client } +// Ensure Client implements DockerAPI +var _ DockerAPI = (*Client)(nil) + // NewClient creates a new Docker client func NewClient() (*Client, error) { cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) @@ -36,7 +52,6 @@ func (c *Client) Close() error { // EnsureNetwork creates a Docker network if it does not already exist func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error { - // Check if network already exists existingNetworks, err := c.cli.NetworkList(ctx, network.ListOptions{ Filters: filters.NewArgs(filters.Arg("name", netCfg.NetworkName)), }) @@ -46,12 +61,10 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) for _, n := range existingNetworks { if n.Name == netCfg.NetworkName { - // Network already exists, skip creation return nil } } - // Parse subnet and gateway _, ipNet, err := net.ParseCIDR(netCfg.Subnet) if err != nil { return fmt.Errorf("failed to parse subnet %s: %w", netCfg.Subnet, err) @@ -62,7 +75,6 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) return fmt.Errorf("failed to parse gateway IP %s", netCfg.Gateway) } - // Create the network createOpts := network.CreateOptions{ Driver: "bridge", IPAM: &network.IPAM{ @@ -84,7 +96,7 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) return fmt.Errorf("failed to create network %s: %w", netCfg.NetworkName, err) } - _ = resp // response contains ID and warnings + _ = resp return nil } @@ -154,6 +166,7 @@ func (c *Client) WaitForContainerRunning(ctx context.Context, containerName stri } } +// GetContainerPID returns the PID of a container for nsenter operations func (c *Client) GetContainerPID(ctx context.Context, containerName string) (int, error) { cont, err := c.cli.ContainerInspect(ctx, containerName) if err != nil { @@ -184,4 +197,4 @@ func (c *Client) AddRouteInContainer(ctx context.Context, containerName, network return fmt.Errorf("failed to add route in container %s: %w\noutput: %s", containerName, err, string(output)) } return nil -} +} \ No newline at end of file diff --git a/network-go/firewall/firewall.go b/network-go/firewall/firewall.go index fc3d2e3..2155529 100644 --- a/network-go/firewall/firewall.go +++ b/network-go/firewall/firewall.go @@ -16,14 +16,14 @@ import ( // Orchestrator reconciles the networks.json configuration into Docker networks // and iptables firewall rules type Orchestrator struct { - dockerClient *docker.Client - iptablesMgr *iptables.Manager + dockerClient docker.DockerAPI + iptablesMgr iptables.IPTablesAPI resolver *resolver.Resolver debug bool } // NewOrchestrator creates a new firewall orchestrator -func NewOrchestrator(dockerClient *docker.Client, iptablesMgr *iptables.Manager, cfg *config.NetworksConfig) *Orchestrator { +func NewOrchestrator(dockerClient docker.DockerAPI, iptablesMgr iptables.IPTablesAPI, cfg *config.NetworksConfig) *Orchestrator { return &Orchestrator{ dockerClient: dockerClient, iptablesMgr: iptablesMgr, diff --git a/network-go/firewall/firewall_test.go b/network-go/firewall/firewall_test.go new file mode 100644 index 0000000..2482dde --- /dev/null +++ b/network-go/firewall/firewall_test.go @@ -0,0 +1,333 @@ +package firewall + +import ( + "context" + "testing" + + "firewall_containers/network-go/config" + "firewall_containers/network-go/mock" +) + +func testConfig() *config.NetworksConfig { + return &config.NetworksConfig{ + Networks: map[string]config.NetworkConfig{ + "smarthost-loadbalancer": { + NetworkName: "smarthost-loadbalancer", + Subnet: "172.18.103.0/24", + Gateway: "172.18.103.1", + }, + "smarthost_backend-1": { + NetworkName: "smarthost_backend-1", + Subnet: "172.18.104.0/24", + Gateway: "172.18.104.1", + }, + }, + IPs: map[string]config.IPConfig{ + "172.18.103.2": { + IP: "172.18.103.2", + ContainerName: "smarthostloadbalancer", + Selector: "smarthostloadbalancer", + ServiceName: "smarthost-proxy", + }, + "172.18.104.2": { + IP: "172.18.104.2", + ContainerName: "smarthostbackend-1", + Selector: "smarthostbackend-1", + ServiceName: "smarthost-proxy", + }, + }, + Policies: []config.PolicyConfig{}, + } +} + +func TestReconcileAllCreatesNetworks(t *testing.T) { + cfg := testConfig() + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.ReconcileAll(ctx, cfg) + + if !docker.EnsureNetworkCalled { + t.Error("EnsureNetwork was not called") + } +} + +func TestReconcileAllConnectsContainers(t *testing.T) { + cfg := testConfig() + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.ReconcileAll(ctx, cfg) + + if !docker.ConnectContainerCalled { + t.Error("ConnectContainer was not called") + } +} + +func TestReconcileAllEnablesIPForward(t *testing.T) { + cfg := testConfig() + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.ReconcileAll(ctx, cfg) + + if !iptables.EnsureIPForwardCalled { + t.Error("EnsureIPForward was not called") + } +} + +func TestReconcilePoliciesForwardRule(t *testing.T) { + cfg := testConfig() + cfg.Policies = []config.PolicyConfig{ + { + ServiceName: "smarthost-proxy", + ContainerName: "smarthostloadbalancer", + Selector: "smarthostloadbalancer", + From: "publicbackend", + Port: 80, + Proto: "tcp", + }, + } + + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables"} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.reconcilePolicies(ctx, cfg) + + // "publicbackend" should be resolved to an IP from the config (if it matches) + // Since "publicbackend" is not in the IPs map, sourceIP will be empty + if !iptables.InsertForwardAcceptCalled { + t.Error("InsertForwardAccept was not called") + } + + // Should use FORWARD chain (not iptables-legacy) + if iptables.InsertForwardAcceptChain != "FORWARD" { + t.Errorf("expected FORWARD chain, got %s", iptables.InsertForwardAcceptChain) + } +} + +func TestReconcilePoliciesForwardRuleWithLegacy(t *testing.T) { + cfg := testConfig() + cfg.Policies = []config.PolicyConfig{ + { + ServiceName: "smarthost-proxy", + From: "publicbackend", + Port: 80, + Proto: "tcp", + }, + } + + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables-legacy"} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.reconcilePolicies(ctx, cfg) + + if !iptables.InsertForwardAcceptCalled { + t.Error("InsertForwardAccept was not called") + } + + // Should use DOCKER-USER chain when iptables-legacy + if iptables.InsertForwardAcceptChain != "DOCKER-USER" { + t.Errorf("expected DOCKER-USER chain, got %s", iptables.InsertForwardAcceptChain) + } +} + +func TestReconcilePoliciesDNATWithInterface(t *testing.T) { + cfg := testConfig() + cfg.Policies = []config.PolicyConfig{ + { + ServiceName: "smarthost-proxy", + Name: "wireguardproxy", + Selector: "smarthostloadbalancer", + Iface: "wg0", + Nat: "dnat", + To: "smarthostloadbalancer", + Port: 80, + Proto: "tcp", + }, + } + + docker := &mock.MockDockerClient{ + GetContainerPIDResult: 0, // simulate no PID available + GetContainerPIDErr: assertError("container not running"), + } + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.reconcilePolicies(ctx, cfg) + + // When GetContainerPID fails, should fall back to interface-based rule + if !iptables.InsertPreroutingRuleOnInterfaceCalled { + t.Error("InsertPreroutingRuleOnInterface was not called (should fall back from nsenter)") + } + + if len(iptables.InsertPreroutingRuleOnInterfaceArgs) > 0 { + iface := iptables.InsertPreroutingRuleOnInterfaceArgs[0] + if iface != "wg0" { + t.Errorf("expected interface wg0, got %s", iface) + } + } +} + +func TestReconcilePoliciesDNATWithContainerPID(t *testing.T) { + cfg := testConfig() + cfg.Policies = []config.PolicyConfig{ + { + ServiceName: "smarthost-proxy", + Name: "wireguardproxy", + Selector: "smarthostloadbalancer", + Nat: "dnat", + To: "smarthostloadbalancer", + Port: 80, + Proto: "tcp", + }, + } + + docker := &mock.MockDockerClient{ + GetContainerPIDResult: 1234, + GetContainerPIDErr: nil, + } + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.reconcilePolicies(ctx, cfg) + + if !docker.GetContainerPIDCalled { + t.Error("GetContainerPID was not called") + } + + if !iptables.InsertPreroutingRuleInContainerCalled { + t.Error("InsertPreroutingRuleInContainer was not called") + } + + if iptables.InsertPreroutingRuleInContainerPID != 1234 { + t.Errorf("expected PID 1234, got %d", iptables.InsertPreroutingRuleInContainerPID) + } +} + +func TestReconcilePoliciesUnresolvedTarget(t *testing.T) { + cfg := testConfig() + cfg.Policies = []config.PolicyConfig{ + { + ServiceName: "test", + Nat: "dnat", + Selector: "container1", + To: "nonexistent-target", + Port: 80, + }, + } + + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + ctx := context.Background() + orch.reconcilePolicies(ctx, cfg) + + // Should not call GetContainerPID when target can't be resolved + if docker.GetContainerPIDCalled { + t.Error("GetContainerPID should not be called when target is unresolvable") + } +} + +func TestResolveIPDirectIP(t *testing.T) { + cfg := testConfig() + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + // Direct IP should be returned as CIDR + result := orch.resolveIP("172.18.103.2") + if result != "172.18.103.2" { + t.Errorf("expected 172.18.103.2, got %s", result) + } + + // .0 ending should be converted to /24 + result = orch.resolveIP("172.18.103.0") + if result != "172.18.103.0/24" { + t.Errorf("expected 172.18.103.0/24, got %s", result) + } +} + +func TestResolveIPFromConfig(t *testing.T) { + cfg := testConfig() + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + + // Should resolve container name from config + result := orch.resolveIP("smarthostloadbalancer") + if result != "172.18.103.2" { + t.Errorf("expected 172.18.103.2, got %s", result) + } +} + +func TestFindNetworkForIP(t *testing.T) { + cfg := testConfig() + + tests := []struct { + ip string + want string + }{ + {"172.18.103.5", "smarthost-loadbalancer"}, + {"172.18.104.5", "smarthost_backend-1"}, + {"10.0.0.1", ""}, + {"", ""}, + } + + for _, tt := range tests { + got := findNetworkForIP(cfg, tt.ip) + if got != tt.want { + t.Errorf("findNetworkForIP(%q) = %q, want %q", tt.ip, got, tt.want) + } + } +} + +func TestReconcileAllReproducible(t *testing.T) { + cfg := testConfig() + cfg.Policies = []config.PolicyConfig{ + { + ServiceName: "test-svc", + From: "publicbackend", + Port: 80, + Proto: "tcp", + }, + } + + // Run reconciliation twice with separate mocks + for i := 0; i < 2; i++ { + docker := &mock.MockDockerClient{} + iptables := &mock.MockIPTablesManager{} + orch := NewOrchestrator(docker, iptables, cfg) + ctx := context.Background() + orch.ReconcileAll(ctx, cfg) + + if !docker.EnsureNetworkCalled { + t.Errorf("run %d: EnsureNetwork not called", i) + } + if !iptables.InsertForwardAcceptCalled { + t.Errorf("run %d: InsertForwardAccept not called", i) + } + } +} + +// assertError is a helper to create a simple error for tests +type simpleErr struct{ msg string } + +func (e simpleErr) Error() string { return e.msg } + +func assertError(msg string) error { + return simpleErr{msg: msg} +} \ No newline at end of file diff --git a/network-go/implementation.md b/network-go/implementation.md index 1b8f248..6e2bb62 100644 --- a/network-go/implementation.md +++ b/network-go/implementation.md @@ -88,26 +88,6 @@ Without these, the program cannot: - Insert PREROUTING/POSTROUTING rules inside other containers - Add routes to container network namespaces -### Minimal Docker Compose Example - -```yaml -version: "3.8" -services: - network-go: - build: ./network-go - network_mode: host - pid: "host" - cap_add: - - NET_ADMIN - - SYS_ADMIN - volumes: - - /var/run/docker.sock:/var/run/docker.sock - - /etc/user/config:/etc/user/config - environment: - - WATCH_PERIOD_SECONDS=30 - - DEBUG=false -``` - ## Configuration — `/etc/user/config/networks.json` ```json diff --git a/network-go/iptables/iptables.go b/network-go/iptables/iptables.go index c7d2a3c..e69692f 100644 --- a/network-go/iptables/iptables.go +++ b/network-go/iptables/iptables.go @@ -6,12 +6,31 @@ import ( "strings" ) +// IPTablesAPI defines the interface for iptables operations, enabling mock implementations for testing +type IPTablesAPI interface { + Binary() string + EnsureIPForward() error + EnsureEstablishedRelated(chain string) error + DeleteLine(chain string, lineNum string) error + InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error + InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error + InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error + InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error + InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error + DeleteForwardAccept(chain, comment string) error + InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error + InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error +} + // Manager manages iptables rules via the iptables/iptables-legacy CLI type Manager struct { - binary string // /usr/sbin/iptables or /usr/sbin/iptables-legacy + binary string debug bool } +// Ensure Manager implements IPTablesAPI +var _ IPTablesAPI = (*Manager)(nil) + // NewManager creates a new iptables manager, auto-detecting the binary func NewManager(debug bool) *Manager { m := &Manager{debug: debug} @@ -83,17 +102,14 @@ func (m *Manager) EnsureIPForward() error { } // EnsureEstablishedRelated inserts an ESTABLISHED,RELATED accept rule at the top of a chain -// if it doesn't already exist func (m *Manager) EnsureEstablishedRelated(chain string) error { checkArgs := []string{"-w", "-n", "-L", chain} cmd := exec.Command(m.binary, checkArgs...) output, err := cmd.Output() if err != nil { - // Chain may not exist, create it return nil } - // Only insert if ESTABLISHED,RELATED rule is not present if !strings.Contains(string(output), "ESTABLISHED") || !strings.Contains(string(output), "RELATED") { args := []string{"-w", "-I", chain, "-m", "state", "--state", "established,related", "-j", "ACCEPT"} return m.run(args...) @@ -114,7 +130,6 @@ func (m *Manager) DeleteLineInContainer(pid int, table, chain, lineNum string) e } // getLineNumbers returns line numbers matching certain criteria in a chain/table -// This implements the grep logic from the shell script: iptables -w --line-number -n -L $CHAIN | grep ... func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []string { args := []string{"-w", "--line-number", "-n", "-L", chain} if table != "" { @@ -150,7 +165,6 @@ func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) [] // deleteMatchingLines deletes all lines in a chain matching the given patterns func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...string) error { lines := m.getLineNumbers(chain, table, grepPatterns...) - // Reverse order (highest line first) so deletions don't shift line numbers for i := len(lines) - 1; i >= 0; i-- { if err := m.DeleteLine(chain, lines[i]); err != nil { return err @@ -161,7 +175,6 @@ func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...strin // deleteMatchingLinesInContainer deletes matching lines inside a container namespace func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, grepPatterns ...string) error { - // For container namespaces, we use a different approach: list via nsenter + grep iptPath := "/sbin/iptables-legacy" if !strings.Contains(m.binary, "legacy") { iptPath = "/sbin/iptables" @@ -192,7 +205,6 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g } } - // Delete in reverse order for i := len(matchingLines) - 1; i >= 0; i-- { if err := m.DeleteLineInContainer(pid, table, chain, matchingLines[i]); err != nil { return err @@ -203,13 +215,11 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g // InsertPreroutingRule inserts a DNAT PREROUTING rule on the host func (m *Manager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error { - // First, delete existing matching rules patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment} if err := m.deleteMatchingLines("PREROUTING", "nat", patterns...); err != nil { return fmt.Errorf("failed to delete old PREROUTING rules: %w", err) } - // Insert the new rule args := []string{ "-w", "-t", "nat", "-I", "PREROUTING", "-d", sourceIP, @@ -236,7 +246,6 @@ func (m *Manager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targ // InsertPostroutingMasquerade inserts a MASQUERADE POSTROUTING rule on the host func (m *Manager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error { - // Delete existing matching rules first patterns := []string{"MASQUERADE", comment, sourceCIDR, sourcePort} if err := m.deleteMatchingLines("POSTROUTING", "nat", patterns...); err != nil { return fmt.Errorf("failed to delete old POSTROUTING rules: %w", err) @@ -273,7 +282,6 @@ func (m *Manager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, target // InsertForwardAccept inserts a FORWARD ACCEPT rule on the host func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error { - // Build grep patterns to match existing rules var grepPatterns []string grepPatterns = append(grepPatterns, proto) if sourceIP != "" { @@ -289,12 +297,10 @@ func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePo grepPatterns = append(grepPatterns, targetPort) } - // Delete old matching rules if err := m.deleteMatchingLines(chain, "", grepPatterns...); err != nil { return fmt.Errorf("failed to delete old FORWARD rules: %w", err) } - // Build iptables args args := []string{"-w", "-I", chain, "-p", proto} if sourceIP != "" { args = append(args, "-s", sourceIP) @@ -326,7 +332,6 @@ func (m *Manager) DeleteForwardAccept(chain, comment string) error { // InsertPreroutingRuleInContainer inserts a DNAT PREROUTING rule inside a container namespace func (m *Manager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error { - // Delete existing first patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment} if err := m.deleteMatchingLinesInContainer(pid, "nat", "PREROUTING", patterns...); err != nil { return fmt.Errorf("failed to delete old container PREROUTING rules: %w", err) diff --git a/network-go/mock/mock.go b/network-go/mock/mock.go new file mode 100644 index 0000000..dee35dc --- /dev/null +++ b/network-go/mock/mock.go @@ -0,0 +1,209 @@ +package mock + +import ( + "context" + "time" + + "github.com/docker/docker/api/types" + + "firewall_containers/network-go/config" + "firewall_containers/network-go/docker" + "firewall_containers/network-go/iptables" +) + +// Compile-time interface conformance checks +var _ docker.DockerAPI = (*MockDockerClient)(nil) +var _ iptables.IPTablesAPI = (*MockIPTablesManager)(nil) + +// MockDockerClient implements docker.DockerAPI for testing +type MockDockerClient struct { + EnsureNetworkCalled bool + EnsureNetworkCfg config.NetworkConfig + EnsureNetworkErr error + + ConnectContainerCalled bool + ConnectContainerName string + ConnectContainerNetwork string + ConnectContainerIP string + ConnectContainerErr error + + WaitForRunningCalled bool + WaitForRunningName string + + GetContainerPIDCalled bool + GetContainerPIDName string + GetContainerPIDResult int + GetContainerPIDErr error + + AddRouteCalled bool + AddRouteContainer string + AddRouteNetwork string + AddRouteGateway string + AddRouteErr error + + InspectContainerErr error + RemoveNetworkErr error + DisconnectContainerErr error +} + +func (m *MockDockerClient) Close() error { return nil } + +func (m *MockDockerClient) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error { + m.EnsureNetworkCalled = true + m.EnsureNetworkCfg = netCfg + return m.EnsureNetworkErr +} + +func (m *MockDockerClient) RemoveNetwork(ctx context.Context, networkName string) error { + return m.RemoveNetworkErr +} + +func (m *MockDockerClient) ConnectContainer(ctx context.Context, containerName, networkName, ip string) error { + m.ConnectContainerCalled = true + m.ConnectContainerName = containerName + m.ConnectContainerNetwork = networkName + m.ConnectContainerIP = ip + return m.ConnectContainerErr +} + +func (m *MockDockerClient) DisconnectContainer(ctx context.Context, containerName, networkName string) error { + return m.DisconnectContainerErr +} + +func (m *MockDockerClient) InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error) { + return nil, m.InspectContainerErr +} + +func (m *MockDockerClient) WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error { + m.WaitForRunningCalled = true + m.WaitForRunningName = containerName + return nil +} + +func (m *MockDockerClient) GetContainerPID(ctx context.Context, containerName string) (int, error) { + m.GetContainerPIDCalled = true + m.GetContainerPIDName = containerName + return m.GetContainerPIDResult, m.GetContainerPIDErr +} + +func (m *MockDockerClient) AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error { + m.AddRouteCalled = true + m.AddRouteContainer = containerName + m.AddRouteNetwork = network + m.AddRouteGateway = gateway + return m.AddRouteErr +} + +// MockIPTablesManager implements iptables.IPTablesAPI for testing +type MockIPTablesManager struct { + BinaryResult string + EnsureIPForwardCalled bool + EnsureIPForwardErr error + EnsureEstablishedRelatedCalled bool + EnsureEstablishedRelatedChain string + EnsureEstablishedRelatedErr error + + InsertPreroutingRuleCalled bool + InsertPreroutingRuleArgs []string + InsertPreroutingRuleErr error + + InsertPreroutingRuleOnInterfaceCalled bool + InsertPreroutingRuleOnInterfaceArgs []string + InsertPreroutingRuleOnInterfaceErr error + + InsertPostroutingMasqueradeCalled bool + InsertPostroutingMasqueradeArgs []string + InsertPostroutingMasqueradeErr error + + InsertForwardAcceptCalled bool + InsertForwardAcceptChain string + InsertForwardAcceptSourceIP string + InsertForwardAcceptTargetIP string + InsertForwardAcceptProto string + InsertForwardAcceptSourcePort string + InsertForwardAcceptTargetPort string + InsertForwardAcceptComment string + InsertForwardAcceptErr error + + InsertPreroutingRuleInContainerCalled bool + InsertPreroutingRuleInContainerPID int + InsertPreroutingRuleInContainerArgs []string + InsertPreroutingRuleInContainerErr error + + InsertPostroutingMasqueradeInContainerCalled bool + InsertPostroutingMasqueradeInContainerErr error + DeleteForwardAcceptErr error + DeleteLineErr error +} + +func (m *MockIPTablesManager) Binary() string { + if m.BinaryResult == "" { + return "/usr/sbin/iptables" + } + return m.BinaryResult +} + +func (m *MockIPTablesManager) EnsureIPForward() error { + m.EnsureIPForwardCalled = true + return m.EnsureIPForwardErr +} + +func (m *MockIPTablesManager) EnsureEstablishedRelated(chain string) error { + m.EnsureEstablishedRelatedCalled = true + m.EnsureEstablishedRelatedChain = chain + return m.EnsureEstablishedRelatedErr +} + +func (m *MockIPTablesManager) DeleteLine(chain string, lineNum string) error { + return m.DeleteLineErr +} + +func (m *MockIPTablesManager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error { + m.InsertPreroutingRuleCalled = true + m.InsertPreroutingRuleArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment} + return m.InsertPreroutingRuleErr +} + +func (m *MockIPTablesManager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error { + m.InsertPreroutingRuleOnInterfaceCalled = true + m.InsertPreroutingRuleOnInterfaceArgs = []string{iface, proto, sourcePort, targetIP, targetPort, comment} + return m.InsertPreroutingRuleOnInterfaceErr +} + +func (m *MockIPTablesManager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error { + m.InsertPostroutingMasqueradeCalled = true + m.InsertPostroutingMasqueradeArgs = []string{sourceCIDR, proto, sourcePort, comment} + return m.InsertPostroutingMasqueradeErr +} + +func (m *MockIPTablesManager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error { + return nil +} + +func (m *MockIPTablesManager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error { + m.InsertForwardAcceptCalled = true + m.InsertForwardAcceptChain = chain + m.InsertForwardAcceptSourceIP = sourceIP + m.InsertForwardAcceptTargetIP = targetIP + m.InsertForwardAcceptProto = proto + m.InsertForwardAcceptSourcePort = sourcePort + m.InsertForwardAcceptTargetPort = targetPort + m.InsertForwardAcceptComment = comment + return m.InsertForwardAcceptErr +} + +func (m *MockIPTablesManager) DeleteForwardAccept(chain, comment string) error { + return m.DeleteForwardAcceptErr +} + +func (m *MockIPTablesManager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error { + m.InsertPreroutingRuleInContainerCalled = true + m.InsertPreroutingRuleInContainerPID = pid + m.InsertPreroutingRuleInContainerArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment} + return m.InsertPreroutingRuleInContainerErr +} + +func (m *MockIPTablesManager) InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error { + m.InsertPostroutingMasqueradeInContainerCalled = true + return m.InsertPostroutingMasqueradeInContainerErr +} \ No newline at end of file diff --git a/network-go/resolver/resolver_test.go b/network-go/resolver/resolver_test.go new file mode 100644 index 0000000..c93f196 --- /dev/null +++ b/network-go/resolver/resolver_test.go @@ -0,0 +1,118 @@ +package resolver + +import ( + "testing" + + "firewall_containers/network-go/config" +) + +func makeTestConfig() *config.NetworksConfig { + return &config.NetworksConfig{ + Networks: map[string]config.NetworkConfig{ + "test-net": {NetworkName: "test-net", Subnet: "172.18.103.0/24", Gateway: "172.18.103.1"}, + }, + IPs: map[string]config.IPConfig{ + "172.18.103.2": {IP: "172.18.103.2", ContainerName: "smarthostloadbalancer", Selector: "smarthostloadbalancer", ServiceName: "smarthost-proxy"}, + "172.18.104.2": {IP: "172.18.104.2", ContainerName: "smarthostbackend-1", Selector: "smarthostbackend-1", ServiceName: "smarthost-proxy"}, + }, + Policies: []config.PolicyConfig{}, + } +} + +func TestResolveByContainerName(t *testing.T) { + cfg := makeTestConfig() + r := NewResolver(cfg) + + ips := r.Resolve("smarthostloadbalancer") + if len(ips) != 1 { + t.Fatalf("expected 1 IP, got %d", len(ips)) + } + if ips[0] != "172.18.103.2" { + t.Errorf("expected 172.18.103.2, got %s", ips[0]) + } +} + +func TestResolveBySelector(t *testing.T) { + cfg := makeTestConfig() + r := NewResolver(cfg) + + ips := r.Resolve("smarthostbackend-1") + if len(ips) != 1 { + t.Fatalf("expected 1 IP, got %d", len(ips)) + } + if ips[0] != "172.18.104.2" { + t.Errorf("expected 172.18.104.2, got %s", ips[0]) + } +} + +func TestResolveNotFound(t *testing.T) { + cfg := makeTestConfig() + r := NewResolver(cfg) + + ips := r.Resolve("nonexistent-container") + if len(ips) != 0 { + t.Errorf("expected 0 IPs, got %d", len(ips)) + } +} + +func TestResolveNilConfig(t *testing.T) { + r := NewResolver(nil) + + ips := r.Resolve("anything") + if len(ips) != 0 { + t.Errorf("expected 0 IPs with nil config, got %d", len(ips)) + } +} + +func TestSetConfig(t *testing.T) { + r := NewResolver(nil) + + // Initially nil config + ips := r.Resolve("test") + if len(ips) != 0 { + t.Errorf("expected 0 IPs with nil config, got %d", len(ips)) + } + + // Update with valid config + cfg := makeTestConfig() + r.SetConfig(cfg) + + ips = r.Resolve("smarthostloadbalancer") + if len(ips) != 1 { + t.Fatalf("expected 1 IP after SetConfig, got %d", len(ips)) + } + if ips[0] != "172.18.103.2" { + t.Errorf("expected 172.18.103.2, got %s", ips[0]) + } +} + +func TestResolveReproducible(t *testing.T) { + cfg := makeTestConfig() + r := NewResolver(cfg) + + ips1 := r.Resolve("smarthostloadbalancer") + ips2 := r.Resolve("smarthostloadbalancer") + + if len(ips1) != len(ips2) { + t.Errorf("reproducibility: result count mismatch %d vs %d", len(ips1), len(ips2)) + } + if len(ips1) > 0 && ips1[0] != ips2[0] { + t.Errorf("reproducibility: IP mismatch %s vs %s", ips1[0], ips2[0]) + } +} + +func TestResolveMultipleMatch(t *testing.T) { + cfg := &config.NetworksConfig{ + IPs: map[string]config.IPConfig{ + "10.0.0.1": {IP: "10.0.0.1", ContainerName: "app-1", Selector: "app-1"}, + "10.0.0.2": {IP: "10.0.0.2", ContainerName: "app-2", Selector: "app-2"}, + }, + } + r := NewResolver(cfg) + + // "app-x" has a dash, so prefix "app" is extracted and matches both app-1 and app-2 + ips := r.Resolve("app-x") + if len(ips) != 2 { + t.Errorf("expected 2 IPs for prefix match, got %d", len(ips)) + } +} diff --git a/network-go/watcher/watcher_test.go b/network-go/watcher/watcher_test.go new file mode 100644 index 0000000..7e99b41 --- /dev/null +++ b/network-go/watcher/watcher_test.go @@ -0,0 +1,179 @@ +package watcher + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestWatcherDetectsChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + changeDetected := make(chan bool, 1) + onChange := func() { + select { + case changeDetected <- true: + default: + } + } + + fw := NewFileWatcher(path, 100*time.Millisecond, onChange) + fw.Start() + defer fw.Stop() + + // Wait for initial hash to be computed + time.Sleep(200 * time.Millisecond) + + // Modify the file + if err := os.WriteFile(path, []byte(`{"version": 2}`), 0644); err != nil { + t.Fatalf("failed to modify test file: %v", err) + } + + // Wait for change detection + select { + case detected := <-changeDetected: + if !detected { + t.Error("expected change detection, got false") + } + case <-time.After(2 * time.Second): + t.Error("timeout waiting for file change detection") + } +} + +func TestWatcherNoChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + changeDetected := make(chan bool, 1) + onChange := func() { + select { + case changeDetected <- true: + default: + } + } + + fw := NewFileWatcher(path, 100*time.Millisecond, onChange) + fw.Start() + defer fw.Stop() + + // Wait without modifying the file + time.Sleep(300 * time.Millisecond) + + // Should not detect a change + select { + case <-changeDetected: + t.Error("unexpected change detection without file modification") + default: + // Expected: no change detected + } +} + +func TestWatcherMultipleChanges(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + changeCount := 0 + var mu chan struct{} + + onChange := func() { + changeCount++ + if mu != nil { + mu <- struct{}{} + } + } + + fw := NewFileWatcher(path, 50*time.Millisecond, onChange) + fw.Start() + defer fw.Stop() + + // Wait for initial hash + time.Sleep(100 * time.Millisecond) + + // Make 3 changes + for i := 2; i <= 4; i++ { + if err := os.WriteFile(path, []byte(`{"v": `+string(rune('0'+i))+`}`), 0644); err != nil { + t.Fatalf("failed to modify file: %v", err) + } + time.Sleep(150 * time.Millisecond) + } + + if changeCount < 2 { + t.Errorf("expected at least 2 change detections, got %d", changeCount) + } +} + +func TestWatcherStop(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + fw := NewFileWatcher(path, 50*time.Millisecond, func() {}) + fw.Start() + + time.Sleep(100 * time.Millisecond) + fw.Stop() + + // Should not panic on double stop or further operations + time.Sleep(100 * time.Millisecond) +} + +func TestWatcherNonexistentFile(t *testing.T) { + onChange := func() { + t.Error("onChange should not be called for nonexistent file") + } + + fw := NewFileWatcher("/nonexistent/file.json", 100*time.Millisecond, onChange) + fw.Start() + defer fw.Stop() + + // Wait — should not call onChange since file doesn't exist + time.Sleep(300 * time.Millisecond) +} + +func TestWatcherReproducible(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.json") + if err := os.WriteFile(path, []byte(`{"data": "value"}`), 0644); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + // Create two watchers on the same file + var detected1, detected2 bool + + fw1 := NewFileWatcher(path, 50*time.Millisecond, func() { detected1 = true }) + fw2 := NewFileWatcher(path, 50*time.Millisecond, func() { detected2 = true }) + + fw1.Start() + fw2.Start() + defer fw1.Stop() + defer fw2.Stop() + + time.Sleep(100 * time.Millisecond) + + // Modify file + if err := os.WriteFile(path, []byte(`{"data": "changed"}`), 0644); err != nil { + t.Fatalf("failed to modify file: %v", err) + } + + time.Sleep(300 * time.Millisecond) + + if !detected1 { + t.Error("watcher 1 did not detect change") + } + if !detected2 { + t.Error("watcher 2 did not detect change") + } +} \ No newline at end of file