diff --git a/caddyconfig/httpcaddyfile/addresses.go b/caddyconfig/httpcaddyfile/addresses.go index 1121776d98f..ae70a92aaae 100644 --- a/caddyconfig/httpcaddyfile/addresses.go +++ b/caddyconfig/httpcaddyfile/addresses.go @@ -25,6 +25,7 @@ import ( "unicode" "github.com/caddyserver/certmagic" + "go.uber.org/zap" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" @@ -307,29 +308,75 @@ func (st *ServerType) listenersForServerBlockAddress(sblock serverBlock, addr Ad } // the bind directive specifies hosts (and potentially network), and the protocols to serve them with, but is optional - lnCfgVals := make([]addressesWithProtocols, 0, len(sblock.pile["bind"])) + lnCfgVals := make([]bindOptions, 0, len(sblock.pile["bind"])) for _, cfgVal := range sblock.pile["bind"] { - if val, ok := cfgVal.Value.(addressesWithProtocols); ok { + if val, ok := cfgVal.Value.(bindOptions); ok { lnCfgVals = append(lnCfgVals, val) } } if len(lnCfgVals) == 0 { if defaultBindValues, ok := options["default_bind"].([]ConfigValue); ok { for _, defaultBindValue := range defaultBindValues { - lnCfgVals = append(lnCfgVals, defaultBindValue.Value.(addressesWithProtocols)) + lnCfgVals = append(lnCfgVals, defaultBindValue.Value.(bindOptions)) } } else { - lnCfgVals = []addressesWithProtocols{{ - addresses: []string{""}, - protocols: nil, + lnCfgVals = []bindOptions{{ + addresses: []string{""}, + interfaces: nil, + protocols: nil, }} } } // use a map to prevent duplication + interfaceAddresses := map[string][]string{} listeners := map[string]map[string]struct{}{} for _, lnCfgVal := range lnCfgVals { - for _, lnAddr := range lnCfgVal.addresses { + addresses := []string{} + addresses = append(addresses, lnCfgVal.addresses...) + for _, lnIface := range lnCfgVal.interfaces { + lnNetw, lnDevice, _, err := caddy.SplitNetworkAddress(lnIface) + if err != nil { + return nil, fmt.Errorf("splitting listener interface: %v", err) + } + + ifaceAddresses, ok := interfaceAddresses[lnDevice] + if !ok { + iface, err := net.InterfaceByName(lnDevice) + if err != nil { + return nil, fmt.Errorf("querying listener interface: %v: %v", lnDevice, err) + } + if iface == nil { + return nil, fmt.Errorf("querying listener interface: %v", lnDevice) + } + ifaceAddrs, err := iface.Addrs() + if err != nil { + return nil, fmt.Errorf("querying listener interface addresses: %v: %v", lnDevice, err) + } + for _, ifaceAddr := range ifaceAddrs { + var ip net.IP + switch ifaceAddrValue := ifaceAddr.(type) { + case *net.IPAddr: + ip = ifaceAddrValue.IP + case *net.IPNet: + ip = ifaceAddrValue.IP + default: + caddy.Log().Error("reading listener interface address", zap.String("device", lnDevice), zap.String("address", ifaceAddr.String())) + continue + } + + if len(ip) == net.IPv4len && caddy.IsIPv4Network(lnNetw) || len(ip) == net.IPv6len && caddy.IsIPv6Network(lnNetw) { + ifaceAddresses = append(ifaceAddresses, caddy.JoinNetworkAddress(lnNetw, ip.String(), "")) + } + } + if len(ifaceAddresses) == 0 { + return nil, fmt.Errorf("querying listener interface addresses for network: %v: %v", lnDevice, lnNetw) + } + interfaceAddresses[lnDevice] = ifaceAddresses + } + addresses = append(addresses, ifaceAddresses...) + } + for _, lnAddr := range addresses { lnNetw, lnHost, _, err := caddy.SplitNetworkAddress(lnAddr) if err != nil { return nil, fmt.Errorf("splitting listener address: %v", err) @@ -350,11 +397,10 @@ func (st *ServerType) listenersForServerBlockAddress(sblock serverBlock, addr Ad return listeners, nil } -// addressesWithProtocols associates a list of listen addresses -// with a list of protocols to serve them with -type addressesWithProtocols struct { - addresses []string - protocols []string +type bindOptions struct { + addresses []string + interfaces []string + protocols []string } // Address represents a site address. It contains diff --git a/caddyconfig/httpcaddyfile/builtins.go b/caddyconfig/httpcaddyfile/builtins.go index 061aaa48b8d..a0eb770470b 100644 --- a/caddyconfig/httpcaddyfile/builtins.go +++ b/caddyconfig/httpcaddyfile/builtins.go @@ -57,16 +57,22 @@ func init() { // parseBind parses the bind directive. Syntax: // -// bind [{ -// protocols [h1|h2|h2c|h3] [...] -// }] +// bind [{ +// interfaces +// protocols [h1|h2|h2c|h3] [...] +// }] func parseBind(h Helper) ([]ConfigValue, error) { h.Next() // consume directive name - var addresses, protocols []string + var addresses, interfaces, protocols []string addresses = h.RemainingArgs() for h.NextBlock(0) { switch h.Val() { + case "interfaces": + interfaces = h.RemainingArgs() + if len(interfaces) == 0 { + return nil, h.Errf("interfaces requires one or more arguments") + } case "protocols": protocols = h.RemainingArgs() if len(protocols) == 0 { @@ -77,9 +83,10 @@ func parseBind(h Helper) ([]ConfigValue, error) { } } - return []ConfigValue{{Class: "bind", Value: addressesWithProtocols{ - addresses: addresses, - protocols: protocols, + return []ConfigValue{{Class: "bind", Value: bindOptions{ + addresses: addresses, + interfaces: interfaces, + protocols: protocols, }}}, nil } diff --git a/caddyconfig/httpcaddyfile/options.go b/caddyconfig/httpcaddyfile/options.go index 336c6999f92..c605ff21138 100644 --- a/caddyconfig/httpcaddyfile/options.go +++ b/caddyconfig/httpcaddyfile/options.go @@ -307,8 +307,7 @@ func parseOptSingleString(d *caddyfile.Dispenser, _ any) (any, error) { func parseOptDefaultBind(d *caddyfile.Dispenser, _ any) (any, error) { d.Next() // consume option name - - var addresses, protocols []string + var addresses, interfaces, protocols []string addresses = d.RemainingArgs() if len(addresses) == 0 { @@ -317,6 +316,11 @@ func parseOptDefaultBind(d *caddyfile.Dispenser, _ any) (any, error) { for d.NextBlock(0) { switch d.Val() { + case "interfaces": + interfaces = d.RemainingArgs() + if len(interfaces) == 0 { + return nil, d.Errf("interfaces requires one or more arguments") + } case "protocols": protocols = d.RemainingArgs() if len(protocols) == 0 { @@ -327,9 +331,10 @@ func parseOptDefaultBind(d *caddyfile.Dispenser, _ any) (any, error) { } } - return []ConfigValue{{Class: "bind", Value: addressesWithProtocols{ - addresses: addresses, - protocols: protocols, + return []ConfigValue{{Class: "bind", Value: bindOptions{ + addresses: addresses, + interfaces: interfaces, + protocols: protocols, }}}, nil } diff --git a/caddyconfig/httpcaddyfile/tlsapp.go b/caddyconfig/httpcaddyfile/tlsapp.go index 30948f84fff..b04d42daf48 100644 --- a/caddyconfig/httpcaddyfile/tlsapp.go +++ b/caddyconfig/httpcaddyfile/tlsapp.go @@ -221,7 +221,7 @@ func (st ServerType) buildTLSApp( if acmeIssuer.Challenges.BindHost == "" { // only binding to one host is supported var bindHost string - if asserted, ok := cfgVal.Value.(addressesWithProtocols); ok && len(asserted.addresses) > 0 { + if asserted, ok := cfgVal.Value.(bindOptions); ok && len(asserted.addresses) > 0 { bindHost = asserted.addresses[0] } acmeIssuer.Challenges.BindHost = bindHost @@ -613,7 +613,7 @@ func fillInGlobalACMEDefaults(issuer certmagic.Issuer, options map[string]any) e // In Linux the same call will error with EADDRINUSE whenever the listener for the automation policy is opened if acmeIssuer.Challenges == nil || (acmeIssuer.Challenges.DNS == nil && acmeIssuer.Challenges.BindHost == "") { if defBinds, ok := globalDefaultBind.([]ConfigValue); ok && len(defBinds) > 0 { - if abp, ok := defBinds[0].Value.(addressesWithProtocols); ok && len(abp.addresses) > 0 { + if abp, ok := defBinds[0].Value.(bindOptions); ok && len(abp.addresses) > 0 { if acmeIssuer.Challenges == nil { acmeIssuer.Challenges = new(caddytls.ChallengesConfig) } diff --git a/listen.go b/listen.go index fba9c3a6ba6..466efd79588 100644 --- a/listen.go +++ b/listen.go @@ -21,7 +21,6 @@ import ( "fmt" "net" "os" - "slices" "strconv" "sync" "sync/atomic" @@ -37,7 +36,7 @@ func reuseUnixSocket(_, _ string) (any, error) { func listenReusable(ctx context.Context, lnKey string, network, address string, config net.ListenConfig) (any, error) { var socketFile *os.File - fd := slices.Contains([]string{"fd", "fdgram"}, network) + fd := IsFdNetwork(network) if fd { socketFd, err := strconv.ParseUint(address, 0, strconv.IntSize) if err != nil { @@ -66,8 +65,8 @@ func listenReusable(ctx context.Context, lnKey string, network, address string, } } - datagram := slices.Contains([]string{"udp", "udp4", "udp6", "unixgram", "fdgram"}, network) - if datagram { + packet := IsPacketNetwork(network) + if packet { sharedPc, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { var ( pc net.PacketConn diff --git a/listen_unix.go b/listen_unix.go index d6ae0cb8ebb..856d0980ecf 100644 --- a/listen_unix.go +++ b/listen_unix.go @@ -27,7 +27,6 @@ import ( "io/fs" "net" "os" - "slices" "strconv" "sync" "sync/atomic" @@ -102,7 +101,7 @@ func listenReusable(ctx context.Context, lnKey string, network, address string, socketFile *os.File ) - fd := slices.Contains([]string{"fd", "fdgram"}, network) + fd := IsFdNetwork(network) if fd { socketFd, err := strconv.ParseUint(address, 0, strconv.IntSize) if err != nil { @@ -142,8 +141,8 @@ func listenReusable(ctx context.Context, lnKey string, network, address string, } } - datagram := slices.Contains([]string{"udp", "udp4", "udp6", "unixgram", "fdgram"}, network) - if datagram { + packet := IsPacketNetwork(network) + if packet { if fd { ln, err = net.FilePacketConn(socketFile) } else { @@ -161,7 +160,7 @@ func listenReusable(ctx context.Context, lnKey string, network, address string, listenerPool.LoadOrStore(lnKey, nil) } - if datagram { + if packet { if !fd { // TODO: Not 100% sure this is necessary, but we do this for net.UnixListener, so... if unix, ok := ln.(*net.UnixConn); ok { diff --git a/listeners.go b/listeners.go index b673c86e109..286a76599af 100644 --- a/listeners.go +++ b/listeners.go @@ -38,10 +38,6 @@ import ( "github.com/caddyserver/caddy/v2/internal" ) -// listenFdsStart is the first file descriptor number for systemd socket activation. -// File descriptors 0, 1, 2 are reserved for stdin, stdout, stderr. -const listenFdsStart = 3 - // NetworkAddress represents one or more network addresses. // It contains the individual components for a parsed network // address of the form accepted by ParseNetworkAddress(). @@ -137,42 +133,45 @@ func (na NetworkAddress) ListenAll(ctx context.Context, config net.ListenConfig) // Listen synchronizes binds to unix domain sockets to avoid race conditions // while an existing socket is unlinked. func (na NetworkAddress) Listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) { - if na.IsUnixNetwork() { - unixSocketsMu.Lock() - defer unixSocketsMu.Unlock() - } + var ( + ln any + err error + ) - // check to see if plugin provides listener - if ln, err := getListenerFromPlugin(ctx, na.Network, na.Host, na.port(), portOffset, config); ln != nil || err != nil { + // check to see if plugin provides a listener + if ln, err = getListenerFromPlugin(ctx, na.Network, na.Host, na.port(), portOffset, config); ln != nil || err != nil { return ln, err } // create (or reuse) the listener ourselves - return na.listen(ctx, portOffset, config) -} - -func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) { var ( - ln any - err error address string unixFileMode fs.FileMode ) + // lock other unix sockets from being bound and // split unix socket addr early so lnKey // is independent of permissions bits if na.IsUnixNetwork() { + unixSocketsMu.Lock() + defer unixSocketsMu.Unlock() + address, unixFileMode, err = internal.SplitUnixSocketPermissionsBits(na.Host) if err != nil { return nil, err } } else if na.IsFdNetwork() { - address = na.Host + socketFd, err := strconv.ParseUint(na.Host, 0, strconv.IntSize) + if err != nil { + return nil, fmt.Errorf("invalid file descriptor: %v", err) + } + + address = strconv.FormatUint(uint64(uint(socketFd)+portOffset), 10) } else { address = na.JoinHostPort(portOffset) } - if strings.HasPrefix(na.Network, "ip") { + if na.IsIpNetwork() { ln, err = config.ListenPacket(ctx, na.Network, address) } else { if na.IsUnixNetwork() { @@ -209,11 +208,29 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net } // IsUnixNetwork returns true if na.Network is -// unix, unixgram, or unixpacket. +// unix, unixgram, unixpacket, or unix+h2c. func (na NetworkAddress) IsUnixNetwork() bool { return IsUnixNetwork(na.Network) } +// IsTCPNetwork returns true if na.Network is +// tcp, tcp4, or tcp6. +func (na NetworkAddress) IsTCPNetwork() bool { + return IsTCPNetwork(na.Network) +} + +// IsUDPNetwork returns true if na.Network is +// udp, udp4, or udp6. +func (na NetworkAddress) IsUDPNetwork() bool { + return IsUDPNetwork(na.Network) +} + +// IsIpNetwork returns true if na.Network starts with +// ip: ip4: or ip6: +func (na NetworkAddress) IsIpNetwork() bool { + return IsIpNetwork(na.Network) +} + // IsFdNetwork returns true if na.Network is // fd or fdgram. func (na NetworkAddress) IsFdNetwork() bool { @@ -293,80 +310,12 @@ func (na NetworkAddress) port() string { // The output can be parsed by ParseNetworkAddress(). If the // address is a unix socket, any non-zero port will be dropped. func (na NetworkAddress) String() string { - if na.Network == "tcp" && (na.Host != "" || na.port() != "") { + if na.Network == TCP && (na.Host != "" || na.port() != "") { na.Network = "" // omit default network value for brevity } return JoinNetworkAddress(na.Network, na.Host, na.port()) } -// IsUnixNetwork returns true if the netw is a unix network. -func IsUnixNetwork(netw string) bool { - return strings.HasPrefix(netw, "unix") -} - -// IsFdNetwork returns true if the netw is a fd network. -func IsFdNetwork(netw string) bool { - return strings.HasPrefix(netw, "fd") -} - -// getFdByName returns the file descriptor number for the given -// socket name from systemd's LISTEN_FDNAMES environment variable. -// Socket names are provided by systemd via socket activation. -// -// The name can optionally include an index to handle multiple sockets -// with the same name: "web:0" for first, "web:1" for second, etc. -// If no index is specified, defaults to index 0 (first occurrence). -func getFdByName(nameWithIndex string) (int, error) { - if nameWithIndex == "" { - return 0, fmt.Errorf("socket name cannot be empty") - } - - fdNamesStr := os.Getenv("LISTEN_FDNAMES") - if fdNamesStr == "" { - return 0, fmt.Errorf("LISTEN_FDNAMES environment variable not set") - } - - // Parse name and optional index - parts := strings.Split(nameWithIndex, ":") - if len(parts) > 2 { - return 0, fmt.Errorf("invalid socket name format '%s': too many colons", nameWithIndex) - } - - name := parts[0] - targetIndex := 0 - - if len(parts) > 1 { - var err error - targetIndex, err = strconv.Atoi(parts[1]) - if err != nil { - return 0, fmt.Errorf("invalid socket index '%s': %v", parts[1], err) - } - if targetIndex < 0 { - return 0, fmt.Errorf("socket index cannot be negative: %d", targetIndex) - } - } - - // Parse the socket names - names := strings.Split(fdNamesStr, ":") - - // Find the Nth occurrence of the requested name - matchCount := 0 - for i, fdName := range names { - if fdName == name { - if matchCount == targetIndex { - return listenFdsStart + i, nil - } - matchCount++ - } - } - - if matchCount == 0 { - return 0, fmt.Errorf("socket name '%s' not found in LISTEN_FDNAMES", name) - } - - return 0, fmt.Errorf("socket name '%s' found %d times, but index %d requested", name, matchCount, targetIndex) -} - // ParseNetworkAddress parses addr into its individual // components. The input string is expected to be of // the form "network/host:port-range" where any part is @@ -398,27 +347,9 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui }, err } if IsFdNetwork(network) { - fdAddr := host - - // Handle named socket activation (fdname/name, fdgramname/name) - if strings.HasPrefix(network, "fdname") || strings.HasPrefix(network, "fdgramname") { - fdNum, err := getFdByName(host) - if err != nil { - return NetworkAddress{}, fmt.Errorf("named socket activation: %v", err) - } - fdAddr = strconv.Itoa(fdNum) - - // Normalize network to standard fd/fdgram - if strings.HasPrefix(network, "fdname") { - network = "fd" - } else { - network = "fdgram" - } - } - return NetworkAddress{ Network: network, - Host: fdAddr, + Host: host, }, nil } var start, end uint64 @@ -713,55 +644,12 @@ func (fcql *fakeCloseQuicListener) Close() error { return nil } -// RegisterNetwork registers a network type with Caddy so that if a listener is -// created for that network type, getListener will be invoked to get the listener. -// This should be called during init() and will panic if the network type is standard -// or reserved, or if it is already registered. EXPERIMENTAL and subject to change. -func RegisterNetwork(network string, getListener ListenerFunc) { - network = strings.TrimSpace(strings.ToLower(network)) - - if network == "tcp" || network == "tcp4" || network == "tcp6" || - network == "udp" || network == "udp4" || network == "udp6" || - network == "unix" || network == "unixpacket" || network == "unixgram" || - strings.HasPrefix(network, "ip:") || strings.HasPrefix(network, "ip4:") || strings.HasPrefix(network, "ip6:") || - network == "fd" || network == "fdgram" { - panic("network type " + network + " is reserved") - } - - if _, ok := networkTypes[strings.ToLower(network)]; ok { - panic("network type " + network + " is already registered") - } - - networkTypes[network] = getListener -} - var unixSocketsMu sync.Mutex -// getListenerFromPlugin returns a listener on the given network and address -// if a plugin has registered the network name. It may return (nil, nil) if -// no plugin can provide a listener. -func getListenerFromPlugin(ctx context.Context, network, host, port string, portOffset uint, config net.ListenConfig) (any, error) { - // get listener from plugin if network type is registered - if getListener, ok := networkTypes[network]; ok { - Log().Debug("getting listener from plugin", zap.String("network", network)) - return getListener(ctx, network, host, port, portOffset, config) - } - - return nil, nil -} - func listenerKey(network, addr string) string { return network + "/" + addr } -// ListenerFunc is a function that can return a listener given a network and address. -// The listeners must be capable of overlapping: with Caddy, new configs are loaded -// before old ones are unloaded, so listeners may overlap briefly if the configs -// both need the same listener. EXPERIMENTAL and subject to change. -type ListenerFunc func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error) - -var networkTypes = map[string]ListenerFunc{} - // ListenerWrapper is a type that wraps a listener // so it can modify the input listener's methods. // Modules that implement this interface are found diff --git a/listeners_test.go b/listeners_test.go index c2cc255f21f..a4cadd3aab1 100644 --- a/listeners_test.go +++ b/listeners_test.go @@ -15,7 +15,6 @@ package caddy import ( - "os" "reflect" "testing" @@ -653,286 +652,3 @@ func TestSplitUnixSocketPermissionsBits(t *testing.T) { } } } - -// TestGetFdByName tests the getFdByName function for systemd socket activation. -func TestGetFdByName(t *testing.T) { - // Save original environment - originalFdNames := os.Getenv("LISTEN_FDNAMES") - - // Restore environment after test - defer func() { - if originalFdNames != "" { - os.Setenv("LISTEN_FDNAMES", originalFdNames) - } else { - os.Unsetenv("LISTEN_FDNAMES") - } - }() - - tests := []struct { - name string - fdNames string - socketName string - expectedFd int - expectError bool - }{ - { - name: "simple http socket", - fdNames: "http", - socketName: "http", - expectedFd: 3, - }, - { - name: "multiple different sockets - first", - fdNames: "http:https:dns", - socketName: "http", - expectedFd: 3, - }, - { - name: "multiple different sockets - second", - fdNames: "http:https:dns", - socketName: "https", - expectedFd: 4, - }, - { - name: "multiple different sockets - third", - fdNames: "http:https:dns", - socketName: "dns", - expectedFd: 5, - }, - { - name: "duplicate names - first occurrence (no index)", - fdNames: "web:web:api", - socketName: "web", - expectedFd: 3, - }, - { - name: "duplicate names - first occurrence (explicit index 0)", - fdNames: "web:web:api", - socketName: "web:0", - expectedFd: 3, - }, - { - name: "duplicate names - second occurrence (index 1)", - fdNames: "web:web:api", - socketName: "web:1", - expectedFd: 4, - }, - { - name: "complex duplicates - first api", - fdNames: "web:api:web:api:dns", - socketName: "api:0", - expectedFd: 4, - }, - { - name: "complex duplicates - second api", - fdNames: "web:api:web:api:dns", - socketName: "api:1", - expectedFd: 6, - }, - { - name: "complex duplicates - first web", - fdNames: "web:api:web:api:dns", - socketName: "web:0", - expectedFd: 3, - }, - { - name: "complex duplicates - second web", - fdNames: "web:api:web:api:dns", - socketName: "web:1", - expectedFd: 5, - }, - { - name: "socket not found", - fdNames: "http:https", - socketName: "missing", - expectError: true, - }, - { - name: "empty socket name", - fdNames: "http", - socketName: "", - expectError: true, - }, - { - name: "missing LISTEN_FDNAMES", - fdNames: "", - socketName: "http", - expectError: true, - }, - { - name: "index out of range", - fdNames: "web:web", - socketName: "web:2", - expectError: true, - }, - { - name: "negative index", - fdNames: "web", - socketName: "web:-1", - expectError: true, - }, - { - name: "invalid index format", - fdNames: "web", - socketName: "web:abc", - expectError: true, - }, - { - name: "too many colons", - fdNames: "web", - socketName: "web:0:extra", - expectError: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - // Set up environment - if tc.fdNames != "" { - os.Setenv("LISTEN_FDNAMES", tc.fdNames) - } else { - os.Unsetenv("LISTEN_FDNAMES") - } - - // Test the function - fd, err := getFdByName(tc.socketName) - - if tc.expectError { - if err == nil { - t.Errorf("Expected error but got none") - } - } else { - if err != nil { - t.Errorf("Expected no error but got: %v", err) - } - if fd != tc.expectedFd { - t.Errorf("Expected FD %d but got %d", tc.expectedFd, fd) - } - } - }) - } -} - -// TestParseNetworkAddressFdName tests parsing of fdname and fdgramname addresses. -func TestParseNetworkAddressFdName(t *testing.T) { - // Save and restore environment - originalFdNames := os.Getenv("LISTEN_FDNAMES") - defer func() { - if originalFdNames != "" { - os.Setenv("LISTEN_FDNAMES", originalFdNames) - } else { - os.Unsetenv("LISTEN_FDNAMES") - } - }() - - // Set up test environment - os.Setenv("LISTEN_FDNAMES", "http:https:dns") - - tests := []struct { - input string - expectAddr NetworkAddress - expectErr bool - }{ - { - input: "fdname/http", - expectAddr: NetworkAddress{ - Network: "fd", - Host: "3", - }, - }, - { - input: "fdname/https", - expectAddr: NetworkAddress{ - Network: "fd", - Host: "4", - }, - }, - { - input: "fdname/dns", - expectAddr: NetworkAddress{ - Network: "fd", - Host: "5", - }, - }, - { - input: "fdname/http:0", - expectAddr: NetworkAddress{ - Network: "fd", - Host: "3", - }, - }, - { - input: "fdname/https:0", - expectAddr: NetworkAddress{ - Network: "fd", - Host: "4", - }, - }, - { - input: "fdgramname/http", - expectAddr: NetworkAddress{ - Network: "fdgram", - Host: "3", - }, - }, - { - input: "fdgramname/https", - expectAddr: NetworkAddress{ - Network: "fdgram", - Host: "4", - }, - }, - { - input: "fdgramname/http:0", - expectAddr: NetworkAddress{ - Network: "fdgram", - Host: "3", - }, - }, - { - input: "fdname/nonexistent", - expectErr: true, - }, - { - input: "fdgramname/nonexistent", - expectErr: true, - }, - { - input: "fdname/http:99", - expectErr: true, - }, - { - input: "fdname/invalid:abc", - expectErr: true, - }, - // Test that old fd/N syntax still works - { - input: "fd/7", - expectAddr: NetworkAddress{ - Network: "fd", - Host: "7", - }, - }, - { - input: "fdgram/8", - expectAddr: NetworkAddress{ - Network: "fdgram", - Host: "8", - }, - }, - } - - for i, tc := range tests { - actualAddr, err := ParseNetworkAddress(tc.input) - - if tc.expectErr && err == nil { - t.Errorf("Test %d (%s): Expected error but got none", i, tc.input) - } - if !tc.expectErr && err != nil { - t.Errorf("Test %d (%s): Expected no error but got: %v", i, tc.input, err) - } - if !tc.expectErr && !reflect.DeepEqual(tc.expectAddr, actualAddr) { - t.Errorf("Test %d (%s): Expected %+v but got %+v", i, tc.input, tc.expectAddr, actualAddr) - } - } -} diff --git a/modules/caddyhttp/server.go b/modules/caddyhttp/server.go index f4d3624960d..d3760bf3cc9 100644 --- a/modules/caddyhttp/server.go +++ b/modules/caddyhttp/server.go @@ -620,7 +620,7 @@ func (s *Server) findLastRouteWithHostMatcher() int { // not already done, and then uses that server to serve HTTP/3 over // the listener, with Server s as the handler. func (s *Server) serveHTTP3(addr caddy.NetworkAddress, tlsCfg *tls.Config) error { - h3net, err := getHTTP3Network(addr.Network) + h3net, err := getNetworkHTTP3(addr.Network) if err != nil { return fmt.Errorf("starting HTTP/3 QUIC listener: %v", err) } @@ -1125,16 +1125,7 @@ const ( ClientIPVarKey string = "client_ip" ) -var networkTypesHTTP3 = map[string]string{ - "unixgram": "unixgram", - "udp": "udp", - "udp4": "udp4", - "udp6": "udp6", - "tcp": "udp", - "tcp4": "udp4", - "tcp6": "udp6", - "fdgram": "fdgram", -} +var networkHTTP3Plugins = map[string]string{} // RegisterNetworkHTTP3 registers a mapping from non-HTTP/3 network to HTTP/3 // network. This should be called during init() and will panic if the network @@ -1142,16 +1133,41 @@ var networkTypesHTTP3 = map[string]string{ // // EXPERIMENTAL: Subject to change. func RegisterNetworkHTTP3(originalNetwork, h3Network string) { - if _, ok := networkTypesHTTP3[strings.ToLower(originalNetwork)]; ok { + if caddy.IsReservedNetwork(originalNetwork) { + panic("network type " + originalNetwork + " is reserved") + } + + if _, ok := networkHTTP3Plugins[strings.ToLower(originalNetwork)]; ok { panic("network type " + originalNetwork + " is already registered") } - networkTypesHTTP3[originalNetwork] = h3Network + + networkHTTP3Plugins[originalNetwork] = h3Network } -func getHTTP3Network(originalNetwork string) (string, error) { - h3Network, ok := networkTypesHTTP3[strings.ToLower(originalNetwork)] +func getNetworkHTTP3(originalNetwork string) (string, error) { + switch originalNetwork { + case caddy.UNIXGRAM: + return caddy.UNIXGRAM, nil + case caddy.UDP: + return caddy.UDP, nil + case caddy.UDP4: + return caddy.UDP4, nil + case caddy.UDP6: + return caddy.UDP6, nil + case caddy.TCP: + return caddy.UDP, nil + case caddy.TCP4: + return caddy.UDP4, nil + case caddy.TCP6: + return caddy.UDP6, nil + case caddy.FDGRAM: + return caddy.FDGRAM, nil + } + + h3Network, ok := networkHTTP3Plugins[strings.ToLower(originalNetwork)] if !ok { return "", fmt.Errorf("network '%s' cannot handle HTTP/3 connections", originalNetwork) } + return h3Network, nil } diff --git a/networks.go b/networks.go new file mode 100644 index 00000000000..0b7e3859a96 --- /dev/null +++ b/networks.go @@ -0,0 +1,129 @@ +// Copyright 2015 Matthew Holt and The Caddy Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package caddy + +import ( + "context" + "net" + "strings" + + "go.uber.org/zap" +) + +const ( + UNIX = "unix" + UNIX_H2C = "unix+h2c" + UNIXGRAM = "unixgram" + UNIXPACKET = "unixpacket" + TCP = "tcp" + TCP4 = "tcp4" + TCP6 = "tcp6" + UDP = "udp" + UDP4 = "udp4" + UDP6 = "udp6" + IP_ = "ip:" + IP4_ = "ip4:" + IP6_ = "ip6:" + FD = "fd" + FDGRAM = "fdgram" +) + +// IsUnixNetwork returns true if the netw is a unix network. +func IsUnixNetwork(netw string) bool { + return netw == UNIX || netw == UNIX_H2C || netw == UNIXGRAM || netw == UNIXPACKET +} + +// IsUnixNetwork returns true if the netw is a TCP network. +func IsTCPNetwork(netw string) bool { + return netw == TCP || netw == TCP4 || netw == TCP6 +} + +// IsUnixNetwork returns true if the netw is a UDP network. +func IsUDPNetwork(netw string) bool { + return netw == UDP || netw == UDP4 || netw == UDP6 +} + +// IsIpNetwork returns true if the netw is an ip network. +func IsIpNetwork(netw string) bool { + return strings.HasPrefix(netw, IP_) || strings.HasPrefix(netw, IP4_) || strings.HasPrefix(netw, IP6_) +} + +// IsFdNetwork returns true if the netw is a fd network. +func IsFdNetwork(netw string) bool { + return netw == FD || netw == FDGRAM +} + +func IsReservedNetwork(network string) bool { + return IsUnixNetwork(network) || + IsTCPNetwork(network) || + IsUDPNetwork(network) || + IsIpNetwork(network) || + IsFdNetwork(network) +} + +func IsIPv4Network(netw string) bool { + return netw == TCP || netw == TCP4 || netw == UDP || netw == UDP4 || strings.HasPrefix(netw, IP_) || strings.HasPrefix(netw, IP4_) +} + +func IsIPv6Network(netw string) bool { + return netw == TCP || netw == TCP6 || netw == UDP || netw == UDP6 || strings.HasPrefix(netw, IP_) || strings.HasPrefix(netw, IP6_) +} + +func IsStreamNetwork(netw string) bool { + return netw == UNIX || netw == UNIX_H2C || netw == UNIXPACKET || IsTCPNetwork(netw) || netw == FD +} + +func IsPacketNetwork(netw string) bool { + return netw == UNIXGRAM || IsUDPNetwork(netw) || IsIpNetwork(netw) || netw == FDGRAM +} + +// ListenerFunc is a function that can return a listener given a network and address. +// The listeners must be capable of overlapping: with Caddy, new configs are loaded +// before old ones are unloaded, so listeners may overlap briefly if the configs +// both need the same listener. EXPERIMENTAL and subject to change. +type ListenerFunc func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error) + +var networkPlugins = map[string]ListenerFunc{} + +// RegisterNetwork registers a network plugin with Caddy so that if a listener is +// created for that network plugin, getListener will be invoked to get the listener. +// This should be called during init() and will panic if the network type is standard +// or reserved, or if it is already registered. EXPERIMENTAL and subject to change. +func RegisterNetwork(network string, getListener ListenerFunc) { + network = strings.TrimSpace(strings.ToLower(network)) + + if IsReservedNetwork(network) { + panic("network type " + network + " is reserved") + } + + if _, ok := networkPlugins[strings.ToLower(network)]; ok { + panic("network type " + network + " is already registered") + } + + networkPlugins[network] = getListener +} + +// getListenerFromPlugin returns a listener on the given network and address +// if a plugin has registered the network name. It may return (nil, nil) if +// no plugin can provide a listener. +func getListenerFromPlugin(ctx context.Context, network, host, port string, portOffset uint, config net.ListenConfig) (any, error) { + // get listener from plugin if network is registered + if getListener, ok := networkPlugins[network]; ok { + Log().Debug("getting listener from plugin", zap.String("network", network)) + return getListener(ctx, network, host, port, portOffset, config) + } + + return nil, nil +} diff --git a/replacer.go b/replacer.go index 1a2aa5771d1..7e5e62d51cf 100644 --- a/replacer.go +++ b/replacer.go @@ -36,16 +36,12 @@ func NewReplacer() *Replacer { static: make(map[string]any), mapMutex: &sync.RWMutex{}, } - rep.providers = []replacementProvider{ - globalDefaultReplacementProvider{}, - fileReplacementProvider{}, - ReplacerFunc(rep.fromStatic), - } + rep.providers = append(globalReplacementProviders, ReplacerFunc(rep.fromStatic)) return rep } // NewEmptyReplacer returns a new Replacer, -// without the global default replacements. +// without the global replacements. func NewEmptyReplacer() *Replacer { rep := &Replacer{ static: make(map[string]any), @@ -360,12 +356,11 @@ func (f fileReplacementProvider) replace(key string) (any, bool) { return string(body), true } -// globalDefaultReplacementProvider handles replacements -// that can be used in any context, such as system variables, -// time, or environment variables. -type globalDefaultReplacementProvider struct{} +// defaultReplacementProvider handles replacements +// such as system variables, time, or environment variables. +type defaultReplacementProvider struct{} -func (f globalDefaultReplacementProvider) replace(key string) (any, bool) { +func (f defaultReplacementProvider) replace(key string) (any, bool) { // check environment variable const envPrefix = "env." if strings.HasPrefix(key, envPrefix) { diff --git a/replacer_nosystemd.go b/replacer_nosystemd.go new file mode 100644 index 00000000000..c885e4cb7a7 --- /dev/null +++ b/replacer_nosystemd.go @@ -0,0 +1,8 @@ +//go:build !linux || nosystemd + +package caddy + +var globalReplacementProviders = []replacementProvider{ + defaultReplacementProvider{}, + fileReplacementProvider{}, +} diff --git a/replacer_systemd.go b/replacer_systemd.go new file mode 100644 index 00000000000..518ce365943 --- /dev/null +++ b/replacer_systemd.go @@ -0,0 +1,123 @@ +//go:build linux && !nosystemd + +package caddy + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + + "go.uber.org/zap" +) + +func sdListenFds() (int, error) { + lnPid, ok := os.LookupEnv("LISTEN_PID") + if !ok { + return 0, errors.New("LISTEN_PID is unset") + } + + pid, err := strconv.Atoi(lnPid) + if err != nil { + return 0, err + } + + if pid != os.Getpid() { + return 0, fmt.Errorf("LISTEN_PID does not match pid: %d != %d", pid, os.Getpid()) + } + + lnFds, ok := os.LookupEnv("LISTEN_FDS") + if !ok { + return 0, errors.New("LISTEN_FDS is unset") + } + + fds, err := strconv.Atoi(lnFds) + if err != nil { + return 0, err + } + + return fds, nil +} + +func sdListenFdsWithNames() (map[string][]uint, error) { + const lnFdsStart = 3 + + fds, err := sdListenFds() + if err != nil { + return nil, err + } + + lnFdnames, ok := os.LookupEnv("LISTEN_FDNAMES") + if !ok { + return nil, errors.New("LISTEN_FDNAMES is unset") + } + + fdNames := strings.Split(lnFdnames, ":") + if fds != len(fdNames) { + return nil, fmt.Errorf("LISTEN_FDS does not match LISTEN_FDNAMES length: %d != %d", fds, len(fdNames)) + } + + nameToFiles := make(map[string][]uint, len(fdNames)) + for index, name := range fdNames { + nameToFiles[name] = append(nameToFiles[name], lnFdsStart+uint(index)) + } + + return nameToFiles, nil +} + +func getSdListenFd(nameToFiles map[string][]uint, nameOffset string) (uint, error) { + index := uint(0) + + name, offset, found := strings.Cut(nameOffset, ":") + if found { + off, err := strconv.ParseUint(offset, 0, strconv.IntSize) + if err != nil { + return 0, err + } + index += uint(off) + } + + files, ok := nameToFiles[name] + if !ok { + return 0, fmt.Errorf("invalid listen fd name: %s", name) + } + + if uint(len(files)) <= index { + return 0, fmt.Errorf("invalid listen fd index: %d", index) + } + + return files[index], nil +} + +var initNameToFiles, initNameToFilesErr = sdListenFdsWithNames() + +// systemdReplacementProvider handles {systemd.*} replacements +type systemdReplacementProvider struct{} + +func (f systemdReplacementProvider) replace(key string) (any, bool) { + // check environment variable + const systemdListenPrefix = "systemd.listen." + if strings.HasPrefix(key, systemdListenPrefix) { + if initNameToFilesErr != nil { + Log().Error("unable to read LISTEN_FDNAMES", zap.Error(initNameToFilesErr)) + return nil, false + } + fd, err := getSdListenFd(initNameToFiles, key[len(systemdListenPrefix):]) + if err != nil { + Log().Error("unable to process {"+key+"}", zap.Error(err)) + return nil, false + } + return fd, true + } + + // TODO const systemdCredsPrefix = "systemd.creds." + + return nil, false +} + +var globalReplacementProviders = []replacementProvider{ + defaultReplacementProvider{}, + fileReplacementProvider{}, + systemdReplacementProvider{}, +} diff --git a/replacer_test.go b/replacer_test.go index 4f20bede30f..27774935a48 100644 --- a/replacer_test.go +++ b/replacer_test.go @@ -374,10 +374,6 @@ func TestReplacerMap(t *testing.T) { func TestReplacerNew(t *testing.T) { repl := NewReplacer() - if len(repl.providers) != 3 { - t.Errorf("Expected providers length '%v' got length '%v'", 3, len(repl.providers)) - } - // test if default global replacements are added as the first provider hostname, _ := os.Hostname() wd, _ := os.Getwd() diff --git a/replacer_test_systemd.go b/replacer_test_systemd.go new file mode 100644 index 00000000000..1623497b769 --- /dev/null +++ b/replacer_test_systemd.go @@ -0,0 +1,376 @@ +//go:build linux && !nosystemd + +package caddy + +import ( + "os" + "reflect" + "strconv" + "testing" +) + +// TestGetSdListenFd tests the getSdListenFd function for systemd socket activation. +func TestGetSdListenFd(t *testing.T) { + // Save original environment + originalFdNames := os.Getenv("LISTEN_FDNAMES") + originalFds := os.Getenv("LISTEN_FDS") + originalPid := os.Getenv("LISTEN_PID") + + // Restore environment after test + defer func() { + if originalFdNames != "" { + os.Setenv("LISTEN_FDNAMES", originalFdNames) + } else { + os.Unsetenv("LISTEN_FDNAMES") + } + if originalFds != "" { + os.Setenv("LISTEN_FDS", originalFds) + } else { + os.Unsetenv("LISTEN_FDS") + } + if originalPid != "" { + os.Setenv("LISTEN_PID", originalPid) + } else { + os.Unsetenv("LISTEN_PID") + } + }() + + tests := []struct { + name string + fdNames string + fds string + socketName string + expectedFd uint + expectError bool + }{ + { + name: "simple http socket", + fdNames: "http", + fds: "1", + socketName: "http", + expectedFd: 3, + }, + { + name: "multiple different sockets - first", + fdNames: "http:https:dns", + fds: "3", + socketName: "http", + expectedFd: 3, + }, + { + name: "multiple different sockets - second", + fdNames: "http:https:dns", + fds: "3", + socketName: "https", + expectedFd: 4, + }, + { + name: "multiple different sockets - third", + fdNames: "http:https:dns", + fds: "3", + socketName: "dns", + expectedFd: 5, + }, + { + name: "duplicate names - first occurrence (no index)", + fdNames: "web:web:api", + fds: "3", + socketName: "web", + expectedFd: 3, + }, + { + name: "duplicate names - first occurrence (explicit index 0)", + fdNames: "web:web:api", + fds: "3", + socketName: "web:0", + expectedFd: 3, + }, + { + name: "duplicate names - second occurrence (index 1)", + fdNames: "web:web:api", + fds: "3", + socketName: "web:1", + expectedFd: 4, + }, + { + name: "complex duplicates - first api", + fdNames: "web:api:web:api:dns", + fds: "5", + socketName: "api:0", + expectedFd: 4, + }, + { + name: "complex duplicates - second api", + fdNames: "web:api:web:api:dns", + fds: "5", + socketName: "api:1", + expectedFd: 6, + }, + { + name: "complex duplicates - first web", + fdNames: "web:api:web:api:dns", + fds: "5", + socketName: "web:0", + expectedFd: 3, + }, + { + name: "complex duplicates - second web", + fdNames: "web:api:web:api:dns", + fds: "5", + socketName: "web:1", + expectedFd: 5, + }, + { + name: "socket not found", + fdNames: "http:https", + fds: "2", + socketName: "missing", + expectError: true, + }, + { + name: "empty socket name", + fdNames: "http", + fds: "1", + socketName: "", + expectError: true, + }, + { + name: "missing LISTEN_FDNAMES", + fdNames: "", + fds: "", + socketName: "http", + expectError: true, + }, + { + name: "index out of range", + fdNames: "web:web", + fds: "2", + socketName: "web:2", + expectError: true, + }, + { + name: "negative index", + fdNames: "web", + fds: "1", + socketName: "web:-1", + expectError: true, + }, + { + name: "invalid index format", + fdNames: "web", + fds: "1", + socketName: "web:abc", + expectError: true, + }, + { + name: "too many colons", + fdNames: "web", + fds: "1", + socketName: "web:0:extra", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Set up environment + if tc.fdNames != "" { + os.Setenv("LISTEN_FDNAMES", tc.fdNames) + } else { + os.Unsetenv("LISTEN_FDNAMES") + } + + if tc.fds != "" { + os.Setenv("LISTEN_FDS", tc.fds) + } else { + os.Unsetenv("LISTEN_FDS") + } + + os.Setenv("LISTEN_PID", strconv.Itoa(os.Getpid())) + + // Test the function + var ( + listenFdsWithNames map[string][]uint + err error + fd uint + ) + listenFdsWithNames, err = sdListenFdsWithNames() + if err == nil { + fd, err = getSdListenFd(listenFdsWithNames, tc.socketName) + } + + if tc.expectError { + if err == nil { + t.Errorf("Expected error but got none") + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } + if fd != tc.expectedFd { + t.Errorf("Expected FD %d but got %d", tc.expectedFd, fd) + } + } + }) + } +} + +// TestParseSystemdListenPlaceholder tests parsing of {systemd.listen.name} placeholders. +func TestParseSystemdListenPlaceholder(t *testing.T) { + // Save and restore environment + originalFdNames := os.Getenv("LISTEN_FDNAMES") + originalFds := os.Getenv("LISTEN_FDS") + originalPid := os.Getenv("LISTEN_PID") + + defer func() { + if originalFdNames != "" { + os.Setenv("LISTEN_FDNAMES", originalFdNames) + } else { + os.Unsetenv("LISTEN_FDNAMES") + } + if originalFds != "" { + os.Setenv("LISTEN_FDS", originalFds) + } else { + os.Unsetenv("LISTEN_FDS") + } + if originalPid != "" { + os.Setenv("LISTEN_PID", originalPid) + } else { + os.Unsetenv("LISTEN_PID") + } + }() + + // Set up test environment + os.Setenv("LISTEN_FDNAMES", "http:https:dns") + os.Setenv("LISTEN_FDS", "3") + os.Setenv("LISTEN_PID", strconv.Itoa(os.Getpid())) + + tests := []struct { + input string + expectedAddr NetworkAddress + expectedFd uint + expectErr bool + }{ + { + input: "fd/{systemd.listen.http}", + expectedAddr: NetworkAddress{ + Network: "fd", + Host: "{systemd.listen.http}", + }, + expectedFd: 3, + }, + { + input: "fd/{systemd.listen.https}", + expectedAddr: NetworkAddress{ + Network: "fd", + Host: "{systemd.listen.https}", + }, + expectedFd: 4, + }, + { + input: "fd/{systemd.listen.dns}", + expectedAddr: NetworkAddress{ + Network: "fd", + Host: "{systemd.listen.dns}", + }, + expectedFd: 5, + }, + { + input: "fd/{systemd.listen.http:0}", + expectedAddr: NetworkAddress{ + Network: "fd", + Host: "{systemd.listen.http:0}", + }, + expectedFd: 3, + }, + { + input: "fd/{systemd.listen.https:0}", + expectedAddr: NetworkAddress{ + Network: "fd", + Host: "{systemd.listen.https:0}", + }, + expectedFd: 4, + }, + { + input: "fdgram/{systemd.listen.http}", + expectedAddr: NetworkAddress{ + Network: "fdgram", + Host: "{systemd.listen.http}", + }, + expectedFd: 3, + }, + { + input: "fdgram/{systemd.listen.https}", + expectedAddr: NetworkAddress{ + Network: "fdgram", + Host: "{systemd.listen.https}", + }, + expectedFd: 4, + }, + { + input: "fdgram/{systemd.listen.http:0}", + expectedAddr: NetworkAddress{ + Network: "fdgram", + Host: "http:0", + }, + expectedFd: 3, + }, + { + input: "fd/{systemd.listen.nonexistent}", + expectErr: true, + }, + { + input: "fdgram/{systemd.listen.nonexistent}", + expectErr: true, + }, + { + input: "fd/{systemd.listen.http:99}", + expectErr: true, + }, + { + input: "fd/{systemd.listen.invalid:abc}", + expectErr: true, + }, + // Test that old fd/N syntax still works + { + input: "fd/7", + expectedAddr: NetworkAddress{ + Network: "fd", + Host: "7", + }, + expectedFd: 7, + }, + { + input: "fdgram/8", + expectedAddr: NetworkAddress{ + Network: "fdgram", + Host: "8", + }, + expectedFd: 8, + }, + } + + for i, tc := range tests { + actualAddr, err := ParseNetworkAddress(tc.input) + if err == nil { + var fd uint + fdWide, err := strconv.ParseUint(actualAddr.Host, 0, strconv.IntSize) + if err == nil { + fd = uint(fdWide) + } + + if tc.expectErr && err == nil { + t.Errorf("Test %d (%s): Expected error but got none", i, tc.input) + } + if !tc.expectErr && err != nil { + t.Errorf("Test %d (%s): Expected no error but got: %v", i, tc.input, err) + } + if !tc.expectErr && !reflect.DeepEqual(tc.expectedAddr, actualAddr) { + t.Errorf("Test %d (%s): Expected %+v but got %+v", i, tc.input, tc.expectedAddr, actualAddr) + } + if !tc.expectErr && fd != tc.expectedFd { + t.Errorf("Expected FD %d but got %d", tc.expectedFd, fd) + } + } + } +}