This commit is contained in:
@@ -0,0 +1,231 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const testJSON = `{
|
||||||
|
"networks": {
|
||||||
|
"smarthost-loadbalancer": {
|
||||||
|
"network_name": "smarthost-loadbalancer",
|
||||||
|
"subnet": "172.18.103.0/24",
|
||||||
|
"gateway": "172.18.103.1"
|
||||||
|
},
|
||||||
|
"smarthost_backend-1": {
|
||||||
|
"network_name": "smarthost_backend-1",
|
||||||
|
"subnet": "172.18.104.0/24",
|
||||||
|
"gateway": "172.18.104.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"ips": {
|
||||||
|
"172.18.103.2": {
|
||||||
|
"ip": "172.18.103.2",
|
||||||
|
"container_name": "smarthostloadbalancer",
|
||||||
|
"selector": "smarthostloadbalancer",
|
||||||
|
"service_name": "smarthost-proxy"
|
||||||
|
},
|
||||||
|
"172.18.104.2": {
|
||||||
|
"ip": "172.18.104.2",
|
||||||
|
"container_name": "smarthostbackend-1",
|
||||||
|
"selector": "smarthostbackend-1",
|
||||||
|
"service_name": "smarthost-proxy"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"policies": [
|
||||||
|
{
|
||||||
|
"service_name": "smarthost-proxy",
|
||||||
|
"container_name": "smarthost_loadbalancer",
|
||||||
|
"selector": "smarthostloadbalancer",
|
||||||
|
"from": "publicbackend",
|
||||||
|
"port": 80,
|
||||||
|
"proto": "tcp"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"service_name": "smarthost-proxy",
|
||||||
|
"container_name": "smarthost_loadbalancer",
|
||||||
|
"selector": "smarthostloadbalancer",
|
||||||
|
"name": "wireguardproxy",
|
||||||
|
"iface": "wg0",
|
||||||
|
"nat": "dnat",
|
||||||
|
"to": "smarthostloadbalancer",
|
||||||
|
"port": 80,
|
||||||
|
"proto": "tcp"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}`
|
||||||
|
|
||||||
|
func TestLoad(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "networks.json")
|
||||||
|
if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load() returned error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Networks) != 2 {
|
||||||
|
t.Errorf("expected 2 networks, got %d", len(cfg.Networks))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Networks["smarthost-loadbalancer"].Subnet != "172.18.103.0/24" {
|
||||||
|
t.Errorf("unexpected subnet: %s", cfg.Networks["smarthost-loadbalancer"].Subnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.IPs) != 2 {
|
||||||
|
t.Errorf("expected 2 IPs, got %d", len(cfg.IPs))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.IPs["172.18.103.2"].ContainerName != "smarthostloadbalancer" {
|
||||||
|
t.Errorf("unexpected container name: %s", cfg.IPs["172.18.103.2"].ContainerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Policies) != 2 {
|
||||||
|
t.Errorf("expected 2 policies, got %d", len(cfg.Policies))
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Policies[0].From != "publicbackend" {
|
||||||
|
t.Errorf("unexpected from: %s", cfg.Policies[0].From)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Policies[1].Nat != "dnat" {
|
||||||
|
t.Errorf("unexpected nat: %s", cfg.Policies[1].Nat)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadFileNotFound(t *testing.T) {
|
||||||
|
_, err := Load("/nonexistent/path/networks.json")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for nonexistent file, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadInvalidJSON(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "bad.json")
|
||||||
|
if err := os.WriteFile(path, []byte("{invalid json"), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid JSON, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToCIDR(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"172.18.103.0", "172.18.103.0/24"},
|
||||||
|
{"172.18.103.2", "172.18.103.2"},
|
||||||
|
{"172.18.103.0/24", "172.18.103.0/24"},
|
||||||
|
{"10.0.0.0", "10.0.0.0/24"},
|
||||||
|
{"invalid", "invalid"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := ToCIDR(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("ToCIDR(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkPrefix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"172.18.103.2", "172.18.103.0/24"},
|
||||||
|
{"10.0.0.1", "10.0.0.0/24"},
|
||||||
|
{"invalid", "invalid"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := NetworkPrefix(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("NetworkPrefix(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsIP(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"172.18.103.2", true},
|
||||||
|
{"10.0.0.0", true},
|
||||||
|
{"255.255.255.255", true},
|
||||||
|
{"publicbackend", false},
|
||||||
|
{"172.18.103.0/24", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := IsIP(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsIP(%q) = %v, want %v", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkConfigParseCIDR(t *testing.T) {
|
||||||
|
nc := NetworkConfig{Subnet: "172.18.103.0/24"}
|
||||||
|
ipNet, err := nc.ParseCIDR()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ParseCIDR() returned error: %v", err)
|
||||||
|
}
|
||||||
|
if ipNet.String() != "172.18.103.0/24" {
|
||||||
|
t.Errorf("ParseCIDR() = %s, want 172.18.103.0/24", ipNet.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
nc2 := NetworkConfig{Subnet: "invalid"}
|
||||||
|
_, err = nc2.ParseCIDR()
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid subnet, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadReproducible(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "networks.json")
|
||||||
|
if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load twice and compare
|
||||||
|
cfg1, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg2, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second Load() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify reproducibility
|
||||||
|
if len(cfg1.Networks) != len(cfg2.Networks) {
|
||||||
|
t.Errorf("reproducibility: network count mismatch %d vs %d", len(cfg1.Networks), len(cfg2.Networks))
|
||||||
|
}
|
||||||
|
if len(cfg1.IPs) != len(cfg2.IPs) {
|
||||||
|
t.Errorf("reproducibility: IP count mismatch %d vs %d", len(cfg1.IPs), len(cfg2.IPs))
|
||||||
|
}
|
||||||
|
if len(cfg1.Policies) != len(cfg2.Policies) {
|
||||||
|
t.Errorf("reproducibility: policy count mismatch %d vs %d", len(cfg1.Policies), len(cfg2.Policies))
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, net1 := range cfg1.Networks {
|
||||||
|
net2 := cfg2.Networks[name]
|
||||||
|
if net1.Subnet != net2.Subnet || net1.Gateway != net2.Gateway {
|
||||||
|
t.Errorf("reproducibility: network %s mismatch", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -15,11 +15,27 @@ import (
|
|||||||
"firewall_containers/network-go/config"
|
"firewall_containers/network-go/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DockerAPI defines the interface for Docker operations, enabling mock implementations for testing
|
||||||
|
type DockerAPI interface {
|
||||||
|
Close() error
|
||||||
|
EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error
|
||||||
|
RemoveNetwork(ctx context.Context, networkName string) error
|
||||||
|
ConnectContainer(ctx context.Context, containerName, networkName, ip string) error
|
||||||
|
DisconnectContainer(ctx context.Context, containerName, networkName string) error
|
||||||
|
InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error)
|
||||||
|
WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error
|
||||||
|
GetContainerPID(ctx context.Context, containerName string) (int, error)
|
||||||
|
AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error
|
||||||
|
}
|
||||||
|
|
||||||
// Client wraps the Docker SDK client
|
// Client wraps the Docker SDK client
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cli *client.Client
|
cli *client.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure Client implements DockerAPI
|
||||||
|
var _ DockerAPI = (*Client)(nil)
|
||||||
|
|
||||||
// NewClient creates a new Docker client
|
// NewClient creates a new Docker client
|
||||||
func NewClient() (*Client, error) {
|
func NewClient() (*Client, error) {
|
||||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||||
@@ -36,7 +52,6 @@ func (c *Client) Close() error {
|
|||||||
|
|
||||||
// EnsureNetwork creates a Docker network if it does not already exist
|
// EnsureNetwork creates a Docker network if it does not already exist
|
||||||
func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
|
func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
|
||||||
// Check if network already exists
|
|
||||||
existingNetworks, err := c.cli.NetworkList(ctx, network.ListOptions{
|
existingNetworks, err := c.cli.NetworkList(ctx, network.ListOptions{
|
||||||
Filters: filters.NewArgs(filters.Arg("name", netCfg.NetworkName)),
|
Filters: filters.NewArgs(filters.Arg("name", netCfg.NetworkName)),
|
||||||
})
|
})
|
||||||
@@ -46,12 +61,10 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
|
|||||||
|
|
||||||
for _, n := range existingNetworks {
|
for _, n := range existingNetworks {
|
||||||
if n.Name == netCfg.NetworkName {
|
if n.Name == netCfg.NetworkName {
|
||||||
// Network already exists, skip creation
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse subnet and gateway
|
|
||||||
_, ipNet, err := net.ParseCIDR(netCfg.Subnet)
|
_, ipNet, err := net.ParseCIDR(netCfg.Subnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to parse subnet %s: %w", netCfg.Subnet, err)
|
return fmt.Errorf("failed to parse subnet %s: %w", netCfg.Subnet, err)
|
||||||
@@ -62,7 +75,6 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
|
|||||||
return fmt.Errorf("failed to parse gateway IP %s", netCfg.Gateway)
|
return fmt.Errorf("failed to parse gateway IP %s", netCfg.Gateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the network
|
|
||||||
createOpts := network.CreateOptions{
|
createOpts := network.CreateOptions{
|
||||||
Driver: "bridge",
|
Driver: "bridge",
|
||||||
IPAM: &network.IPAM{
|
IPAM: &network.IPAM{
|
||||||
@@ -84,7 +96,7 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
|
|||||||
return fmt.Errorf("failed to create network %s: %w", netCfg.NetworkName, err)
|
return fmt.Errorf("failed to create network %s: %w", netCfg.NetworkName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = resp // response contains ID and warnings
|
_ = resp
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,6 +166,7 @@ func (c *Client) WaitForContainerRunning(ctx context.Context, containerName stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetContainerPID returns the PID of a container for nsenter operations
|
||||||
func (c *Client) GetContainerPID(ctx context.Context, containerName string) (int, error) {
|
func (c *Client) GetContainerPID(ctx context.Context, containerName string) (int, error) {
|
||||||
cont, err := c.cli.ContainerInspect(ctx, containerName)
|
cont, err := c.cli.ContainerInspect(ctx, containerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -184,4 +197,4 @@ func (c *Client) AddRouteInContainer(ctx context.Context, containerName, network
|
|||||||
return fmt.Errorf("failed to add route in container %s: %w\noutput: %s", containerName, err, string(output))
|
return fmt.Errorf("failed to add route in container %s: %w\noutput: %s", containerName, err, string(output))
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -16,14 +16,14 @@ import (
|
|||||||
// Orchestrator reconciles the networks.json configuration into Docker networks
|
// Orchestrator reconciles the networks.json configuration into Docker networks
|
||||||
// and iptables firewall rules
|
// and iptables firewall rules
|
||||||
type Orchestrator struct {
|
type Orchestrator struct {
|
||||||
dockerClient *docker.Client
|
dockerClient docker.DockerAPI
|
||||||
iptablesMgr *iptables.Manager
|
iptablesMgr iptables.IPTablesAPI
|
||||||
resolver *resolver.Resolver
|
resolver *resolver.Resolver
|
||||||
debug bool
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOrchestrator creates a new firewall orchestrator
|
// NewOrchestrator creates a new firewall orchestrator
|
||||||
func NewOrchestrator(dockerClient *docker.Client, iptablesMgr *iptables.Manager, cfg *config.NetworksConfig) *Orchestrator {
|
func NewOrchestrator(dockerClient docker.DockerAPI, iptablesMgr iptables.IPTablesAPI, cfg *config.NetworksConfig) *Orchestrator {
|
||||||
return &Orchestrator{
|
return &Orchestrator{
|
||||||
dockerClient: dockerClient,
|
dockerClient: dockerClient,
|
||||||
iptablesMgr: iptablesMgr,
|
iptablesMgr: iptablesMgr,
|
||||||
|
|||||||
@@ -0,0 +1,333 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"firewall_containers/network-go/config"
|
||||||
|
"firewall_containers/network-go/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testConfig() *config.NetworksConfig {
|
||||||
|
return &config.NetworksConfig{
|
||||||
|
Networks: map[string]config.NetworkConfig{
|
||||||
|
"smarthost-loadbalancer": {
|
||||||
|
NetworkName: "smarthost-loadbalancer",
|
||||||
|
Subnet: "172.18.103.0/24",
|
||||||
|
Gateway: "172.18.103.1",
|
||||||
|
},
|
||||||
|
"smarthost_backend-1": {
|
||||||
|
NetworkName: "smarthost_backend-1",
|
||||||
|
Subnet: "172.18.104.0/24",
|
||||||
|
Gateway: "172.18.104.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
IPs: map[string]config.IPConfig{
|
||||||
|
"172.18.103.2": {
|
||||||
|
IP: "172.18.103.2",
|
||||||
|
ContainerName: "smarthostloadbalancer",
|
||||||
|
Selector: "smarthostloadbalancer",
|
||||||
|
ServiceName: "smarthost-proxy",
|
||||||
|
},
|
||||||
|
"172.18.104.2": {
|
||||||
|
IP: "172.18.104.2",
|
||||||
|
ContainerName: "smarthostbackend-1",
|
||||||
|
Selector: "smarthostbackend-1",
|
||||||
|
ServiceName: "smarthost-proxy",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Policies: []config.PolicyConfig{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcileAllCreatesNetworks(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.ReconcileAll(ctx, cfg)
|
||||||
|
|
||||||
|
if !docker.EnsureNetworkCalled {
|
||||||
|
t.Error("EnsureNetwork was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcileAllConnectsContainers(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.ReconcileAll(ctx, cfg)
|
||||||
|
|
||||||
|
if !docker.ConnectContainerCalled {
|
||||||
|
t.Error("ConnectContainer was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcileAllEnablesIPForward(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.ReconcileAll(ctx, cfg)
|
||||||
|
|
||||||
|
if !iptables.EnsureIPForwardCalled {
|
||||||
|
t.Error("EnsureIPForward was not called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcilePoliciesForwardRule(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Policies = []config.PolicyConfig{
|
||||||
|
{
|
||||||
|
ServiceName: "smarthost-proxy",
|
||||||
|
ContainerName: "smarthostloadbalancer",
|
||||||
|
Selector: "smarthostloadbalancer",
|
||||||
|
From: "publicbackend",
|
||||||
|
Port: 80,
|
||||||
|
Proto: "tcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables"}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.reconcilePolicies(ctx, cfg)
|
||||||
|
|
||||||
|
// "publicbackend" should be resolved to an IP from the config (if it matches)
|
||||||
|
// Since "publicbackend" is not in the IPs map, sourceIP will be empty
|
||||||
|
if !iptables.InsertForwardAcceptCalled {
|
||||||
|
t.Error("InsertForwardAccept was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should use FORWARD chain (not iptables-legacy)
|
||||||
|
if iptables.InsertForwardAcceptChain != "FORWARD" {
|
||||||
|
t.Errorf("expected FORWARD chain, got %s", iptables.InsertForwardAcceptChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcilePoliciesForwardRuleWithLegacy(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Policies = []config.PolicyConfig{
|
||||||
|
{
|
||||||
|
ServiceName: "smarthost-proxy",
|
||||||
|
From: "publicbackend",
|
||||||
|
Port: 80,
|
||||||
|
Proto: "tcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables-legacy"}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.reconcilePolicies(ctx, cfg)
|
||||||
|
|
||||||
|
if !iptables.InsertForwardAcceptCalled {
|
||||||
|
t.Error("InsertForwardAccept was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should use DOCKER-USER chain when iptables-legacy
|
||||||
|
if iptables.InsertForwardAcceptChain != "DOCKER-USER" {
|
||||||
|
t.Errorf("expected DOCKER-USER chain, got %s", iptables.InsertForwardAcceptChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcilePoliciesDNATWithInterface(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Policies = []config.PolicyConfig{
|
||||||
|
{
|
||||||
|
ServiceName: "smarthost-proxy",
|
||||||
|
Name: "wireguardproxy",
|
||||||
|
Selector: "smarthostloadbalancer",
|
||||||
|
Iface: "wg0",
|
||||||
|
Nat: "dnat",
|
||||||
|
To: "smarthostloadbalancer",
|
||||||
|
Port: 80,
|
||||||
|
Proto: "tcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
docker := &mock.MockDockerClient{
|
||||||
|
GetContainerPIDResult: 0, // simulate no PID available
|
||||||
|
GetContainerPIDErr: assertError("container not running"),
|
||||||
|
}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.reconcilePolicies(ctx, cfg)
|
||||||
|
|
||||||
|
// When GetContainerPID fails, should fall back to interface-based rule
|
||||||
|
if !iptables.InsertPreroutingRuleOnInterfaceCalled {
|
||||||
|
t.Error("InsertPreroutingRuleOnInterface was not called (should fall back from nsenter)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(iptables.InsertPreroutingRuleOnInterfaceArgs) > 0 {
|
||||||
|
iface := iptables.InsertPreroutingRuleOnInterfaceArgs[0]
|
||||||
|
if iface != "wg0" {
|
||||||
|
t.Errorf("expected interface wg0, got %s", iface)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcilePoliciesDNATWithContainerPID(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Policies = []config.PolicyConfig{
|
||||||
|
{
|
||||||
|
ServiceName: "smarthost-proxy",
|
||||||
|
Name: "wireguardproxy",
|
||||||
|
Selector: "smarthostloadbalancer",
|
||||||
|
Nat: "dnat",
|
||||||
|
To: "smarthostloadbalancer",
|
||||||
|
Port: 80,
|
||||||
|
Proto: "tcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
docker := &mock.MockDockerClient{
|
||||||
|
GetContainerPIDResult: 1234,
|
||||||
|
GetContainerPIDErr: nil,
|
||||||
|
}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.reconcilePolicies(ctx, cfg)
|
||||||
|
|
||||||
|
if !docker.GetContainerPIDCalled {
|
||||||
|
t.Error("GetContainerPID was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !iptables.InsertPreroutingRuleInContainerCalled {
|
||||||
|
t.Error("InsertPreroutingRuleInContainer was not called")
|
||||||
|
}
|
||||||
|
|
||||||
|
if iptables.InsertPreroutingRuleInContainerPID != 1234 {
|
||||||
|
t.Errorf("expected PID 1234, got %d", iptables.InsertPreroutingRuleInContainerPID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcilePoliciesUnresolvedTarget(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Policies = []config.PolicyConfig{
|
||||||
|
{
|
||||||
|
ServiceName: "test",
|
||||||
|
Nat: "dnat",
|
||||||
|
Selector: "container1",
|
||||||
|
To: "nonexistent-target",
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.reconcilePolicies(ctx, cfg)
|
||||||
|
|
||||||
|
// Should not call GetContainerPID when target can't be resolved
|
||||||
|
if docker.GetContainerPIDCalled {
|
||||||
|
t.Error("GetContainerPID should not be called when target is unresolvable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveIPDirectIP(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
// Direct IP should be returned as CIDR
|
||||||
|
result := orch.resolveIP("172.18.103.2")
|
||||||
|
if result != "172.18.103.2" {
|
||||||
|
t.Errorf("expected 172.18.103.2, got %s", result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// .0 ending should be converted to /24
|
||||||
|
result = orch.resolveIP("172.18.103.0")
|
||||||
|
if result != "172.18.103.0/24" {
|
||||||
|
t.Errorf("expected 172.18.103.0/24, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveIPFromConfig(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
|
||||||
|
// Should resolve container name from config
|
||||||
|
result := orch.resolveIP("smarthostloadbalancer")
|
||||||
|
if result != "172.18.103.2" {
|
||||||
|
t.Errorf("expected 172.18.103.2, got %s", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindNetworkForIP(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"172.18.103.5", "smarthost-loadbalancer"},
|
||||||
|
{"172.18.104.5", "smarthost_backend-1"},
|
||||||
|
{"10.0.0.1", ""},
|
||||||
|
{"", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := findNetworkForIP(cfg, tt.ip)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("findNetworkForIP(%q) = %q, want %q", tt.ip, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconcileAllReproducible(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.Policies = []config.PolicyConfig{
|
||||||
|
{
|
||||||
|
ServiceName: "test-svc",
|
||||||
|
From: "publicbackend",
|
||||||
|
Port: 80,
|
||||||
|
Proto: "tcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run reconciliation twice with separate mocks
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
docker := &mock.MockDockerClient{}
|
||||||
|
iptables := &mock.MockIPTablesManager{}
|
||||||
|
orch := NewOrchestrator(docker, iptables, cfg)
|
||||||
|
ctx := context.Background()
|
||||||
|
orch.ReconcileAll(ctx, cfg)
|
||||||
|
|
||||||
|
if !docker.EnsureNetworkCalled {
|
||||||
|
t.Errorf("run %d: EnsureNetwork not called", i)
|
||||||
|
}
|
||||||
|
if !iptables.InsertForwardAcceptCalled {
|
||||||
|
t.Errorf("run %d: InsertForwardAccept not called", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertError is a helper to create a simple error for tests
|
||||||
|
type simpleErr struct{ msg string }
|
||||||
|
|
||||||
|
func (e simpleErr) Error() string { return e.msg }
|
||||||
|
|
||||||
|
func assertError(msg string) error {
|
||||||
|
return simpleErr{msg: msg}
|
||||||
|
}
|
||||||
@@ -88,26 +88,6 @@ Without these, the program cannot:
|
|||||||
- Insert PREROUTING/POSTROUTING rules inside other containers
|
- Insert PREROUTING/POSTROUTING rules inside other containers
|
||||||
- Add routes to container network namespaces
|
- Add routes to container network namespaces
|
||||||
|
|
||||||
### Minimal Docker Compose Example
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
version: "3.8"
|
|
||||||
services:
|
|
||||||
network-go:
|
|
||||||
build: ./network-go
|
|
||||||
network_mode: host
|
|
||||||
pid: "host"
|
|
||||||
cap_add:
|
|
||||||
- NET_ADMIN
|
|
||||||
- SYS_ADMIN
|
|
||||||
volumes:
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock
|
|
||||||
- /etc/user/config:/etc/user/config
|
|
||||||
environment:
|
|
||||||
- WATCH_PERIOD_SECONDS=30
|
|
||||||
- DEBUG=false
|
|
||||||
```
|
|
||||||
|
|
||||||
## Configuration — `/etc/user/config/networks.json`
|
## Configuration — `/etc/user/config/networks.json`
|
||||||
|
|
||||||
```json
|
```json
|
||||||
|
|||||||
@@ -6,12 +6,31 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IPTablesAPI defines the interface for iptables operations, enabling mock implementations for testing
|
||||||
|
type IPTablesAPI interface {
|
||||||
|
Binary() string
|
||||||
|
EnsureIPForward() error
|
||||||
|
EnsureEstablishedRelated(chain string) error
|
||||||
|
DeleteLine(chain string, lineNum string) error
|
||||||
|
InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error
|
||||||
|
InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error
|
||||||
|
InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error
|
||||||
|
InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error
|
||||||
|
InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error
|
||||||
|
DeleteForwardAccept(chain, comment string) error
|
||||||
|
InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error
|
||||||
|
InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error
|
||||||
|
}
|
||||||
|
|
||||||
// Manager manages iptables rules via the iptables/iptables-legacy CLI
|
// Manager manages iptables rules via the iptables/iptables-legacy CLI
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
binary string // /usr/sbin/iptables or /usr/sbin/iptables-legacy
|
binary string
|
||||||
debug bool
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure Manager implements IPTablesAPI
|
||||||
|
var _ IPTablesAPI = (*Manager)(nil)
|
||||||
|
|
||||||
// NewManager creates a new iptables manager, auto-detecting the binary
|
// NewManager creates a new iptables manager, auto-detecting the binary
|
||||||
func NewManager(debug bool) *Manager {
|
func NewManager(debug bool) *Manager {
|
||||||
m := &Manager{debug: debug}
|
m := &Manager{debug: debug}
|
||||||
@@ -83,17 +102,14 @@ func (m *Manager) EnsureIPForward() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// EnsureEstablishedRelated inserts an ESTABLISHED,RELATED accept rule at the top of a chain
|
// EnsureEstablishedRelated inserts an ESTABLISHED,RELATED accept rule at the top of a chain
|
||||||
// if it doesn't already exist
|
|
||||||
func (m *Manager) EnsureEstablishedRelated(chain string) error {
|
func (m *Manager) EnsureEstablishedRelated(chain string) error {
|
||||||
checkArgs := []string{"-w", "-n", "-L", chain}
|
checkArgs := []string{"-w", "-n", "-L", chain}
|
||||||
cmd := exec.Command(m.binary, checkArgs...)
|
cmd := exec.Command(m.binary, checkArgs...)
|
||||||
output, err := cmd.Output()
|
output, err := cmd.Output()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Chain may not exist, create it
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only insert if ESTABLISHED,RELATED rule is not present
|
|
||||||
if !strings.Contains(string(output), "ESTABLISHED") || !strings.Contains(string(output), "RELATED") {
|
if !strings.Contains(string(output), "ESTABLISHED") || !strings.Contains(string(output), "RELATED") {
|
||||||
args := []string{"-w", "-I", chain, "-m", "state", "--state", "established,related", "-j", "ACCEPT"}
|
args := []string{"-w", "-I", chain, "-m", "state", "--state", "established,related", "-j", "ACCEPT"}
|
||||||
return m.run(args...)
|
return m.run(args...)
|
||||||
@@ -114,7 +130,6 @@ func (m *Manager) DeleteLineInContainer(pid int, table, chain, lineNum string) e
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getLineNumbers returns line numbers matching certain criteria in a chain/table
|
// getLineNumbers returns line numbers matching certain criteria in a chain/table
|
||||||
// This implements the grep logic from the shell script: iptables -w --line-number -n -L $CHAIN | grep ...
|
|
||||||
func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []string {
|
func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []string {
|
||||||
args := []string{"-w", "--line-number", "-n", "-L", chain}
|
args := []string{"-w", "--line-number", "-n", "-L", chain}
|
||||||
if table != "" {
|
if table != "" {
|
||||||
@@ -150,7 +165,6 @@ func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []
|
|||||||
// deleteMatchingLines deletes all lines in a chain matching the given patterns
|
// deleteMatchingLines deletes all lines in a chain matching the given patterns
|
||||||
func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...string) error {
|
func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...string) error {
|
||||||
lines := m.getLineNumbers(chain, table, grepPatterns...)
|
lines := m.getLineNumbers(chain, table, grepPatterns...)
|
||||||
// Reverse order (highest line first) so deletions don't shift line numbers
|
|
||||||
for i := len(lines) - 1; i >= 0; i-- {
|
for i := len(lines) - 1; i >= 0; i-- {
|
||||||
if err := m.DeleteLine(chain, lines[i]); err != nil {
|
if err := m.DeleteLine(chain, lines[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -161,7 +175,6 @@ func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...strin
|
|||||||
|
|
||||||
// deleteMatchingLinesInContainer deletes matching lines inside a container namespace
|
// deleteMatchingLinesInContainer deletes matching lines inside a container namespace
|
||||||
func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, grepPatterns ...string) error {
|
func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, grepPatterns ...string) error {
|
||||||
// For container namespaces, we use a different approach: list via nsenter + grep
|
|
||||||
iptPath := "/sbin/iptables-legacy"
|
iptPath := "/sbin/iptables-legacy"
|
||||||
if !strings.Contains(m.binary, "legacy") {
|
if !strings.Contains(m.binary, "legacy") {
|
||||||
iptPath = "/sbin/iptables"
|
iptPath = "/sbin/iptables"
|
||||||
@@ -192,7 +205,6 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete in reverse order
|
|
||||||
for i := len(matchingLines) - 1; i >= 0; i-- {
|
for i := len(matchingLines) - 1; i >= 0; i-- {
|
||||||
if err := m.DeleteLineInContainer(pid, table, chain, matchingLines[i]); err != nil {
|
if err := m.DeleteLineInContainer(pid, table, chain, matchingLines[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -203,13 +215,11 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g
|
|||||||
|
|
||||||
// InsertPreroutingRule inserts a DNAT PREROUTING rule on the host
|
// InsertPreroutingRule inserts a DNAT PREROUTING rule on the host
|
||||||
func (m *Manager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
func (m *Manager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||||
// First, delete existing matching rules
|
|
||||||
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
|
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
|
||||||
if err := m.deleteMatchingLines("PREROUTING", "nat", patterns...); err != nil {
|
if err := m.deleteMatchingLines("PREROUTING", "nat", patterns...); err != nil {
|
||||||
return fmt.Errorf("failed to delete old PREROUTING rules: %w", err)
|
return fmt.Errorf("failed to delete old PREROUTING rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert the new rule
|
|
||||||
args := []string{
|
args := []string{
|
||||||
"-w", "-t", "nat", "-I", "PREROUTING",
|
"-w", "-t", "nat", "-I", "PREROUTING",
|
||||||
"-d", sourceIP,
|
"-d", sourceIP,
|
||||||
@@ -236,7 +246,6 @@ func (m *Manager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targ
|
|||||||
|
|
||||||
// InsertPostroutingMasquerade inserts a MASQUERADE POSTROUTING rule on the host
|
// InsertPostroutingMasquerade inserts a MASQUERADE POSTROUTING rule on the host
|
||||||
func (m *Manager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
|
func (m *Manager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
|
||||||
// Delete existing matching rules first
|
|
||||||
patterns := []string{"MASQUERADE", comment, sourceCIDR, sourcePort}
|
patterns := []string{"MASQUERADE", comment, sourceCIDR, sourcePort}
|
||||||
if err := m.deleteMatchingLines("POSTROUTING", "nat", patterns...); err != nil {
|
if err := m.deleteMatchingLines("POSTROUTING", "nat", patterns...); err != nil {
|
||||||
return fmt.Errorf("failed to delete old POSTROUTING rules: %w", err)
|
return fmt.Errorf("failed to delete old POSTROUTING rules: %w", err)
|
||||||
@@ -273,7 +282,6 @@ func (m *Manager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, target
|
|||||||
|
|
||||||
// InsertForwardAccept inserts a FORWARD ACCEPT rule on the host
|
// InsertForwardAccept inserts a FORWARD ACCEPT rule on the host
|
||||||
func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
|
func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
|
||||||
// Build grep patterns to match existing rules
|
|
||||||
var grepPatterns []string
|
var grepPatterns []string
|
||||||
grepPatterns = append(grepPatterns, proto)
|
grepPatterns = append(grepPatterns, proto)
|
||||||
if sourceIP != "" {
|
if sourceIP != "" {
|
||||||
@@ -289,12 +297,10 @@ func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePo
|
|||||||
grepPatterns = append(grepPatterns, targetPort)
|
grepPatterns = append(grepPatterns, targetPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete old matching rules
|
|
||||||
if err := m.deleteMatchingLines(chain, "", grepPatterns...); err != nil {
|
if err := m.deleteMatchingLines(chain, "", grepPatterns...); err != nil {
|
||||||
return fmt.Errorf("failed to delete old FORWARD rules: %w", err)
|
return fmt.Errorf("failed to delete old FORWARD rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build iptables args
|
|
||||||
args := []string{"-w", "-I", chain, "-p", proto}
|
args := []string{"-w", "-I", chain, "-p", proto}
|
||||||
if sourceIP != "" {
|
if sourceIP != "" {
|
||||||
args = append(args, "-s", sourceIP)
|
args = append(args, "-s", sourceIP)
|
||||||
@@ -326,7 +332,6 @@ func (m *Manager) DeleteForwardAccept(chain, comment string) error {
|
|||||||
|
|
||||||
// InsertPreroutingRuleInContainer inserts a DNAT PREROUTING rule inside a container namespace
|
// InsertPreroutingRuleInContainer inserts a DNAT PREROUTING rule inside a container namespace
|
||||||
func (m *Manager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
func (m *Manager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||||
// Delete existing first
|
|
||||||
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
|
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
|
||||||
if err := m.deleteMatchingLinesInContainer(pid, "nat", "PREROUTING", patterns...); err != nil {
|
if err := m.deleteMatchingLinesInContainer(pid, "nat", "PREROUTING", patterns...); err != nil {
|
||||||
return fmt.Errorf("failed to delete old container PREROUTING rules: %w", err)
|
return fmt.Errorf("failed to delete old container PREROUTING rules: %w", err)
|
||||||
|
|||||||
@@ -0,0 +1,209 @@
|
|||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/docker/docker/api/types"
|
||||||
|
|
||||||
|
"firewall_containers/network-go/config"
|
||||||
|
"firewall_containers/network-go/docker"
|
||||||
|
"firewall_containers/network-go/iptables"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Compile-time interface conformance checks
|
||||||
|
var _ docker.DockerAPI = (*MockDockerClient)(nil)
|
||||||
|
var _ iptables.IPTablesAPI = (*MockIPTablesManager)(nil)
|
||||||
|
|
||||||
|
// MockDockerClient implements docker.DockerAPI for testing
|
||||||
|
type MockDockerClient struct {
|
||||||
|
EnsureNetworkCalled bool
|
||||||
|
EnsureNetworkCfg config.NetworkConfig
|
||||||
|
EnsureNetworkErr error
|
||||||
|
|
||||||
|
ConnectContainerCalled bool
|
||||||
|
ConnectContainerName string
|
||||||
|
ConnectContainerNetwork string
|
||||||
|
ConnectContainerIP string
|
||||||
|
ConnectContainerErr error
|
||||||
|
|
||||||
|
WaitForRunningCalled bool
|
||||||
|
WaitForRunningName string
|
||||||
|
|
||||||
|
GetContainerPIDCalled bool
|
||||||
|
GetContainerPIDName string
|
||||||
|
GetContainerPIDResult int
|
||||||
|
GetContainerPIDErr error
|
||||||
|
|
||||||
|
AddRouteCalled bool
|
||||||
|
AddRouteContainer string
|
||||||
|
AddRouteNetwork string
|
||||||
|
AddRouteGateway string
|
||||||
|
AddRouteErr error
|
||||||
|
|
||||||
|
InspectContainerErr error
|
||||||
|
RemoveNetworkErr error
|
||||||
|
DisconnectContainerErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) Close() error { return nil }
|
||||||
|
|
||||||
|
func (m *MockDockerClient) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
|
||||||
|
m.EnsureNetworkCalled = true
|
||||||
|
m.EnsureNetworkCfg = netCfg
|
||||||
|
return m.EnsureNetworkErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) RemoveNetwork(ctx context.Context, networkName string) error {
|
||||||
|
return m.RemoveNetworkErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) ConnectContainer(ctx context.Context, containerName, networkName, ip string) error {
|
||||||
|
m.ConnectContainerCalled = true
|
||||||
|
m.ConnectContainerName = containerName
|
||||||
|
m.ConnectContainerNetwork = networkName
|
||||||
|
m.ConnectContainerIP = ip
|
||||||
|
return m.ConnectContainerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) DisconnectContainer(ctx context.Context, containerName, networkName string) error {
|
||||||
|
return m.DisconnectContainerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error) {
|
||||||
|
return nil, m.InspectContainerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error {
|
||||||
|
m.WaitForRunningCalled = true
|
||||||
|
m.WaitForRunningName = containerName
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) GetContainerPID(ctx context.Context, containerName string) (int, error) {
|
||||||
|
m.GetContainerPIDCalled = true
|
||||||
|
m.GetContainerPIDName = containerName
|
||||||
|
return m.GetContainerPIDResult, m.GetContainerPIDErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDockerClient) AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error {
|
||||||
|
m.AddRouteCalled = true
|
||||||
|
m.AddRouteContainer = containerName
|
||||||
|
m.AddRouteNetwork = network
|
||||||
|
m.AddRouteGateway = gateway
|
||||||
|
return m.AddRouteErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockIPTablesManager implements iptables.IPTablesAPI for testing
|
||||||
|
type MockIPTablesManager struct {
|
||||||
|
BinaryResult string
|
||||||
|
EnsureIPForwardCalled bool
|
||||||
|
EnsureIPForwardErr error
|
||||||
|
EnsureEstablishedRelatedCalled bool
|
||||||
|
EnsureEstablishedRelatedChain string
|
||||||
|
EnsureEstablishedRelatedErr error
|
||||||
|
|
||||||
|
InsertPreroutingRuleCalled bool
|
||||||
|
InsertPreroutingRuleArgs []string
|
||||||
|
InsertPreroutingRuleErr error
|
||||||
|
|
||||||
|
InsertPreroutingRuleOnInterfaceCalled bool
|
||||||
|
InsertPreroutingRuleOnInterfaceArgs []string
|
||||||
|
InsertPreroutingRuleOnInterfaceErr error
|
||||||
|
|
||||||
|
InsertPostroutingMasqueradeCalled bool
|
||||||
|
InsertPostroutingMasqueradeArgs []string
|
||||||
|
InsertPostroutingMasqueradeErr error
|
||||||
|
|
||||||
|
InsertForwardAcceptCalled bool
|
||||||
|
InsertForwardAcceptChain string
|
||||||
|
InsertForwardAcceptSourceIP string
|
||||||
|
InsertForwardAcceptTargetIP string
|
||||||
|
InsertForwardAcceptProto string
|
||||||
|
InsertForwardAcceptSourcePort string
|
||||||
|
InsertForwardAcceptTargetPort string
|
||||||
|
InsertForwardAcceptComment string
|
||||||
|
InsertForwardAcceptErr error
|
||||||
|
|
||||||
|
InsertPreroutingRuleInContainerCalled bool
|
||||||
|
InsertPreroutingRuleInContainerPID int
|
||||||
|
InsertPreroutingRuleInContainerArgs []string
|
||||||
|
InsertPreroutingRuleInContainerErr error
|
||||||
|
|
||||||
|
InsertPostroutingMasqueradeInContainerCalled bool
|
||||||
|
InsertPostroutingMasqueradeInContainerErr error
|
||||||
|
DeleteForwardAcceptErr error
|
||||||
|
DeleteLineErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) Binary() string {
|
||||||
|
if m.BinaryResult == "" {
|
||||||
|
return "/usr/sbin/iptables"
|
||||||
|
}
|
||||||
|
return m.BinaryResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) EnsureIPForward() error {
|
||||||
|
m.EnsureIPForwardCalled = true
|
||||||
|
return m.EnsureIPForwardErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) EnsureEstablishedRelated(chain string) error {
|
||||||
|
m.EnsureEstablishedRelatedCalled = true
|
||||||
|
m.EnsureEstablishedRelatedChain = chain
|
||||||
|
return m.EnsureEstablishedRelatedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) DeleteLine(chain string, lineNum string) error {
|
||||||
|
return m.DeleteLineErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||||
|
m.InsertPreroutingRuleCalled = true
|
||||||
|
m.InsertPreroutingRuleArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment}
|
||||||
|
return m.InsertPreroutingRuleErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||||
|
m.InsertPreroutingRuleOnInterfaceCalled = true
|
||||||
|
m.InsertPreroutingRuleOnInterfaceArgs = []string{iface, proto, sourcePort, targetIP, targetPort, comment}
|
||||||
|
return m.InsertPreroutingRuleOnInterfaceErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
|
||||||
|
m.InsertPostroutingMasqueradeCalled = true
|
||||||
|
m.InsertPostroutingMasqueradeArgs = []string{sourceCIDR, proto, sourcePort, comment}
|
||||||
|
return m.InsertPostroutingMasqueradeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
|
||||||
|
m.InsertForwardAcceptCalled = true
|
||||||
|
m.InsertForwardAcceptChain = chain
|
||||||
|
m.InsertForwardAcceptSourceIP = sourceIP
|
||||||
|
m.InsertForwardAcceptTargetIP = targetIP
|
||||||
|
m.InsertForwardAcceptProto = proto
|
||||||
|
m.InsertForwardAcceptSourcePort = sourcePort
|
||||||
|
m.InsertForwardAcceptTargetPort = targetPort
|
||||||
|
m.InsertForwardAcceptComment = comment
|
||||||
|
return m.InsertForwardAcceptErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) DeleteForwardAccept(chain, comment string) error {
|
||||||
|
return m.DeleteForwardAcceptErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||||
|
m.InsertPreroutingRuleInContainerCalled = true
|
||||||
|
m.InsertPreroutingRuleInContainerPID = pid
|
||||||
|
m.InsertPreroutingRuleInContainerArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment}
|
||||||
|
return m.InsertPreroutingRuleInContainerErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockIPTablesManager) InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error {
|
||||||
|
m.InsertPostroutingMasqueradeInContainerCalled = true
|
||||||
|
return m.InsertPostroutingMasqueradeInContainerErr
|
||||||
|
}
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
package resolver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"firewall_containers/network-go/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeTestConfig() *config.NetworksConfig {
|
||||||
|
return &config.NetworksConfig{
|
||||||
|
Networks: map[string]config.NetworkConfig{
|
||||||
|
"test-net": {NetworkName: "test-net", Subnet: "172.18.103.0/24", Gateway: "172.18.103.1"},
|
||||||
|
},
|
||||||
|
IPs: map[string]config.IPConfig{
|
||||||
|
"172.18.103.2": {IP: "172.18.103.2", ContainerName: "smarthostloadbalancer", Selector: "smarthostloadbalancer", ServiceName: "smarthost-proxy"},
|
||||||
|
"172.18.104.2": {IP: "172.18.104.2", ContainerName: "smarthostbackend-1", Selector: "smarthostbackend-1", ServiceName: "smarthost-proxy"},
|
||||||
|
},
|
||||||
|
Policies: []config.PolicyConfig{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveByContainerName(t *testing.T) {
|
||||||
|
cfg := makeTestConfig()
|
||||||
|
r := NewResolver(cfg)
|
||||||
|
|
||||||
|
ips := r.Resolve("smarthostloadbalancer")
|
||||||
|
if len(ips) != 1 {
|
||||||
|
t.Fatalf("expected 1 IP, got %d", len(ips))
|
||||||
|
}
|
||||||
|
if ips[0] != "172.18.103.2" {
|
||||||
|
t.Errorf("expected 172.18.103.2, got %s", ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveBySelector(t *testing.T) {
|
||||||
|
cfg := makeTestConfig()
|
||||||
|
r := NewResolver(cfg)
|
||||||
|
|
||||||
|
ips := r.Resolve("smarthostbackend-1")
|
||||||
|
if len(ips) != 1 {
|
||||||
|
t.Fatalf("expected 1 IP, got %d", len(ips))
|
||||||
|
}
|
||||||
|
if ips[0] != "172.18.104.2" {
|
||||||
|
t.Errorf("expected 172.18.104.2, got %s", ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveNotFound(t *testing.T) {
|
||||||
|
cfg := makeTestConfig()
|
||||||
|
r := NewResolver(cfg)
|
||||||
|
|
||||||
|
ips := r.Resolve("nonexistent-container")
|
||||||
|
if len(ips) != 0 {
|
||||||
|
t.Errorf("expected 0 IPs, got %d", len(ips))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveNilConfig(t *testing.T) {
|
||||||
|
r := NewResolver(nil)
|
||||||
|
|
||||||
|
ips := r.Resolve("anything")
|
||||||
|
if len(ips) != 0 {
|
||||||
|
t.Errorf("expected 0 IPs with nil config, got %d", len(ips))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetConfig(t *testing.T) {
|
||||||
|
r := NewResolver(nil)
|
||||||
|
|
||||||
|
// Initially nil config
|
||||||
|
ips := r.Resolve("test")
|
||||||
|
if len(ips) != 0 {
|
||||||
|
t.Errorf("expected 0 IPs with nil config, got %d", len(ips))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update with valid config
|
||||||
|
cfg := makeTestConfig()
|
||||||
|
r.SetConfig(cfg)
|
||||||
|
|
||||||
|
ips = r.Resolve("smarthostloadbalancer")
|
||||||
|
if len(ips) != 1 {
|
||||||
|
t.Fatalf("expected 1 IP after SetConfig, got %d", len(ips))
|
||||||
|
}
|
||||||
|
if ips[0] != "172.18.103.2" {
|
||||||
|
t.Errorf("expected 172.18.103.2, got %s", ips[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveReproducible(t *testing.T) {
|
||||||
|
cfg := makeTestConfig()
|
||||||
|
r := NewResolver(cfg)
|
||||||
|
|
||||||
|
ips1 := r.Resolve("smarthostloadbalancer")
|
||||||
|
ips2 := r.Resolve("smarthostloadbalancer")
|
||||||
|
|
||||||
|
if len(ips1) != len(ips2) {
|
||||||
|
t.Errorf("reproducibility: result count mismatch %d vs %d", len(ips1), len(ips2))
|
||||||
|
}
|
||||||
|
if len(ips1) > 0 && ips1[0] != ips2[0] {
|
||||||
|
t.Errorf("reproducibility: IP mismatch %s vs %s", ips1[0], ips2[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveMultipleMatch(t *testing.T) {
|
||||||
|
cfg := &config.NetworksConfig{
|
||||||
|
IPs: map[string]config.IPConfig{
|
||||||
|
"10.0.0.1": {IP: "10.0.0.1", ContainerName: "app-1", Selector: "app-1"},
|
||||||
|
"10.0.0.2": {IP: "10.0.0.2", ContainerName: "app-2", Selector: "app-2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
r := NewResolver(cfg)
|
||||||
|
|
||||||
|
// "app-x" has a dash, so prefix "app" is extracted and matches both app-1 and app-2
|
||||||
|
ips := r.Resolve("app-x")
|
||||||
|
if len(ips) != 2 {
|
||||||
|
t.Errorf("expected 2 IPs for prefix match, got %d", len(ips))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
package watcher
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWatcherDetectsChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.json")
|
||||||
|
if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
changeDetected := make(chan bool, 1)
|
||||||
|
onChange := func() {
|
||||||
|
select {
|
||||||
|
case changeDetected <- true:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fw := NewFileWatcher(path, 100*time.Millisecond, onChange)
|
||||||
|
fw.Start()
|
||||||
|
defer fw.Stop()
|
||||||
|
|
||||||
|
// Wait for initial hash to be computed
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Modify the file
|
||||||
|
if err := os.WriteFile(path, []byte(`{"version": 2}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to modify test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for change detection
|
||||||
|
select {
|
||||||
|
case detected := <-changeDetected:
|
||||||
|
if !detected {
|
||||||
|
t.Error("expected change detection, got false")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("timeout waiting for file change detection")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherNoChange(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.json")
|
||||||
|
if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
changeDetected := make(chan bool, 1)
|
||||||
|
onChange := func() {
|
||||||
|
select {
|
||||||
|
case changeDetected <- true:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fw := NewFileWatcher(path, 100*time.Millisecond, onChange)
|
||||||
|
fw.Start()
|
||||||
|
defer fw.Stop()
|
||||||
|
|
||||||
|
// Wait without modifying the file
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should not detect a change
|
||||||
|
select {
|
||||||
|
case <-changeDetected:
|
||||||
|
t.Error("unexpected change detection without file modification")
|
||||||
|
default:
|
||||||
|
// Expected: no change detected
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherMultipleChanges(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.json")
|
||||||
|
if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
changeCount := 0
|
||||||
|
var mu chan struct{}
|
||||||
|
|
||||||
|
onChange := func() {
|
||||||
|
changeCount++
|
||||||
|
if mu != nil {
|
||||||
|
mu <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fw := NewFileWatcher(path, 50*time.Millisecond, onChange)
|
||||||
|
fw.Start()
|
||||||
|
defer fw.Stop()
|
||||||
|
|
||||||
|
// Wait for initial hash
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Make 3 changes
|
||||||
|
for i := 2; i <= 4; i++ {
|
||||||
|
if err := os.WriteFile(path, []byte(`{"v": `+string(rune('0'+i))+`}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to modify file: %v", err)
|
||||||
|
}
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
if changeCount < 2 {
|
||||||
|
t.Errorf("expected at least 2 change detections, got %d", changeCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherStop(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.json")
|
||||||
|
if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fw := NewFileWatcher(path, 50*time.Millisecond, func() {})
|
||||||
|
fw.Start()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
fw.Stop()
|
||||||
|
|
||||||
|
// Should not panic on double stop or further operations
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherNonexistentFile(t *testing.T) {
|
||||||
|
onChange := func() {
|
||||||
|
t.Error("onChange should not be called for nonexistent file")
|
||||||
|
}
|
||||||
|
|
||||||
|
fw := NewFileWatcher("/nonexistent/file.json", 100*time.Millisecond, onChange)
|
||||||
|
fw.Start()
|
||||||
|
defer fw.Stop()
|
||||||
|
|
||||||
|
// Wait — should not call onChange since file doesn't exist
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatcherReproducible(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.json")
|
||||||
|
if err := os.WriteFile(path, []byte(`{"data": "value"}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create two watchers on the same file
|
||||||
|
var detected1, detected2 bool
|
||||||
|
|
||||||
|
fw1 := NewFileWatcher(path, 50*time.Millisecond, func() { detected1 = true })
|
||||||
|
fw2 := NewFileWatcher(path, 50*time.Millisecond, func() { detected2 = true })
|
||||||
|
|
||||||
|
fw1.Start()
|
||||||
|
fw2.Start()
|
||||||
|
defer fw1.Stop()
|
||||||
|
defer fw2.Stop()
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Modify file
|
||||||
|
if err := os.WriteFile(path, []byte(`{"data": "changed"}`), 0644); err != nil {
|
||||||
|
t.Fatalf("failed to modify file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
|
||||||
|
if !detected1 {
|
||||||
|
t.Error("watcher 1 did not detect change")
|
||||||
|
}
|
||||||
|
if !detected2 {
|
||||||
|
t.Error("watcher 2 did not detect change")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user