This commit is contained in:
@@ -0,0 +1,231 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const testJSON = `{
|
||||
"networks": {
|
||||
"smarthost-loadbalancer": {
|
||||
"network_name": "smarthost-loadbalancer",
|
||||
"subnet": "172.18.103.0/24",
|
||||
"gateway": "172.18.103.1"
|
||||
},
|
||||
"smarthost_backend-1": {
|
||||
"network_name": "smarthost_backend-1",
|
||||
"subnet": "172.18.104.0/24",
|
||||
"gateway": "172.18.104.1"
|
||||
}
|
||||
},
|
||||
"ips": {
|
||||
"172.18.103.2": {
|
||||
"ip": "172.18.103.2",
|
||||
"container_name": "smarthostloadbalancer",
|
||||
"selector": "smarthostloadbalancer",
|
||||
"service_name": "smarthost-proxy"
|
||||
},
|
||||
"172.18.104.2": {
|
||||
"ip": "172.18.104.2",
|
||||
"container_name": "smarthostbackend-1",
|
||||
"selector": "smarthostbackend-1",
|
||||
"service_name": "smarthost-proxy"
|
||||
}
|
||||
},
|
||||
"policies": [
|
||||
{
|
||||
"service_name": "smarthost-proxy",
|
||||
"container_name": "smarthost_loadbalancer",
|
||||
"selector": "smarthostloadbalancer",
|
||||
"from": "publicbackend",
|
||||
"port": 80,
|
||||
"proto": "tcp"
|
||||
},
|
||||
{
|
||||
"service_name": "smarthost-proxy",
|
||||
"container_name": "smarthost_loadbalancer",
|
||||
"selector": "smarthostloadbalancer",
|
||||
"name": "wireguardproxy",
|
||||
"iface": "wg0",
|
||||
"nat": "dnat",
|
||||
"to": "smarthostloadbalancer",
|
||||
"port": 80,
|
||||
"proto": "tcp"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "networks.json")
|
||||
if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Load() returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Networks) != 2 {
|
||||
t.Errorf("expected 2 networks, got %d", len(cfg.Networks))
|
||||
}
|
||||
|
||||
if cfg.Networks["smarthost-loadbalancer"].Subnet != "172.18.103.0/24" {
|
||||
t.Errorf("unexpected subnet: %s", cfg.Networks["smarthost-loadbalancer"].Subnet)
|
||||
}
|
||||
|
||||
if len(cfg.IPs) != 2 {
|
||||
t.Errorf("expected 2 IPs, got %d", len(cfg.IPs))
|
||||
}
|
||||
|
||||
if cfg.IPs["172.18.103.2"].ContainerName != "smarthostloadbalancer" {
|
||||
t.Errorf("unexpected container name: %s", cfg.IPs["172.18.103.2"].ContainerName)
|
||||
}
|
||||
|
||||
if len(cfg.Policies) != 2 {
|
||||
t.Errorf("expected 2 policies, got %d", len(cfg.Policies))
|
||||
}
|
||||
|
||||
if cfg.Policies[0].From != "publicbackend" {
|
||||
t.Errorf("unexpected from: %s", cfg.Policies[0].From)
|
||||
}
|
||||
|
||||
if cfg.Policies[1].Nat != "dnat" {
|
||||
t.Errorf("unexpected nat: %s", cfg.Policies[1].Nat)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFileNotFound(t *testing.T) {
|
||||
_, err := Load("/nonexistent/path/networks.json")
|
||||
if err == nil {
|
||||
t.Error("expected error for nonexistent file, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadInvalidJSON(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "bad.json")
|
||||
if err := os.WriteFile(path, []byte("{invalid json"), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid JSON, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToCIDR(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"172.18.103.0", "172.18.103.0/24"},
|
||||
{"172.18.103.2", "172.18.103.2"},
|
||||
{"172.18.103.0/24", "172.18.103.0/24"},
|
||||
{"10.0.0.0", "10.0.0.0/24"},
|
||||
{"invalid", "invalid"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := ToCIDR(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("ToCIDR(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{"172.18.103.2", "172.18.103.0/24"},
|
||||
{"10.0.0.1", "10.0.0.0/24"},
|
||||
{"invalid", "invalid"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := NetworkPrefix(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("NetworkPrefix(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want bool
|
||||
}{
|
||||
{"172.18.103.2", true},
|
||||
{"10.0.0.0", true},
|
||||
{"255.255.255.255", true},
|
||||
{"publicbackend", false},
|
||||
{"172.18.103.0/24", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := IsIP(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsIP(%q) = %v, want %v", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNetworkConfigParseCIDR(t *testing.T) {
|
||||
nc := NetworkConfig{Subnet: "172.18.103.0/24"}
|
||||
ipNet, err := nc.ParseCIDR()
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCIDR() returned error: %v", err)
|
||||
}
|
||||
if ipNet.String() != "172.18.103.0/24" {
|
||||
t.Errorf("ParseCIDR() = %s, want 172.18.103.0/24", ipNet.String())
|
||||
}
|
||||
|
||||
nc2 := NetworkConfig{Subnet: "invalid"}
|
||||
_, err = nc2.ParseCIDR()
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid subnet, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadReproducible(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "networks.json")
|
||||
if err := os.WriteFile(path, []byte(testJSON), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
// Load twice and compare
|
||||
cfg1, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("first Load() error: %v", err)
|
||||
}
|
||||
|
||||
cfg2, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("second Load() error: %v", err)
|
||||
}
|
||||
|
||||
// Verify reproducibility
|
||||
if len(cfg1.Networks) != len(cfg2.Networks) {
|
||||
t.Errorf("reproducibility: network count mismatch %d vs %d", len(cfg1.Networks), len(cfg2.Networks))
|
||||
}
|
||||
if len(cfg1.IPs) != len(cfg2.IPs) {
|
||||
t.Errorf("reproducibility: IP count mismatch %d vs %d", len(cfg1.IPs), len(cfg2.IPs))
|
||||
}
|
||||
if len(cfg1.Policies) != len(cfg2.Policies) {
|
||||
t.Errorf("reproducibility: policy count mismatch %d vs %d", len(cfg1.Policies), len(cfg2.Policies))
|
||||
}
|
||||
|
||||
for name, net1 := range cfg1.Networks {
|
||||
net2 := cfg2.Networks[name]
|
||||
if net1.Subnet != net2.Subnet || net1.Gateway != net2.Gateway {
|
||||
t.Errorf("reproducibility: network %s mismatch", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -15,11 +15,27 @@ import (
|
||||
"firewall_containers/network-go/config"
|
||||
)
|
||||
|
||||
// DockerAPI defines the interface for Docker operations, enabling mock implementations for testing
|
||||
type DockerAPI interface {
|
||||
Close() error
|
||||
EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error
|
||||
RemoveNetwork(ctx context.Context, networkName string) error
|
||||
ConnectContainer(ctx context.Context, containerName, networkName, ip string) error
|
||||
DisconnectContainer(ctx context.Context, containerName, networkName string) error
|
||||
InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error)
|
||||
WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error
|
||||
GetContainerPID(ctx context.Context, containerName string) (int, error)
|
||||
AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error
|
||||
}
|
||||
|
||||
// Client wraps the Docker SDK client
|
||||
type Client struct {
|
||||
cli *client.Client
|
||||
}
|
||||
|
||||
// Ensure Client implements DockerAPI
|
||||
var _ DockerAPI = (*Client)(nil)
|
||||
|
||||
// NewClient creates a new Docker client
|
||||
func NewClient() (*Client, error) {
|
||||
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
|
||||
@@ -36,7 +52,6 @@ func (c *Client) Close() error {
|
||||
|
||||
// EnsureNetwork creates a Docker network if it does not already exist
|
||||
func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
|
||||
// Check if network already exists
|
||||
existingNetworks, err := c.cli.NetworkList(ctx, network.ListOptions{
|
||||
Filters: filters.NewArgs(filters.Arg("name", netCfg.NetworkName)),
|
||||
})
|
||||
@@ -46,12 +61,10 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
|
||||
|
||||
for _, n := range existingNetworks {
|
||||
if n.Name == netCfg.NetworkName {
|
||||
// Network already exists, skip creation
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Parse subnet and gateway
|
||||
_, ipNet, err := net.ParseCIDR(netCfg.Subnet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse subnet %s: %w", netCfg.Subnet, err)
|
||||
@@ -62,7 +75,6 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
|
||||
return fmt.Errorf("failed to parse gateway IP %s", netCfg.Gateway)
|
||||
}
|
||||
|
||||
// Create the network
|
||||
createOpts := network.CreateOptions{
|
||||
Driver: "bridge",
|
||||
IPAM: &network.IPAM{
|
||||
@@ -84,7 +96,7 @@ func (c *Client) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig)
|
||||
return fmt.Errorf("failed to create network %s: %w", netCfg.NetworkName, err)
|
||||
}
|
||||
|
||||
_ = resp // response contains ID and warnings
|
||||
_ = resp
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -154,6 +166,7 @@ func (c *Client) WaitForContainerRunning(ctx context.Context, containerName stri
|
||||
}
|
||||
}
|
||||
|
||||
// GetContainerPID returns the PID of a container for nsenter operations
|
||||
func (c *Client) GetContainerPID(ctx context.Context, containerName string) (int, error) {
|
||||
cont, err := c.cli.ContainerInspect(ctx, containerName)
|
||||
if err != nil {
|
||||
|
||||
@@ -16,14 +16,14 @@ import (
|
||||
// Orchestrator reconciles the networks.json configuration into Docker networks
|
||||
// and iptables firewall rules
|
||||
type Orchestrator struct {
|
||||
dockerClient *docker.Client
|
||||
iptablesMgr *iptables.Manager
|
||||
dockerClient docker.DockerAPI
|
||||
iptablesMgr iptables.IPTablesAPI
|
||||
resolver *resolver.Resolver
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewOrchestrator creates a new firewall orchestrator
|
||||
func NewOrchestrator(dockerClient *docker.Client, iptablesMgr *iptables.Manager, cfg *config.NetworksConfig) *Orchestrator {
|
||||
func NewOrchestrator(dockerClient docker.DockerAPI, iptablesMgr iptables.IPTablesAPI, cfg *config.NetworksConfig) *Orchestrator {
|
||||
return &Orchestrator{
|
||||
dockerClient: dockerClient,
|
||||
iptablesMgr: iptablesMgr,
|
||||
|
||||
@@ -0,0 +1,333 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"firewall_containers/network-go/config"
|
||||
"firewall_containers/network-go/mock"
|
||||
)
|
||||
|
||||
func testConfig() *config.NetworksConfig {
|
||||
return &config.NetworksConfig{
|
||||
Networks: map[string]config.NetworkConfig{
|
||||
"smarthost-loadbalancer": {
|
||||
NetworkName: "smarthost-loadbalancer",
|
||||
Subnet: "172.18.103.0/24",
|
||||
Gateway: "172.18.103.1",
|
||||
},
|
||||
"smarthost_backend-1": {
|
||||
NetworkName: "smarthost_backend-1",
|
||||
Subnet: "172.18.104.0/24",
|
||||
Gateway: "172.18.104.1",
|
||||
},
|
||||
},
|
||||
IPs: map[string]config.IPConfig{
|
||||
"172.18.103.2": {
|
||||
IP: "172.18.103.2",
|
||||
ContainerName: "smarthostloadbalancer",
|
||||
Selector: "smarthostloadbalancer",
|
||||
ServiceName: "smarthost-proxy",
|
||||
},
|
||||
"172.18.104.2": {
|
||||
IP: "172.18.104.2",
|
||||
ContainerName: "smarthostbackend-1",
|
||||
Selector: "smarthostbackend-1",
|
||||
ServiceName: "smarthost-proxy",
|
||||
},
|
||||
},
|
||||
Policies: []config.PolicyConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileAllCreatesNetworks(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.ReconcileAll(ctx, cfg)
|
||||
|
||||
if !docker.EnsureNetworkCalled {
|
||||
t.Error("EnsureNetwork was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileAllConnectsContainers(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.ReconcileAll(ctx, cfg)
|
||||
|
||||
if !docker.ConnectContainerCalled {
|
||||
t.Error("ConnectContainer was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileAllEnablesIPForward(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.ReconcileAll(ctx, cfg)
|
||||
|
||||
if !iptables.EnsureIPForwardCalled {
|
||||
t.Error("EnsureIPForward was not called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcilePoliciesForwardRule(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Policies = []config.PolicyConfig{
|
||||
{
|
||||
ServiceName: "smarthost-proxy",
|
||||
ContainerName: "smarthostloadbalancer",
|
||||
Selector: "smarthostloadbalancer",
|
||||
From: "publicbackend",
|
||||
Port: 80,
|
||||
Proto: "tcp",
|
||||
},
|
||||
}
|
||||
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables"}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.reconcilePolicies(ctx, cfg)
|
||||
|
||||
// "publicbackend" should be resolved to an IP from the config (if it matches)
|
||||
// Since "publicbackend" is not in the IPs map, sourceIP will be empty
|
||||
if !iptables.InsertForwardAcceptCalled {
|
||||
t.Error("InsertForwardAccept was not called")
|
||||
}
|
||||
|
||||
// Should use FORWARD chain (not iptables-legacy)
|
||||
if iptables.InsertForwardAcceptChain != "FORWARD" {
|
||||
t.Errorf("expected FORWARD chain, got %s", iptables.InsertForwardAcceptChain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcilePoliciesForwardRuleWithLegacy(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Policies = []config.PolicyConfig{
|
||||
{
|
||||
ServiceName: "smarthost-proxy",
|
||||
From: "publicbackend",
|
||||
Port: 80,
|
||||
Proto: "tcp",
|
||||
},
|
||||
}
|
||||
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{BinaryResult: "/usr/sbin/iptables-legacy"}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.reconcilePolicies(ctx, cfg)
|
||||
|
||||
if !iptables.InsertForwardAcceptCalled {
|
||||
t.Error("InsertForwardAccept was not called")
|
||||
}
|
||||
|
||||
// Should use DOCKER-USER chain when iptables-legacy
|
||||
if iptables.InsertForwardAcceptChain != "DOCKER-USER" {
|
||||
t.Errorf("expected DOCKER-USER chain, got %s", iptables.InsertForwardAcceptChain)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcilePoliciesDNATWithInterface(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Policies = []config.PolicyConfig{
|
||||
{
|
||||
ServiceName: "smarthost-proxy",
|
||||
Name: "wireguardproxy",
|
||||
Selector: "smarthostloadbalancer",
|
||||
Iface: "wg0",
|
||||
Nat: "dnat",
|
||||
To: "smarthostloadbalancer",
|
||||
Port: 80,
|
||||
Proto: "tcp",
|
||||
},
|
||||
}
|
||||
|
||||
docker := &mock.MockDockerClient{
|
||||
GetContainerPIDResult: 0, // simulate no PID available
|
||||
GetContainerPIDErr: assertError("container not running"),
|
||||
}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.reconcilePolicies(ctx, cfg)
|
||||
|
||||
// When GetContainerPID fails, should fall back to interface-based rule
|
||||
if !iptables.InsertPreroutingRuleOnInterfaceCalled {
|
||||
t.Error("InsertPreroutingRuleOnInterface was not called (should fall back from nsenter)")
|
||||
}
|
||||
|
||||
if len(iptables.InsertPreroutingRuleOnInterfaceArgs) > 0 {
|
||||
iface := iptables.InsertPreroutingRuleOnInterfaceArgs[0]
|
||||
if iface != "wg0" {
|
||||
t.Errorf("expected interface wg0, got %s", iface)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcilePoliciesDNATWithContainerPID(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Policies = []config.PolicyConfig{
|
||||
{
|
||||
ServiceName: "smarthost-proxy",
|
||||
Name: "wireguardproxy",
|
||||
Selector: "smarthostloadbalancer",
|
||||
Nat: "dnat",
|
||||
To: "smarthostloadbalancer",
|
||||
Port: 80,
|
||||
Proto: "tcp",
|
||||
},
|
||||
}
|
||||
|
||||
docker := &mock.MockDockerClient{
|
||||
GetContainerPIDResult: 1234,
|
||||
GetContainerPIDErr: nil,
|
||||
}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.reconcilePolicies(ctx, cfg)
|
||||
|
||||
if !docker.GetContainerPIDCalled {
|
||||
t.Error("GetContainerPID was not called")
|
||||
}
|
||||
|
||||
if !iptables.InsertPreroutingRuleInContainerCalled {
|
||||
t.Error("InsertPreroutingRuleInContainer was not called")
|
||||
}
|
||||
|
||||
if iptables.InsertPreroutingRuleInContainerPID != 1234 {
|
||||
t.Errorf("expected PID 1234, got %d", iptables.InsertPreroutingRuleInContainerPID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcilePoliciesUnresolvedTarget(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Policies = []config.PolicyConfig{
|
||||
{
|
||||
ServiceName: "test",
|
||||
Nat: "dnat",
|
||||
Selector: "container1",
|
||||
To: "nonexistent-target",
|
||||
Port: 80,
|
||||
},
|
||||
}
|
||||
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
orch.reconcilePolicies(ctx, cfg)
|
||||
|
||||
// Should not call GetContainerPID when target can't be resolved
|
||||
if docker.GetContainerPIDCalled {
|
||||
t.Error("GetContainerPID should not be called when target is unresolvable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveIPDirectIP(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
// Direct IP should be returned as CIDR
|
||||
result := orch.resolveIP("172.18.103.2")
|
||||
if result != "172.18.103.2" {
|
||||
t.Errorf("expected 172.18.103.2, got %s", result)
|
||||
}
|
||||
|
||||
// .0 ending should be converted to /24
|
||||
result = orch.resolveIP("172.18.103.0")
|
||||
if result != "172.18.103.0/24" {
|
||||
t.Errorf("expected 172.18.103.0/24, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveIPFromConfig(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
|
||||
// Should resolve container name from config
|
||||
result := orch.resolveIP("smarthostloadbalancer")
|
||||
if result != "172.18.103.2" {
|
||||
t.Errorf("expected 172.18.103.2, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindNetworkForIP(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
want string
|
||||
}{
|
||||
{"172.18.103.5", "smarthost-loadbalancer"},
|
||||
{"172.18.104.5", "smarthost_backend-1"},
|
||||
{"10.0.0.1", ""},
|
||||
{"", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := findNetworkForIP(cfg, tt.ip)
|
||||
if got != tt.want {
|
||||
t.Errorf("findNetworkForIP(%q) = %q, want %q", tt.ip, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconcileAllReproducible(t *testing.T) {
|
||||
cfg := testConfig()
|
||||
cfg.Policies = []config.PolicyConfig{
|
||||
{
|
||||
ServiceName: "test-svc",
|
||||
From: "publicbackend",
|
||||
Port: 80,
|
||||
Proto: "tcp",
|
||||
},
|
||||
}
|
||||
|
||||
// Run reconciliation twice with separate mocks
|
||||
for i := 0; i < 2; i++ {
|
||||
docker := &mock.MockDockerClient{}
|
||||
iptables := &mock.MockIPTablesManager{}
|
||||
orch := NewOrchestrator(docker, iptables, cfg)
|
||||
ctx := context.Background()
|
||||
orch.ReconcileAll(ctx, cfg)
|
||||
|
||||
if !docker.EnsureNetworkCalled {
|
||||
t.Errorf("run %d: EnsureNetwork not called", i)
|
||||
}
|
||||
if !iptables.InsertForwardAcceptCalled {
|
||||
t.Errorf("run %d: InsertForwardAccept not called", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// assertError is a helper to create a simple error for tests
|
||||
type simpleErr struct{ msg string }
|
||||
|
||||
func (e simpleErr) Error() string { return e.msg }
|
||||
|
||||
func assertError(msg string) error {
|
||||
return simpleErr{msg: msg}
|
||||
}
|
||||
@@ -88,26 +88,6 @@ Without these, the program cannot:
|
||||
- Insert PREROUTING/POSTROUTING rules inside other containers
|
||||
- Add routes to container network namespaces
|
||||
|
||||
### Minimal Docker Compose Example
|
||||
|
||||
```yaml
|
||||
version: "3.8"
|
||||
services:
|
||||
network-go:
|
||||
build: ./network-go
|
||||
network_mode: host
|
||||
pid: "host"
|
||||
cap_add:
|
||||
- NET_ADMIN
|
||||
- SYS_ADMIN
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /etc/user/config:/etc/user/config
|
||||
environment:
|
||||
- WATCH_PERIOD_SECONDS=30
|
||||
- DEBUG=false
|
||||
```
|
||||
|
||||
## Configuration — `/etc/user/config/networks.json`
|
||||
|
||||
```json
|
||||
|
||||
@@ -6,12 +6,31 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IPTablesAPI defines the interface for iptables operations, enabling mock implementations for testing
|
||||
type IPTablesAPI interface {
|
||||
Binary() string
|
||||
EnsureIPForward() error
|
||||
EnsureEstablishedRelated(chain string) error
|
||||
DeleteLine(chain string, lineNum string) error
|
||||
InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error
|
||||
InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error
|
||||
InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error
|
||||
InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error
|
||||
InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error
|
||||
DeleteForwardAccept(chain, comment string) error
|
||||
InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error
|
||||
InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error
|
||||
}
|
||||
|
||||
// Manager manages iptables rules via the iptables/iptables-legacy CLI
|
||||
type Manager struct {
|
||||
binary string // /usr/sbin/iptables or /usr/sbin/iptables-legacy
|
||||
binary string
|
||||
debug bool
|
||||
}
|
||||
|
||||
// Ensure Manager implements IPTablesAPI
|
||||
var _ IPTablesAPI = (*Manager)(nil)
|
||||
|
||||
// NewManager creates a new iptables manager, auto-detecting the binary
|
||||
func NewManager(debug bool) *Manager {
|
||||
m := &Manager{debug: debug}
|
||||
@@ -83,17 +102,14 @@ func (m *Manager) EnsureIPForward() error {
|
||||
}
|
||||
|
||||
// EnsureEstablishedRelated inserts an ESTABLISHED,RELATED accept rule at the top of a chain
|
||||
// if it doesn't already exist
|
||||
func (m *Manager) EnsureEstablishedRelated(chain string) error {
|
||||
checkArgs := []string{"-w", "-n", "-L", chain}
|
||||
cmd := exec.Command(m.binary, checkArgs...)
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
// Chain may not exist, create it
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only insert if ESTABLISHED,RELATED rule is not present
|
||||
if !strings.Contains(string(output), "ESTABLISHED") || !strings.Contains(string(output), "RELATED") {
|
||||
args := []string{"-w", "-I", chain, "-m", "state", "--state", "established,related", "-j", "ACCEPT"}
|
||||
return m.run(args...)
|
||||
@@ -114,7 +130,6 @@ func (m *Manager) DeleteLineInContainer(pid int, table, chain, lineNum string) e
|
||||
}
|
||||
|
||||
// getLineNumbers returns line numbers matching certain criteria in a chain/table
|
||||
// This implements the grep logic from the shell script: iptables -w --line-number -n -L $CHAIN | grep ...
|
||||
func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []string {
|
||||
args := []string{"-w", "--line-number", "-n", "-L", chain}
|
||||
if table != "" {
|
||||
@@ -150,7 +165,6 @@ func (m *Manager) getLineNumbers(chain, table string, grepPatterns ...string) []
|
||||
// deleteMatchingLines deletes all lines in a chain matching the given patterns
|
||||
func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...string) error {
|
||||
lines := m.getLineNumbers(chain, table, grepPatterns...)
|
||||
// Reverse order (highest line first) so deletions don't shift line numbers
|
||||
for i := len(lines) - 1; i >= 0; i-- {
|
||||
if err := m.DeleteLine(chain, lines[i]); err != nil {
|
||||
return err
|
||||
@@ -161,7 +175,6 @@ func (m *Manager) deleteMatchingLines(chain, table string, grepPatterns ...strin
|
||||
|
||||
// deleteMatchingLinesInContainer deletes matching lines inside a container namespace
|
||||
func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, grepPatterns ...string) error {
|
||||
// For container namespaces, we use a different approach: list via nsenter + grep
|
||||
iptPath := "/sbin/iptables-legacy"
|
||||
if !strings.Contains(m.binary, "legacy") {
|
||||
iptPath = "/sbin/iptables"
|
||||
@@ -192,7 +205,6 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g
|
||||
}
|
||||
}
|
||||
|
||||
// Delete in reverse order
|
||||
for i := len(matchingLines) - 1; i >= 0; i-- {
|
||||
if err := m.DeleteLineInContainer(pid, table, chain, matchingLines[i]); err != nil {
|
||||
return err
|
||||
@@ -203,13 +215,11 @@ func (m *Manager) deleteMatchingLinesInContainer(pid int, table, chain string, g
|
||||
|
||||
// InsertPreroutingRule inserts a DNAT PREROUTING rule on the host
|
||||
func (m *Manager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||
// First, delete existing matching rules
|
||||
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
|
||||
if err := m.deleteMatchingLines("PREROUTING", "nat", patterns...); err != nil {
|
||||
return fmt.Errorf("failed to delete old PREROUTING rules: %w", err)
|
||||
}
|
||||
|
||||
// Insert the new rule
|
||||
args := []string{
|
||||
"-w", "-t", "nat", "-I", "PREROUTING",
|
||||
"-d", sourceIP,
|
||||
@@ -236,7 +246,6 @@ func (m *Manager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targ
|
||||
|
||||
// InsertPostroutingMasquerade inserts a MASQUERADE POSTROUTING rule on the host
|
||||
func (m *Manager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
|
||||
// Delete existing matching rules first
|
||||
patterns := []string{"MASQUERADE", comment, sourceCIDR, sourcePort}
|
||||
if err := m.deleteMatchingLines("POSTROUTING", "nat", patterns...); err != nil {
|
||||
return fmt.Errorf("failed to delete old POSTROUTING rules: %w", err)
|
||||
@@ -273,7 +282,6 @@ func (m *Manager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, target
|
||||
|
||||
// InsertForwardAccept inserts a FORWARD ACCEPT rule on the host
|
||||
func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
|
||||
// Build grep patterns to match existing rules
|
||||
var grepPatterns []string
|
||||
grepPatterns = append(grepPatterns, proto)
|
||||
if sourceIP != "" {
|
||||
@@ -289,12 +297,10 @@ func (m *Manager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePo
|
||||
grepPatterns = append(grepPatterns, targetPort)
|
||||
}
|
||||
|
||||
// Delete old matching rules
|
||||
if err := m.deleteMatchingLines(chain, "", grepPatterns...); err != nil {
|
||||
return fmt.Errorf("failed to delete old FORWARD rules: %w", err)
|
||||
}
|
||||
|
||||
// Build iptables args
|
||||
args := []string{"-w", "-I", chain, "-p", proto}
|
||||
if sourceIP != "" {
|
||||
args = append(args, "-s", sourceIP)
|
||||
@@ -326,7 +332,6 @@ func (m *Manager) DeleteForwardAccept(chain, comment string) error {
|
||||
|
||||
// InsertPreroutingRuleInContainer inserts a DNAT PREROUTING rule inside a container namespace
|
||||
func (m *Manager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||
// Delete existing first
|
||||
patterns := []string{"DNAT", sourcePort, targetIP, targetPort, comment}
|
||||
if err := m.deleteMatchingLinesInContainer(pid, "nat", "PREROUTING", patterns...); err != nil {
|
||||
return fmt.Errorf("failed to delete old container PREROUTING rules: %w", err)
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
|
||||
"firewall_containers/network-go/config"
|
||||
"firewall_containers/network-go/docker"
|
||||
"firewall_containers/network-go/iptables"
|
||||
)
|
||||
|
||||
// Compile-time interface conformance checks
|
||||
var _ docker.DockerAPI = (*MockDockerClient)(nil)
|
||||
var _ iptables.IPTablesAPI = (*MockIPTablesManager)(nil)
|
||||
|
||||
// MockDockerClient implements docker.DockerAPI for testing
|
||||
type MockDockerClient struct {
|
||||
EnsureNetworkCalled bool
|
||||
EnsureNetworkCfg config.NetworkConfig
|
||||
EnsureNetworkErr error
|
||||
|
||||
ConnectContainerCalled bool
|
||||
ConnectContainerName string
|
||||
ConnectContainerNetwork string
|
||||
ConnectContainerIP string
|
||||
ConnectContainerErr error
|
||||
|
||||
WaitForRunningCalled bool
|
||||
WaitForRunningName string
|
||||
|
||||
GetContainerPIDCalled bool
|
||||
GetContainerPIDName string
|
||||
GetContainerPIDResult int
|
||||
GetContainerPIDErr error
|
||||
|
||||
AddRouteCalled bool
|
||||
AddRouteContainer string
|
||||
AddRouteNetwork string
|
||||
AddRouteGateway string
|
||||
AddRouteErr error
|
||||
|
||||
InspectContainerErr error
|
||||
RemoveNetworkErr error
|
||||
DisconnectContainerErr error
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) Close() error { return nil }
|
||||
|
||||
func (m *MockDockerClient) EnsureNetwork(ctx context.Context, netCfg config.NetworkConfig) error {
|
||||
m.EnsureNetworkCalled = true
|
||||
m.EnsureNetworkCfg = netCfg
|
||||
return m.EnsureNetworkErr
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) RemoveNetwork(ctx context.Context, networkName string) error {
|
||||
return m.RemoveNetworkErr
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) ConnectContainer(ctx context.Context, containerName, networkName, ip string) error {
|
||||
m.ConnectContainerCalled = true
|
||||
m.ConnectContainerName = containerName
|
||||
m.ConnectContainerNetwork = networkName
|
||||
m.ConnectContainerIP = ip
|
||||
return m.ConnectContainerErr
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) DisconnectContainer(ctx context.Context, containerName, networkName string) error {
|
||||
return m.DisconnectContainerErr
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) InspectContainer(ctx context.Context, containerName string) (*types.ContainerJSON, error) {
|
||||
return nil, m.InspectContainerErr
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) WaitForContainerRunning(ctx context.Context, containerName string, timeout time.Duration) error {
|
||||
m.WaitForRunningCalled = true
|
||||
m.WaitForRunningName = containerName
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) GetContainerPID(ctx context.Context, containerName string) (int, error) {
|
||||
m.GetContainerPIDCalled = true
|
||||
m.GetContainerPIDName = containerName
|
||||
return m.GetContainerPIDResult, m.GetContainerPIDErr
|
||||
}
|
||||
|
||||
func (m *MockDockerClient) AddRouteInContainer(ctx context.Context, containerName, network, gateway string) error {
|
||||
m.AddRouteCalled = true
|
||||
m.AddRouteContainer = containerName
|
||||
m.AddRouteNetwork = network
|
||||
m.AddRouteGateway = gateway
|
||||
return m.AddRouteErr
|
||||
}
|
||||
|
||||
// MockIPTablesManager implements iptables.IPTablesAPI for testing
|
||||
type MockIPTablesManager struct {
|
||||
BinaryResult string
|
||||
EnsureIPForwardCalled bool
|
||||
EnsureIPForwardErr error
|
||||
EnsureEstablishedRelatedCalled bool
|
||||
EnsureEstablishedRelatedChain string
|
||||
EnsureEstablishedRelatedErr error
|
||||
|
||||
InsertPreroutingRuleCalled bool
|
||||
InsertPreroutingRuleArgs []string
|
||||
InsertPreroutingRuleErr error
|
||||
|
||||
InsertPreroutingRuleOnInterfaceCalled bool
|
||||
InsertPreroutingRuleOnInterfaceArgs []string
|
||||
InsertPreroutingRuleOnInterfaceErr error
|
||||
|
||||
InsertPostroutingMasqueradeCalled bool
|
||||
InsertPostroutingMasqueradeArgs []string
|
||||
InsertPostroutingMasqueradeErr error
|
||||
|
||||
InsertForwardAcceptCalled bool
|
||||
InsertForwardAcceptChain string
|
||||
InsertForwardAcceptSourceIP string
|
||||
InsertForwardAcceptTargetIP string
|
||||
InsertForwardAcceptProto string
|
||||
InsertForwardAcceptSourcePort string
|
||||
InsertForwardAcceptTargetPort string
|
||||
InsertForwardAcceptComment string
|
||||
InsertForwardAcceptErr error
|
||||
|
||||
InsertPreroutingRuleInContainerCalled bool
|
||||
InsertPreroutingRuleInContainerPID int
|
||||
InsertPreroutingRuleInContainerArgs []string
|
||||
InsertPreroutingRuleInContainerErr error
|
||||
|
||||
InsertPostroutingMasqueradeInContainerCalled bool
|
||||
InsertPostroutingMasqueradeInContainerErr error
|
||||
DeleteForwardAcceptErr error
|
||||
DeleteLineErr error
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) Binary() string {
|
||||
if m.BinaryResult == "" {
|
||||
return "/usr/sbin/iptables"
|
||||
}
|
||||
return m.BinaryResult
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) EnsureIPForward() error {
|
||||
m.EnsureIPForwardCalled = true
|
||||
return m.EnsureIPForwardErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) EnsureEstablishedRelated(chain string) error {
|
||||
m.EnsureEstablishedRelatedCalled = true
|
||||
m.EnsureEstablishedRelatedChain = chain
|
||||
return m.EnsureEstablishedRelatedErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) DeleteLine(chain string, lineNum string) error {
|
||||
return m.DeleteLineErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertPreroutingRule(sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||
m.InsertPreroutingRuleCalled = true
|
||||
m.InsertPreroutingRuleArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment}
|
||||
return m.InsertPreroutingRuleErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertPreroutingRuleOnInterface(iface, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||
m.InsertPreroutingRuleOnInterfaceCalled = true
|
||||
m.InsertPreroutingRuleOnInterfaceArgs = []string{iface, proto, sourcePort, targetIP, targetPort, comment}
|
||||
return m.InsertPreroutingRuleOnInterfaceErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertPostroutingMasquerade(sourceCIDR, proto, sourcePort, comment string) error {
|
||||
m.InsertPostroutingMasqueradeCalled = true
|
||||
m.InsertPostroutingMasqueradeArgs = []string{sourceCIDR, proto, sourcePort, comment}
|
||||
return m.InsertPostroutingMasqueradeErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertPostroutingMasqueradeForTarget(targetCIDR, proto, targetPort, comment string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertForwardAccept(chain, sourceIP, targetIP, proto, sourcePort, targetPort, comment string) error {
|
||||
m.InsertForwardAcceptCalled = true
|
||||
m.InsertForwardAcceptChain = chain
|
||||
m.InsertForwardAcceptSourceIP = sourceIP
|
||||
m.InsertForwardAcceptTargetIP = targetIP
|
||||
m.InsertForwardAcceptProto = proto
|
||||
m.InsertForwardAcceptSourcePort = sourcePort
|
||||
m.InsertForwardAcceptTargetPort = targetPort
|
||||
m.InsertForwardAcceptComment = comment
|
||||
return m.InsertForwardAcceptErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) DeleteForwardAccept(chain, comment string) error {
|
||||
return m.DeleteForwardAcceptErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertPreroutingRuleInContainer(pid int, sourceIP, proto, sourcePort, targetIP, targetPort, comment string) error {
|
||||
m.InsertPreroutingRuleInContainerCalled = true
|
||||
m.InsertPreroutingRuleInContainerPID = pid
|
||||
m.InsertPreroutingRuleInContainerArgs = []string{sourceIP, proto, sourcePort, targetIP, targetPort, comment}
|
||||
return m.InsertPreroutingRuleInContainerErr
|
||||
}
|
||||
|
||||
func (m *MockIPTablesManager) InsertPostroutingMasqueradeInContainer(pid int, sourceCIDR, proto, sourcePort, comment string) error {
|
||||
m.InsertPostroutingMasqueradeInContainerCalled = true
|
||||
return m.InsertPostroutingMasqueradeInContainerErr
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
package resolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"firewall_containers/network-go/config"
|
||||
)
|
||||
|
||||
func makeTestConfig() *config.NetworksConfig {
|
||||
return &config.NetworksConfig{
|
||||
Networks: map[string]config.NetworkConfig{
|
||||
"test-net": {NetworkName: "test-net", Subnet: "172.18.103.0/24", Gateway: "172.18.103.1"},
|
||||
},
|
||||
IPs: map[string]config.IPConfig{
|
||||
"172.18.103.2": {IP: "172.18.103.2", ContainerName: "smarthostloadbalancer", Selector: "smarthostloadbalancer", ServiceName: "smarthost-proxy"},
|
||||
"172.18.104.2": {IP: "172.18.104.2", ContainerName: "smarthostbackend-1", Selector: "smarthostbackend-1", ServiceName: "smarthost-proxy"},
|
||||
},
|
||||
Policies: []config.PolicyConfig{},
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveByContainerName(t *testing.T) {
|
||||
cfg := makeTestConfig()
|
||||
r := NewResolver(cfg)
|
||||
|
||||
ips := r.Resolve("smarthostloadbalancer")
|
||||
if len(ips) != 1 {
|
||||
t.Fatalf("expected 1 IP, got %d", len(ips))
|
||||
}
|
||||
if ips[0] != "172.18.103.2" {
|
||||
t.Errorf("expected 172.18.103.2, got %s", ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBySelector(t *testing.T) {
|
||||
cfg := makeTestConfig()
|
||||
r := NewResolver(cfg)
|
||||
|
||||
ips := r.Resolve("smarthostbackend-1")
|
||||
if len(ips) != 1 {
|
||||
t.Fatalf("expected 1 IP, got %d", len(ips))
|
||||
}
|
||||
if ips[0] != "172.18.104.2" {
|
||||
t.Errorf("expected 172.18.104.2, got %s", ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveNotFound(t *testing.T) {
|
||||
cfg := makeTestConfig()
|
||||
r := NewResolver(cfg)
|
||||
|
||||
ips := r.Resolve("nonexistent-container")
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("expected 0 IPs, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveNilConfig(t *testing.T) {
|
||||
r := NewResolver(nil)
|
||||
|
||||
ips := r.Resolve("anything")
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("expected 0 IPs with nil config, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetConfig(t *testing.T) {
|
||||
r := NewResolver(nil)
|
||||
|
||||
// Initially nil config
|
||||
ips := r.Resolve("test")
|
||||
if len(ips) != 0 {
|
||||
t.Errorf("expected 0 IPs with nil config, got %d", len(ips))
|
||||
}
|
||||
|
||||
// Update with valid config
|
||||
cfg := makeTestConfig()
|
||||
r.SetConfig(cfg)
|
||||
|
||||
ips = r.Resolve("smarthostloadbalancer")
|
||||
if len(ips) != 1 {
|
||||
t.Fatalf("expected 1 IP after SetConfig, got %d", len(ips))
|
||||
}
|
||||
if ips[0] != "172.18.103.2" {
|
||||
t.Errorf("expected 172.18.103.2, got %s", ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveReproducible(t *testing.T) {
|
||||
cfg := makeTestConfig()
|
||||
r := NewResolver(cfg)
|
||||
|
||||
ips1 := r.Resolve("smarthostloadbalancer")
|
||||
ips2 := r.Resolve("smarthostloadbalancer")
|
||||
|
||||
if len(ips1) != len(ips2) {
|
||||
t.Errorf("reproducibility: result count mismatch %d vs %d", len(ips1), len(ips2))
|
||||
}
|
||||
if len(ips1) > 0 && ips1[0] != ips2[0] {
|
||||
t.Errorf("reproducibility: IP mismatch %s vs %s", ips1[0], ips2[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveMultipleMatch(t *testing.T) {
|
||||
cfg := &config.NetworksConfig{
|
||||
IPs: map[string]config.IPConfig{
|
||||
"10.0.0.1": {IP: "10.0.0.1", ContainerName: "app-1", Selector: "app-1"},
|
||||
"10.0.0.2": {IP: "10.0.0.2", ContainerName: "app-2", Selector: "app-2"},
|
||||
},
|
||||
}
|
||||
r := NewResolver(cfg)
|
||||
|
||||
// "app-x" has a dash, so prefix "app" is extracted and matches both app-1 and app-2
|
||||
ips := r.Resolve("app-x")
|
||||
if len(ips) != 2 {
|
||||
t.Errorf("expected 2 IPs for prefix match, got %d", len(ips))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,179 @@
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestWatcherDetectsChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
changeDetected := make(chan bool, 1)
|
||||
onChange := func() {
|
||||
select {
|
||||
case changeDetected <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
fw := NewFileWatcher(path, 100*time.Millisecond, onChange)
|
||||
fw.Start()
|
||||
defer fw.Stop()
|
||||
|
||||
// Wait for initial hash to be computed
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Modify the file
|
||||
if err := os.WriteFile(path, []byte(`{"version": 2}`), 0644); err != nil {
|
||||
t.Fatalf("failed to modify test file: %v", err)
|
||||
}
|
||||
|
||||
// Wait for change detection
|
||||
select {
|
||||
case detected := <-changeDetected:
|
||||
if !detected {
|
||||
t.Error("expected change detection, got false")
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("timeout waiting for file change detection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcherNoChange(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
if err := os.WriteFile(path, []byte(`{"version": 1}`), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
changeDetected := make(chan bool, 1)
|
||||
onChange := func() {
|
||||
select {
|
||||
case changeDetected <- true:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
fw := NewFileWatcher(path, 100*time.Millisecond, onChange)
|
||||
fw.Start()
|
||||
defer fw.Stop()
|
||||
|
||||
// Wait without modifying the file
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
// Should not detect a change
|
||||
select {
|
||||
case <-changeDetected:
|
||||
t.Error("unexpected change detection without file modification")
|
||||
default:
|
||||
// Expected: no change detected
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcherMultipleChanges(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
changeCount := 0
|
||||
var mu chan struct{}
|
||||
|
||||
onChange := func() {
|
||||
changeCount++
|
||||
if mu != nil {
|
||||
mu <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
fw := NewFileWatcher(path, 50*time.Millisecond, onChange)
|
||||
fw.Start()
|
||||
defer fw.Stop()
|
||||
|
||||
// Wait for initial hash
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Make 3 changes
|
||||
for i := 2; i <= 4; i++ {
|
||||
if err := os.WriteFile(path, []byte(`{"v": `+string(rune('0'+i))+`}`), 0644); err != nil {
|
||||
t.Fatalf("failed to modify file: %v", err)
|
||||
}
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
if changeCount < 2 {
|
||||
t.Errorf("expected at least 2 change detections, got %d", changeCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWatcherStop(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
if err := os.WriteFile(path, []byte(`{"v": 1}`), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
fw := NewFileWatcher(path, 50*time.Millisecond, func() {})
|
||||
fw.Start()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
fw.Stop()
|
||||
|
||||
// Should not panic on double stop or further operations
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWatcherNonexistentFile(t *testing.T) {
|
||||
onChange := func() {
|
||||
t.Error("onChange should not be called for nonexistent file")
|
||||
}
|
||||
|
||||
fw := NewFileWatcher("/nonexistent/file.json", 100*time.Millisecond, onChange)
|
||||
fw.Start()
|
||||
defer fw.Stop()
|
||||
|
||||
// Wait — should not call onChange since file doesn't exist
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestWatcherReproducible(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.json")
|
||||
if err := os.WriteFile(path, []byte(`{"data": "value"}`), 0644); err != nil {
|
||||
t.Fatalf("failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
// Create two watchers on the same file
|
||||
var detected1, detected2 bool
|
||||
|
||||
fw1 := NewFileWatcher(path, 50*time.Millisecond, func() { detected1 = true })
|
||||
fw2 := NewFileWatcher(path, 50*time.Millisecond, func() { detected2 = true })
|
||||
|
||||
fw1.Start()
|
||||
fw2.Start()
|
||||
defer fw1.Stop()
|
||||
defer fw2.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// Modify file
|
||||
if err := os.WriteFile(path, []byte(`{"data": "changed"}`), 0644); err != nil {
|
||||
t.Fatalf("failed to modify file: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
|
||||
if !detected1 {
|
||||
t.Error("watcher 1 did not detect change")
|
||||
}
|
||||
if !detected2 {
|
||||
t.Error("watcher 2 did not detect change")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user