diff --git a/pkg/queue/certificate/watcher.go b/pkg/queue/certificate/watcher.go new file mode 100644 index 000000000000..e5f2bfc1a120 --- /dev/null +++ b/pkg/queue/certificate/watcher.go @@ -0,0 +1,118 @@ +package certificate + +import ( + "crypto/sha256" + "crypto/tls" + "os" + "path" + "sync" + "time" + + "go.uber.org/zap" +) + +const ( + reloadInterval = 1 * time.Minute +) + +// CertWatcher watches certificate and key files and reloads them if they change on disk. +type CertWatcher struct { + certPath string + certChecksum [sha256.Size]byte + keyPath string + keyChecksum [sha256.Size]byte + + certificate *tls.Certificate + + logger *zap.SugaredLogger + ticker *time.Ticker + stop chan struct{} + mux sync.RWMutex +} + +// NewCertWatcher creates a CertWatcher and watches +// the certificate and key files. It reloads the contents on file change. +// Make sure to stop the CertWatcher using Stop() upon destroy. +func NewCertWatcher(certPath, keyPath string, logger *zap.SugaredLogger) (*CertWatcher, error) { + cw := &CertWatcher{ + certPath: certPath, + keyPath: keyPath, + logger: logger, + ticker: time.NewTicker(reloadInterval), + stop: make(chan struct{}), + mux: sync.RWMutex{}, + } + + certDir := path.Dir(cw.certPath) + keyDir := path.Dir(cw.keyPath) + + cw.logger.Info("Starting to watch the following directories for changes", + zap.String("certDir", certDir), zap.String("keyDir", keyDir)) + + // initial load + cw.loadCert() + + go cw.watch() + + return cw, nil +} + +// Stop shuts down the CertWatcher. Use this with `defer`. +func (cw *CertWatcher) Stop() { + cw.logger.Info("Stopping file watcher") + close(cw.stop) + cw.ticker.Stop() +} + +// GetCertificate returns the server certificate for a client-hello request. +func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + cw.mux.RLock() + defer cw.mux.RUnlock() + return cw.certificate, nil +} + +func (cw *CertWatcher) watch() { + for { + select { + case <-cw.stop: + return + + case <-cw.ticker.C: + cw.loadCert() + } + } +} + +func (cw *CertWatcher) loadCert() { + var err error + certFile, err := os.ReadFile(cw.certPath) + if err != nil { + cw.logger.Error("failed to load certificate file", zap.String("certPath", cw.certPath), zap.Error(err)) + return + } + keyFile, err := os.ReadFile(cw.keyPath) + if err != nil { + cw.logger.Error("failed to load key file", zap.String("keyPath", cw.keyPath), zap.Error(err)) + return + } + + certChecksum := sha256.Sum256(certFile) + keyChecksum := sha256.Sum256(keyFile) + + if certChecksum != cw.certChecksum || keyChecksum != cw.keyChecksum { + keyPair, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath) + if err != nil { + cw.logger.Error("failed to load and parse certificate", zap.Error(err)) + return + } + + cw.mux.Lock() + defer cw.mux.Unlock() + + cw.certificate = &keyPair + cw.certChecksum = certChecksum + cw.keyChecksum = keyChecksum + + cw.logger.Info("Certificate and/or key have changed on disk and were reloaded.") + } +} diff --git a/pkg/queue/sharedmain/main.go b/pkg/queue/sharedmain/main.go index 0400d416457a..81fca603f360 100644 --- a/pkg/queue/sharedmain/main.go +++ b/pkg/queue/sharedmain/main.go @@ -18,6 +18,7 @@ package sharedmain import ( "context" + "crypto/tls" "errors" "fmt" "net/http" @@ -29,6 +30,7 @@ import ( "go.opencensus.io/plugin/ochttp" "go.uber.org/automaxprocs/maxprocs" "go.uber.org/zap" + "knative.dev/serving/pkg/queue/certificate" "k8s.io/apimachinery/pkg/types" @@ -245,16 +247,21 @@ func Main(opts ...Option) error { httpServers["profile"] = profiling.NewServer(profiling.NewHandler(logger, true)) } - tlsServers := map[string]*http.Server{ - "main": mainServer(":"+env.QueueServingTLSPort, mainHandler), - "admin": adminServer(":"+strconv.Itoa(networking.QueueAdminPort), adminHandler), - } + tlsServers := make(map[string]*http.Server) + var certWatcher *certificate.CertWatcher + var err error if tlsEnabled { + tlsServers["main"] = mainServer(":"+env.QueueServingTLSPort, mainHandler) + tlsServers["admin"] = adminServer(":"+strconv.Itoa(networking.QueueAdminPort), adminHandler) + + certWatcher, err = certificate.NewCertWatcher(certPath, keyPath, logger) + if err != nil { + logger.Fatal("failed to create certWatcher", zap.Error(err)) + } + // Drop admin http server since the admin TLS server is listening on the same port delete(httpServers, "admin") - } else { - tlsServers = map[string]*http.Server{} } logger.Info("Starting queue-proxy") @@ -271,9 +278,12 @@ func Main(opts ...Option) error { } for name, server := range tlsServers { go func(name string, s *http.Server) { - // Don't forward ErrServerClosed as that indicates we're already shutting down. logger.Info("Starting tls server ", name, s.Addr) - if err := s.ListenAndServeTLS(certPath, keyPath); err != nil && !errors.Is(err, http.ErrServerClosed) { + s.TLSConfig = &tls.Config{ + GetCertificate: certWatcher.GetCertificate, + } + // Don't forward ErrServerClosed as that indicates we're already shutting down. + if err := s.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- fmt.Errorf("%s server failed to serve: %w", name, err) } }(name, server) diff --git a/pkg/reconciler/revision/resources/deploy.go b/pkg/reconciler/revision/resources/deploy.go index 8ba95784dc12..94942245a5a6 100644 --- a/pkg/reconciler/revision/resources/deploy.go +++ b/pkg/reconciler/revision/resources/deploy.go @@ -132,8 +132,16 @@ func certVolume(secret string) corev1.Volume { return corev1.Volume{ Name: certVolumeName, VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: secret, + Projected: &corev1.ProjectedVolumeSource{ + Sources: []corev1.VolumeProjection{ + { + Secret: &corev1.SecretProjection{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: secret, + }, + }, + }, + }, }, }, }