Files
firewall_containers/network-go/firewall/firewall.go
gyurix 246346f8b1
continuous-integration/drone/push Build is passing
feat(docker, firewall): Add stateful network connection check and optimize NAT rules
This adds an IsConnected method to verify if a container is already connected to a network with the expected IP, preventing redundant operations. In reconcileIPs, it skips reconnections if the state is correct. In applyNATRule, MASQUERADE is now applied in the same namespace as DNAT (container or host) for consistent and accurate rule application.
2026-06-15 23:39:58 +02:00

339 lines
12 KiB
Go

package firewall
import (
"context"
"net"
"strconv"
"strings"
"time"
"firewall_containers/network-go/config"
"firewall_containers/network-go/docker"
"firewall_containers/network-go/iptables"
"firewall_containers/network-go/logger"
"firewall_containers/network-go/resolver"
)
// Orchestrator reconciles the networks.json configuration into Docker networks
// and iptables firewall rules
type Orchestrator struct {
dockerClient docker.DockerAPI
iptablesMgr iptables.IPTablesAPI
resolver *resolver.Resolver
debug bool
}
// NewOrchestrator creates a new firewall orchestrator
func NewOrchestrator(dockerClient docker.DockerAPI, iptablesMgr iptables.IPTablesAPI, cfg *config.NetworksConfig) *Orchestrator {
return &Orchestrator{
dockerClient: dockerClient,
iptablesMgr: iptablesMgr,
resolver: resolver.NewResolver(cfg),
}
}
// ReconcileAll runs the full reconciliation: networks, container connections, and firewall rules
func (o *Orchestrator) ReconcileAll(ctx context.Context, cfg *config.NetworksConfig) {
logger.Info("FIREWALL: starting full reconciliation")
logger.Debug("FIREWALL: config has %d networks, %d IPs, %d policies",
len(cfg.Networks), len(cfg.IPs), len(cfg.Policies))
// Update resolver with latest config
o.resolver.SetConfig(cfg)
// Step 0: Enable IP forwarding (may fail in containers with read-only fs)
if err := o.iptablesMgr.EnsureIPForward(); err != nil {
logger.Warn("FIREWALL: could not enable ip_forward: %v", err)
} else {
logger.Info("FIREWALL: IP forwarding enabled")
}
// Step 1: Ensure all defined networks exist
o.reconcileNetworks(ctx, cfg)
// Step 2: Connect containers to networks with assigned IPs
o.reconcileIPs(ctx, cfg)
// Step 3: Reconcile firewall policies
o.reconcilePolicies(ctx, cfg)
logger.Info("FIREWALL: full reconciliation completed")
}
// reconcileNetworks creates Docker networks if they don't exist
func (o *Orchestrator) reconcileNetworks(ctx context.Context, cfg *config.NetworksConfig) {
for name, netCfg := range cfg.Networks {
logger.Info("FIREWALL: ensuring network %s (name=%s, subnet=%s, gateway=%s)",
name, netCfg.NetworkName, netCfg.Subnet, netCfg.Gateway)
if err := o.dockerClient.EnsureNetwork(ctx, netCfg); err != nil {
logger.Error("FIREWALL: failed to ensure network %s: %v", name, err)
} else {
logger.Debug("FIREWALL: network %s ready", name)
}
}
}
// reconcileIPs connects containers to networks with their assigned IPs
func (o *Orchestrator) reconcileIPs(ctx context.Context, cfg *config.NetworksConfig) {
for ipStr, ipCfg := range cfg.IPs {
networkName := findNetworkForIP(cfg, ipStr)
if networkName == "" {
logger.Warn("FIREWALL: no network found for IP %s (container=%s, selector=%s)",
ipStr, ipCfg.ContainerName, ipCfg.Selector)
continue
}
logger.Info("FIREWALL: resolving container name for IP %s (container=%s, selector=%s)",
ipStr, ipCfg.ContainerName, ipCfg.Selector)
// Resolve the actual container name, with fallback to fuzzy matching
// (old shell script behavior: docker ps | grep $D"-")
containerName, err := o.dockerClient.FindContainerName(ctx, ipCfg.ContainerName, ipCfg.Selector)
if err != nil {
logger.Warn("FIREWALL: container %s (selector=%s) not found: %v, using config name anyway",
ipCfg.ContainerName, ipCfg.Selector, err)
containerName = ipCfg.ContainerName
} else if containerName != ipCfg.ContainerName {
logger.Info("FIREWALL: container resolved: config_name=%s -> actual=%s",
ipCfg.ContainerName, containerName)
}
// Stateful check: verify container already has the correct IP on this network
if o.dockerClient.IsConnected(ctx, containerName, networkName, ipStr) {
logger.Debug("FIREWALL: container %s already connected to %s with IP %s, skipping",
containerName, networkName, ipStr)
continue
}
logger.Info("FIREWALL: connecting container %s to network %s with IP %s",
containerName, networkName, ipStr)
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
waitErr := o.dockerClient.WaitForContainerRunning(waitCtx, containerName, 10*time.Second)
cancel()
if waitErr != nil {
logger.Warn("FIREWALL: container %s not running yet: %v, connecting anyway",
containerName, waitErr)
} else {
logger.Debug("FIREWALL: container %s is running", containerName)
}
if err := o.dockerClient.ConnectContainer(ctx, containerName, networkName, ipStr); err != nil {
logger.Error("FIREWALL: failed to connect container %s to %s: %v",
containerName, networkName, err)
} else {
logger.Info("FIREWALL: container %s connected to network %s with IP %s",
containerName, networkName, ipStr)
}
}
}
// reconcilePolicies translates PolicyConfig entries into iptables rules
func (o *Orchestrator) reconcilePolicies(ctx context.Context, cfg *config.NetworksConfig) {
for i, policy := range cfg.Policies {
logger.Info("FIREWALL: processing policy[%d]", i)
logger.Debug("FIREWALL: policy[%d] details: service=%s container=%s selector=%s from=%s to=%s port=%d proto=%s nat=%s iface=%s",
i, policy.ServiceName, policy.ContainerName, policy.Selector,
policy.From, policy.To, policy.Port, policy.Proto, policy.Nat, policy.Iface)
proto := policy.Proto
if proto == "" {
proto = "tcp"
}
port := strconv.Itoa(policy.Port)
// Build comment for iptables (matches shell script's NAME-COMMENT pattern)
// Use Name if present, otherwise ServiceName, to avoid trailing dashes
comment := ""
if policy.Name != "" {
comment = policy.Name
}
if policy.ServiceName != "" {
if comment != "" {
comment += "-" + policy.ServiceName
} else {
comment = policy.ServiceName
}
}
logger.Debug("FIREWALL: policy[%d] comment=%q", i, comment)
// CASE 1: Rule with "from" field — this is a FORWARD ACCEPT rule
if policy.From != "" {
o.applyForwardRule(ctx, cfg, policy, proto, port, comment)
continue
}
// CASE 2: Rule with "nat" field — this is a DNAT/MASQUERADE rule
if policy.Nat != "" {
o.applyNATRule(ctx, cfg, policy, proto, port, comment)
continue
}
// Unhandled pattern
logger.Warn("FIREWALL: policy[%d] unhandled pattern — service=%s container=%s selector=%s from=%s to=%s port=%d proto=%s nat=%s",
i, policy.ServiceName, policy.ContainerName, policy.Selector, policy.From, policy.To, policy.Port, policy.Proto, policy.Nat)
}
}
func (o *Orchestrator) applyForwardRule(ctx context.Context, cfg *config.NetworksConfig, policy config.PolicyConfig, proto, port, comment string) {
sourceIP := o.resolveIP(policy.From)
targetIP := ""
if policy.To != "" {
targetIP = o.resolveIP(policy.To)
}
logger.Info("FIREWALL: forward rule: from=%q (IP=%s) to=%q (IP=%s) proto=%s port=%s",
policy.From, sourceIP, policy.To, targetIP, proto, port)
// Determine the chain: use DOCKER-USER (iptables-legacy) or FORWARD
chain := "FORWARD"
if o.iptablesMgr.Binary() == "/usr/sbin/iptables-legacy" {
chain = "DOCKER-USER"
}
logger.Debug("FIREWALL: using iptables chain=%s (binary=%s)", chain, o.iptablesMgr.Binary())
// Ensure established/related rule exists at the top
if err := o.iptablesMgr.EnsureEstablishedRelated(chain); err != nil {
logger.Error("FIREWALL: failed to ensure established/related rule in %s: %v", chain, err)
} else {
logger.Debug("FIREWALL: established/related rule ensured in %s", chain)
}
// Insert the FORWARD ACCEPT rule
if err := o.iptablesMgr.InsertForwardAccept(chain, sourceIP, targetIP, proto, "", port, comment); err != nil {
logger.Error("FIREWALL: failed to insert FORWARD ACCEPT rule in %s: %v", chain, err)
} else {
logger.Info("FIREWALL: FORWARD ACCEPT rule inserted: chain=%s src=%s dst=%s proto=%s port=%s comment=%q",
chain, sourceIP, targetIP, proto, port, comment)
}
}
func (o *Orchestrator) applyNATRule(ctx context.Context, cfg *config.NetworksConfig, policy config.PolicyConfig, proto, port, comment string) {
to := policy.To
logger.Info("FIREWALL: NAT rule: to=%s proto=%s port=%s nat=%s iface=%s",
to, proto, port, policy.Nat, policy.Iface)
// Resolve "to" as target IP
targetIP := o.resolveIP(to)
logger.Debug("FIREWALL: resolved target %q -> IP=%q", to, targetIP)
if targetIP == "" {
logger.Warn("FIREWALL: cannot resolve target %s for nat policy", to)
return
}
if policy.Nat == "dnat" {
// Determine the best container selector from the policy: try Selector, then ContainerName, then Name
selector := policy.Selector
if selector == "" {
selector = policy.ContainerName
}
if selector == "" {
selector = policy.Name
}
logger.Debug("FIREWALL: DNAT selector=%s", selector)
// Try to insert rules inside the container namespace via nsenter
usedContainer := false
var containerPID int
if selector != "" {
pid, err := o.dockerClient.GetContainerPID(ctx, selector)
if err == nil {
logger.Info("FIREWALL: inserting DNAT rule in container %s (PID=%d)", selector, pid)
if err := o.iptablesMgr.InsertPreroutingRuleInContainer(pid, "0.0.0.0/0", proto, port, targetIP, port, comment); err != nil {
logger.Error("FIREWALL: failed to insert container PREROUTING rule: %v", err)
} else {
logger.Info("FIREWALL: DNAT rule inserted in container %s: target=%s proto=%s port=%s",
selector, targetIP, proto, port)
usedContainer = true
containerPID = pid
}
} else {
logger.Warn("FIREWALL: cannot get PID for container %s: %v, trying host rules", selector, err)
}
}
// Fall back to host-level PREROUTING if container not used
if !usedContainer && policy.Iface != "" {
logger.Info("FIREWALL: inserting host-level DNAT rule on interface %s", policy.Iface)
if err := o.iptablesMgr.InsertPreroutingRuleOnInterface(policy.Iface, proto, port, targetIP, port, comment); err != nil {
logger.Error("FIREWALL: failed to insert interface PREROUTING rule on %s: %v", policy.Iface, err)
} else {
logger.Info("FIREWALL: host DNAT rule inserted: iface=%s target=%s proto=%s port=%s",
policy.Iface, targetIP, proto, port)
}
}
// Always add MASQUERADE on POSTROUTING so return traffic from the
// DNAT target can route back through the same interface.
// If DNAT was in a container namespace, apply POSTROUTING in the same namespace.
// If DNAT was on the host, apply POSTROUTING on the host.
if targetIP != "" {
masqComment := comment + "-masq"
targetSubnet := ""
if strings.Contains(targetIP, ".") {
targetSubnet = targetIP[:strings.LastIndex(targetIP, ".")] + ".0/24"
}
if targetSubnet != "" {
logger.Info("FIREWALL: inserting POSTROUTING MASQUERADE for %s", targetSubnet)
if usedContainer && containerPID > 0 {
// Apply in container namespace alongside the DNAT rule
logger.Info("FIREWALL: POSTROUTING MASQUERADE in container PID %d", containerPID)
if err := o.iptablesMgr.InsertPostroutingMasqueradeInContainer(containerPID, targetSubnet, proto, port, masqComment); err != nil {
logger.Error("FIREWALL: failed to insert container POSTROUTING MASQUERADE: %v", err)
} else {
logger.Info("FIREWALL: container POSTROUTING MASQUERADE inserted: subnet=%s proto=%s port=%s",
targetSubnet, proto, port)
}
} else {
// Apply on host POSTROUTING
logger.Info("FIREWALL: POSTROUTING MASQUERADE on host")
if err := o.iptablesMgr.InsertPostroutingMasquerade(targetSubnet, proto, port, masqComment); err != nil {
logger.Error("FIREWALL: failed to insert host POSTROUTING MASQUERADE: %v", err)
} else {
logger.Info("FIREWALL: host POSTROUTING MASQUERADE inserted: subnet=%s proto=%s port=%s",
targetSubnet, proto, port)
}
}
}
}
}
}
// resolveIP resolves a name or IP string to an IP address using networks.json config
func (o *Orchestrator) resolveIP(name string) string {
// If it's already an IP, return it as CIDR
if config.IsIP(name) {
result := config.ToCIDR(name)
logger.Debug("FIREWALL: resolveIP(%q): direct IP -> %s", name, result)
return result
}
// Use the resolver which looks up from networks.json
ips := o.resolver.Resolve(name)
if len(ips) > 0 {
logger.Debug("FIREWALL: resolveIP(%q): resolved -> %s", name, ips[0])
return ips[0]
}
logger.Debug("FIREWALL: resolveIP(%q): not found", name)
return ""
}
// findNetworkForIP finds the network name that contains the given IP in its subnet
func findNetworkForIP(cfg *config.NetworksConfig, ip string) string {
for _, netCfg := range cfg.Networks {
subnet, err := netCfg.ParseCIDR()
if err != nil {
continue
}
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
continue
}
if subnet.Contains(parsedIP) {
return netCfg.NetworkName
}
}
return ""
}