From bf94206849643a77eb48b2138d345704689f41f6 Mon Sep 17 00:00:00 2001 From: gyurix Date: Mon, 15 Jun 2026 22:40:43 +0200 Subject: [PATCH] feat: Add POSTROUTING MASQUERADE and periodic state reconciliation - Add POSTROUTING MASQUERADE rule alongside DNAT rules to ensure return traffic from container targets can route back through the same interface, matching legacy shell script behavior - Enhance FileWatcher to trigger periodic state reconciliation every tick regardless of config file changes, ensuring desired state is maintained after container restarts or iptables flushes --- network-go/firewall/firewall.go | 24 ++++++++++++++ network-go/watcher/watcher.go | 51 ++++++++++++++++++++---------- network-go/watcher/watcher_test.go | 20 +++++------- 3 files changed, 66 insertions(+), 29 deletions(-) diff --git a/network-go/firewall/firewall.go b/network-go/firewall/firewall.go index 0b33b61..7df73bc 100644 --- a/network-go/firewall/firewall.go +++ b/network-go/firewall/firewall.go @@ -4,6 +4,7 @@ import ( "context" "net" "strconv" + "strings" "time" "firewall_containers/network-go/config" @@ -253,6 +254,29 @@ func (o *Orchestrator) applyNATRule(ctx context.Context, cfg *config.NetworksCon policy.Iface, targetIP, proto, port) } } + + // Always add MASQUERADE on POSTROUTING so return traffic from the + // DNAT target can route back through the same interface. + // This mirrors the old shell script behavior where POSTROUTING + // was always set alongside PREROUTING DNAT rules. + // Required regardless of whether DNAT was in container namespace or host. + if targetIP != "" { + masqComment := comment + "-masq" + targetSubnet := "" + // Use the target's /24 subnet as the source CIDR for masquerade + if strings.Contains(targetIP, ".") { + targetSubnet = targetIP[:strings.LastIndex(targetIP, ".")] + ".0/24" + } + if targetSubnet != "" { + logger.Info("FIREWALL: inserting POSTROUTING MASQUERADE for %s", targetSubnet) + if err := o.iptablesMgr.InsertPostroutingMasquerade(targetSubnet, proto, port, masqComment); err != nil { + logger.Error("FIREWALL: failed to insert POSTROUTING MASQUERADE: %v", err) + } else { + logger.Info("FIREWALL: POSTROUTING MASQUERADE inserted: subnet=%s proto=%s port=%s", + targetSubnet, proto, port) + } + } + } } } diff --git a/network-go/watcher/watcher.go b/network-go/watcher/watcher.go index f92ff76..ad7294a 100644 --- a/network-go/watcher/watcher.go +++ b/network-go/watcher/watcher.go @@ -9,23 +9,27 @@ import ( "firewall_containers/network-go/logger" ) -// FileWatcher periodically checks a file for changes using MD5 hash +// FileWatcher periodically polls a file for changes AND triggers a periodic +// reconciliation callback regardless of file changes. type FileWatcher struct { - path string - period time.Duration - lastHash string - onChange func() - stopCh chan struct{} + path string + period time.Duration + lastHash string + onChange func() + stopCh chan struct{} } -// NewFileWatcher creates a new file watcher that polls the file at the given period +// NewFileWatcher creates a new file watcher that: +// 1. Polls the file for content changes at the given period +// 2. Triggers a reconciliation callback every period regardless of changes +// to ensure desired state is maintained (stateful reconciliation) func NewFileWatcher(path string, period time.Duration, onChange func()) *FileWatcher { return &FileWatcher{ - path: path, - period: period, - lastHash: "", - onChange: onChange, - stopCh: make(chan struct{}), + path: path, + period: period, + lastHash: "", + onChange: onChange, + stopCh: make(chan struct{}), } } @@ -38,7 +42,9 @@ func (fw *FileWatcher) hashFile() (string, error) { return fmt.Sprintf("%x", md5.Sum(data)), nil } -// Start begins polling the file for changes in a goroutine +// Start begins polling the file for changes in a goroutine. +// Every period, it checks if the file changed AND triggers a full reconciliation +// to maintain the desired state (handles container restarts, iptables flushes, etc.) func (fw *FileWatcher) Start() { // Compute initial hash hash, err := fw.hashFile() @@ -52,7 +58,7 @@ func (fw *FileWatcher) Start() { ticker := time.NewTicker(fw.period) defer ticker.Stop() - logger.Info("WATCHER: started watching %s every %s", fw.path, fw.period) + logger.Info("WATCHER: started watching %s every %s (periodic reconciliation enabled)", fw.path, fw.period) for { select { @@ -66,12 +72,23 @@ func (fw *FileWatcher) Start() { continue } - if hash != fw.lastHash { + fileChanged := hash != fw.lastHash + if fileChanged { logger.Info("WATCHER: detected change in %s", fw.path) fw.lastHash = hash - if fw.onChange != nil { - fw.onChange() + } + + // Trigger reconciliation every period to maintain state, + // even if the config file hasn't changed. + // This ensures container restarts, iptable flushes, etc. + // are corrected. + if fw.onChange != nil { + if fileChanged { + logger.Info("WATCHER: triggering reconciliation (config changed)") + } else { + logger.Debug("WATCHER: triggering periodic state reconciliation") } + fw.onChange() } } } diff --git a/network-go/watcher/watcher_test.go b/network-go/watcher/watcher_test.go index 7e99b41..8dbad68 100644 --- a/network-go/watcher/watcher_test.go +++ b/network-go/watcher/watcher_test.go @@ -52,12 +52,11 @@ func TestWatcherNoChange(t *testing.T) { t.Fatalf("failed to write test file: %v", err) } - changeDetected := make(chan bool, 1) + // With periodic reconciliation, onChange will be called every period. + // Count how many times it's called within the wait period. + callCount := 0 onChange := func() { - select { - case changeDetected <- true: - default: - } + callCount++ } fw := NewFileWatcher(path, 100*time.Millisecond, onChange) @@ -65,14 +64,11 @@ func TestWatcherNoChange(t *testing.T) { defer fw.Stop() // Wait without modifying the file - time.Sleep(300 * time.Millisecond) + time.Sleep(350 * time.Millisecond) - // Should not detect a change - select { - case <-changeDetected: - t.Error("unexpected change detection without file modification") - default: - // Expected: no change detected + // onChange should have been called ~3 times (0s, 0.1s, 0.2s, 0.3s) for periodic reconciliation + if callCount < 1 { + t.Errorf("expected at least 1 periodic reconciliation call, got %d", callCount) } }