diff --git a/integration/teleterm_test.go b/integration/teleterm_test.go index 5a1b900947bf7..14bf6853c7c1b 100644 --- a/integration/teleterm_test.go +++ b/integration/teleterm_test.go @@ -57,6 +57,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/apiserver/handler" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" libutils "github.com/gravitational/teleport/lib/utils" @@ -98,7 +99,7 @@ func TestTeleterm(t *testing.T) { testGetClusterReturnsPropertiesFromAuthServer(t, pack) }) - t.Run("Test headless watcher", func(t *testing.T) { + t.Run("headless watcher", func(t *testing.T) { t.Parallel() testHeadlessWatcher(t, pack, creds) @@ -124,7 +125,7 @@ func TestTeleterm(t *testing.T) { testDeleteConnectMyComputerNode(t, pack) }) - t.Run("TestClientCache", func(t *testing.T) { + t.Run("client cache", func(t *testing.T) { t.Parallel() testClientCache(t, pack, creds) @@ -362,10 +363,13 @@ func testGetClusterReturnsPropertiesFromAuthServer(t *testing.T, pack *dbhelpers }) require.NoError(t, err) + clusterIDCache := clusteridcache.Cache{} + daemonService, err := daemon.New(daemon.Config{ Storage: storage, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + ClusterIDCache: &clusterIDCache, }) require.NoError(t, err) t.Cleanup(func() { @@ -381,15 +385,22 @@ func testGetClusterReturnsPropertiesFromAuthServer(t *testing.T, pack *dbhelpers rootClusterName, _, err := net.SplitHostPort(pack.Root.Cluster.Web) require.NoError(t, err) + clusterURI := uri.NewClusterURI(rootClusterName) response, err := handler.GetCluster(context.Background(), &api.GetClusterRequest{ - ClusterUri: uri.NewClusterURI(rootClusterName).String(), + ClusterUri: clusterURI.String(), }) require.NoError(t, err) require.Equal(t, userName, response.LoggedInUser.Name) require.ElementsMatch(t, []string{requestableRoleName}, response.LoggedInUser.RequestableRoles) require.ElementsMatch(t, []string{suggestedReviewer}, response.LoggedInUser.SuggestedReviewers) + + // Verify that cluster ID cache gets updated. + clusterIDFromCache, ok := clusterIDCache.Load(clusterURI) + require.True(t, ok, "ID for cluster %q was not found in the cache", clusterURI) + require.NotEmpty(t, clusterIDFromCache) + require.Equal(t, response.AuthClusterId, clusterIDFromCache) } func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { diff --git a/lib/teleterm/apiserver/apiserver.go b/lib/teleterm/apiserver/apiserver.go index 5ff1793a0f4c5..40e387638877e 100644 --- a/lib/teleterm/apiserver/apiserver.go +++ b/lib/teleterm/apiserver/apiserver.go @@ -40,18 +40,6 @@ func New(cfg Config) (*APIServer, error) { return nil, trace.Wrap(err) } - // Create the listener, set up the server. - - ls, err := newListener(cfg.HostAddr, cfg.ListeningC) - if err != nil { - return nil, trace.Wrap(err) - } - - grpcServer := grpc.NewServer(cfg.TshdServerCreds, - grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Log)), - grpc.MaxConcurrentStreams(defaults.GRPCMaxConcurrentStreams), - ) - // Create Terminal and VNet services. serviceHandler, err := handler.New( @@ -66,11 +54,25 @@ func New(cfg Config) (*APIServer, error) { vnetService, err := vnet.New(vnet.Config{ DaemonService: cfg.Daemon, InsecureSkipVerify: cfg.InsecureSkipVerify, + ClusterIDCache: cfg.ClusterIDCache, + InstallationID: cfg.InstallationID, }) if err != nil { return nil, trace.Wrap(err) } + // Create the listener, set up the server. + + ls, err := newListener(cfg.HostAddr, cfg.ListeningC) + if err != nil { + return nil, trace.Wrap(err) + } + + grpcServer := grpc.NewServer(cfg.TshdServerCreds, + grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Log)), + grpc.MaxConcurrentStreams(defaults.GRPCMaxConcurrentStreams), + ) + api.RegisterTerminalServiceServer(grpcServer, serviceHandler) vnetapi.RegisterVnetServiceServer(grpcServer, vnetService) diff --git a/lib/teleterm/apiserver/config.go b/lib/teleterm/apiserver/config.go index 086bb957c10f1..1f296dce88893 100644 --- a/lib/teleterm/apiserver/config.go +++ b/lib/teleterm/apiserver/config.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/daemon" "github.com/gravitational/teleport/lib/utils" ) @@ -34,7 +35,9 @@ type Config struct { HostAddr string InsecureSkipVerify bool // Daemon is the terminal daemon service - Daemon *daemon.Service + Daemon *daemon.Service + ClusterIDCache *clusteridcache.Cache + InstallationID string // Log is a component logger Log logrus.FieldLogger TshdServerCreds grpc.ServerOption @@ -65,5 +68,13 @@ func (c *Config) CheckAndSetDefaults() error { c.Log = logrus.WithField(teleport.ComponentKey, "conn:apiserver") } + if c.InstallationID == "" { + return trace.BadParameter("missing installation ID") + } + + if c.ClusterIDCache == nil { + c.ClusterIDCache = &clusteridcache.Cache{} + } + return nil } diff --git a/lib/teleterm/clusteridcache/clusteridcache.go b/lib/teleterm/clusteridcache/clusteridcache.go new file mode 100644 index 0000000000000..23a0558cfdb73 --- /dev/null +++ b/lib/teleterm/clusteridcache/clusteridcache.go @@ -0,0 +1,61 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package clusteridcache + +import ( + "sync" + + "github.com/gravitational/teleport/lib/teleterm/api/uri" +) + +// Cache stores cluster IDs indexed by their cluster URIs. +// +// Cluster IDs are required when reporting usage events, but they are not publicly known and can be +// fetched only after logging in to a cluster. Today, most events are sent from the Electron app. +// The Electron app caches cluster IDs on its own. However, sometimes we want to send events +// straight from the tsh daemon, in which case we need to know the ID of a cluster. +// +// Whenever the user logs in and fetches full details of a cluster, the cluster ID gets saved to the +// cache. Later on when tsh daemon wants to send a usage event, it can load the cluster ID from the +// cache. +// +// This cache is never cleared since cluster IDs are saved only for root clusters. Logging in to a +// root cluster overwrites existing ID under the same URI. +// +// TODO(ravicious): Refactor usage reporting to operate on cluster URIs instead of cluster IDs and +// keep the cache only on the side of tsh daemon. Fetch a cluster ID whenever it's first requested +// to avoid an issue with trying to send a usage event before the cluster ID is known. +// https://github.com/gravitational/teleport/issues/23030 +type Cache struct { + m sync.Map +} + +// Store stores the cluster ID for the given uri of a root cluster. +func (c *Cache) Store(uri uri.ResourceURI, clusterID string) { + c.m.Store(uri.String(), clusterID) +} + +// Load returns the cluster ID for the given uri of a root cluster. +func (c *Cache) Load(uri uri.ResourceURI) (string, bool) { + id, ok := c.m.Load(uri.String()) + + if !ok { + return "", false + } + + return id.(string), true +} diff --git a/lib/teleterm/clusters/cluster.go b/lib/teleterm/clusters/cluster.go index 04f7d94873fe2..0d4c16a4a8d4b 100644 --- a/lib/teleterm/clusters/cluster.go +++ b/lib/teleterm/clusters/cluster.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" ) // Cluster describes user settings and access to various resources. @@ -88,7 +89,7 @@ func (c *Cluster) Connected() bool { // GetWithDetails makes requests to the auth server to return details of the current // Cluster that cannot be found on the disk only, including details about the user // and enabled enterprise features. This method requires a valid cert. -func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.ClientI) (*ClusterWithDetails, error) { +func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.ClientI, clusterIDCache *clusteridcache.Cache) (*ClusterWithDetails, error) { var ( clusterPingResponse *webclient.PingResponse webConfig *webclient.WebConfig @@ -142,6 +143,7 @@ func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.Clie return trace.Wrap(err) } authClusterID = clusterName.GetClusterID() + clusterIDCache.Store(c.URI, authClusterID) return nil }) return trace.Wrap(err) diff --git a/lib/teleterm/config.go b/lib/teleterm/config.go index 794623db11b9e..0038863a5fb51 100644 --- a/lib/teleterm/config.go +++ b/lib/teleterm/config.go @@ -44,6 +44,8 @@ type Config struct { KubeconfigsDir string // AgentsDir contains agent config files and data directories for Connect My Computer. AgentsDir string + // InstallationID is a unique ID identifying a specific Teleport Connect installation. + InstallationID string } // CheckAndSetDefaults checks and sets default config values. @@ -77,5 +79,9 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("missing agents directory") } + if c.InstallationID == "" { + return trace.BadParameter("missing installation ID") + } + return nil } diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index 92e375ddd721c..80cc79d081946 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/clientcache" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/services/connectmycomputer" ) @@ -74,6 +75,10 @@ type Config struct { ConnectMyComputerNodeName *connectmycomputer.NodeName CreateClientCacheFunc func(resolver clientcache.NewClientFunc) (ClientCache, error) + // ClusterIDCache gets updated whenever daemon.Service.ResolveClusterWithDetails gets called. + // Since that method is called by the Electron app only for root clusters and typically only once + // after a successful login, this cache doesn't have to be cleared. + ClusterIDCache *clusteridcache.Cache } // ResolveClusterFunc returns a cluster by URI. @@ -174,5 +179,9 @@ func (c *Config) CheckAndSetDefaults() error { } } + if c.ClusterIDCache == nil { + c.ClusterIDCache = &clusteridcache.Cache{} + } + return nil } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 09cd4cf489d1e..497dffb7437db 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -280,7 +280,7 @@ func (s *Service) ResolveClusterWithDetails(ctx context.Context, uri string) (*c return nil, nil, trace.Wrap(err) } - withDetails, err := cluster.GetWithDetails(ctx, proxyClient.CurrentCluster()) + withDetails, err := cluster.GetWithDetails(ctx, proxyClient.CurrentCluster(), s.cfg.ClusterIDCache) if err != nil { return nil, nil, trace.Wrap(err) } @@ -1133,9 +1133,9 @@ func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gat // GetCachedClient returns a client from the cache if it exists, // otherwise it dials the remote server. -func (s *Service) GetCachedClient(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) { - profileName := clusterURI.GetProfileName() - leafClusterName := clusterURI.GetLeafClusterName() +func (s *Service) GetCachedClient(ctx context.Context, resourceURI uri.ResourceURI) (*client.ClusterClient, error) { + profileName := resourceURI.GetProfileName() + leafClusterName := resourceURI.GetLeafClusterName() clt, err := s.clientCache.Get(ctx, profileName, leafClusterName) return clt, trace.Wrap(err) } diff --git a/lib/teleterm/teleterm.go b/lib/teleterm/teleterm.go index 0c03e6ef8f4f9..d2348a7bb93d3 100644 --- a/lib/teleterm/teleterm.go +++ b/lib/teleterm/teleterm.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/gravitational/teleport/lib/teleterm/apiserver" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" ) @@ -55,12 +56,15 @@ func Serve(ctx context.Context, cfg Config) error { return trace.Wrap(err) } + clusterIDCache := &clusteridcache.Cache{} + daemonService, err := daemon.New(daemon.Config{ Storage: storage, CreateTshdEventsClientCredsFunc: grpcCredentials.tshdEvents, PrehogAddr: cfg.PrehogAddr, KubeconfigsDir: cfg.KubeconfigsDir, AgentsDir: cfg.AgentsDir, + ClusterIDCache: clusterIDCache, }) if err != nil { return trace.Wrap(err) @@ -72,6 +76,8 @@ func Serve(ctx context.Context, cfg Config) error { Daemon: daemonService, TshdServerCreds: grpcCredentials.tshd, ListeningC: cfg.ListeningC, + ClusterIDCache: clusterIDCache, + InstallationID: cfg.InstallationID, }) if err != nil { return trace.Wrap(err) diff --git a/lib/teleterm/teleterm_test.go b/lib/teleterm/teleterm_test.go index e836677f0b507..854273d71c683 100644 --- a/lib/teleterm/teleterm_test.go +++ b/lib/teleterm/teleterm_test.go @@ -120,6 +120,7 @@ func TestStart(t *testing.T) { ListeningC: listeningC, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + InstallationID: "foo", } ctx, cancel := context.WithCancel(context.Background()) diff --git a/lib/teleterm/vnet/service.go b/lib/teleterm/vnet/service.go index 3053caebd6d6e..79694156c99b5 100644 --- a/lib/teleterm/vnet/service.go +++ b/lib/teleterm/vnet/service.go @@ -21,16 +21,21 @@ import ( "crypto/tls" "errors" "sync" + "sync/atomic" "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport" vnetproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/vnet/v1" "github.com/gravitational/teleport/api/types" + prehogv1alpha "github.com/gravitational/teleport/gen/proto/go/prehog/v1alpha" apiteleterm "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/vnet/v1" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" + "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" logutils "github.com/gravitational/teleport/lib/utils/log" "github.com/gravitational/teleport/lib/vnet" @@ -54,6 +59,7 @@ type Service struct { mu sync.Mutex status status processManager *vnet.ProcessManager + usageReporter *usageReporter } // New creates an instance of Service. @@ -68,8 +74,17 @@ func New(cfg Config) (*Service, error) { } type Config struct { - DaemonService *daemon.Service + // DaemonService is used to get cached clients and for usage reporting. If DaemonService was not + // one giant blob of methods, Config could accept two separate services instead. + DaemonService *daemon.Service + // InsecureSkipVerify signifies whether VNet is going to verify the identity of the proxy service. InsecureSkipVerify bool + // ClusterIDCache is used for usage reporting to read cluster ID that needs to be included with + // every event. + ClusterIDCache *clusteridcache.Cache + // InstallationID is a unique ID of this particular Connect installation, used for usage + // reporting. + InstallationID string } // CheckAndSetDefaults checks and sets the defaults @@ -78,6 +93,14 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("missing DaemonService") } + if c.ClusterIDCache == nil { + return trace.BadParameter("missing ClusterIDCache") + } + + if c.InstallationID == "" { + return trace.BadParameter("missing InstallationID") + } + return nil } @@ -93,9 +116,25 @@ func (s *Service) Start(ctx context.Context, req *api.StartRequest) (*api.StartR return &api.StartResponse{}, nil } + usageReporter, err := NewUsageReporter(UsageReporterConfig{ + ClientCache: s.cfg.DaemonService, + EventConsumer: s.cfg.DaemonService, + ClusterIDCache: s.cfg.ClusterIDCache, + InstallationID: s.cfg.InstallationID, + }) + if err != nil { + return nil, trace.Wrap(err) + } + defer func() { + if s.status != statusRunning { + usageReporter.Stop() + } + }() + appProvider := &appProvider{ daemonService: s.cfg.DaemonService, insecureSkipVerify: s.cfg.InsecureSkipVerify, + usageReporter: usageReporter, } processManager, err := vnet.SetupAndRun(ctx, appProvider) @@ -124,6 +163,7 @@ func (s *Service) Start(ctx context.Context, req *api.StartRequest) (*api.StartR }() s.processManager = processManager + s.usageReporter = usageReporter s.status = statusRunning return &api.StartResponse{}, nil } @@ -155,6 +195,7 @@ func (s *Service) stopLocked() error { if err != nil && !errors.Is(err, context.Canceled) { return trace.Wrap(err) } + s.usageReporter.Stop() s.status = statusNotRunning return nil @@ -177,6 +218,7 @@ func (s *Service) Close() error { type appProvider struct { daemonService *daemon.Service + usageReporter *usageReporter insecureSkipVerify bool } @@ -271,3 +313,174 @@ func (p *appProvider) GetVnetConfig(ctx context.Context, profileName, leafCluste vnetConfig, err := vnetConfigClient.GetVnetConfig(ctx, &vnetproto.GetVnetConfigRequest{}) return vnetConfig, trace.Wrap(err) } + +// OnNewConnection submits a usage event once per appProvider lifetime. +// That is, if a user makes multiple connections to a single app, OnNewConnection submits a single +// event. This is to mimic how Connect submits events for its app gateways. This lets us compare +// popularity of VNet and app gateways. +func (p *appProvider) OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error { + // Enqueue the event from a separate goroutine since we don't care about errors anyway and we also + // don't want to slow down VNet connections. + go func() { + uri := uri.NewClusterURI(profileName).AppendLeafCluster(leafClusterName).AppendApp(app.GetName()) + + // Not passing ctx to ReportApp since ctx is tied to the lifetime of the connection. + // If it's a short-lived connection, inheriting its context would interrupt reporting. + err := p.usageReporter.ReportApp(uri) + if err != nil { + log.ErrorContext(ctx, "Failed to submit usage event", "app", uri, "error", err) + } + }() + + return nil +} + +type usageReporter struct { + cfg UsageReporterConfig + // reportedApps contains a set of URIs for apps which usage has been already reported. + // App gateways (local proxies) in Connect report a single event per gateway created per app. VNet + // needs to replicate this behavior, hence why it keeps track of reported apps to report only one + // event per app per VNet's lifespan. + reportedApps map[string]struct{} + // mu protects access to reportedApps. + mu sync.Mutex + // close is used to abort a ReportApp call that's currently in flight. + close chan struct{} + // closed signals that usageReporter has been stopped and no more events should be reported. + closed atomic.Bool +} + +type clientCache interface { + GetCachedClient(context.Context, uri.ResourceURI) (*client.ClusterClient, error) + ResolveClusterURI(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) +} + +type eventConsumer interface { + ReportUsageEvent(*apiteleterm.ReportUsageEventRequest) error +} + +type UsageReporterConfig struct { + ClientCache clientCache + EventConsumer eventConsumer + // clusterIDCache stores cluster ID that needs to be included with each usage event. It's updated + // outside of usageReporter – the middleware merely reads data from it. If the cache does not + // contain the given cluster ID, usageReporter drops the event. + ClusterIDCache *clusteridcache.Cache + InstallationID string +} + +func (c *UsageReporterConfig) CheckAndSetDefaults() error { + if c.ClientCache == nil { + return trace.BadParameter("missing ClientCache") + } + + if c.EventConsumer == nil { + return trace.BadParameter("missing EventConsumer") + } + + if c.ClusterIDCache == nil { + return trace.BadParameter("missing ClusterIDCache") + } + + if c.InstallationID == "" { + return trace.BadParameter("missing InstallationID") + } + + return nil +} + +func NewUsageReporter(cfg UsageReporterConfig) (*usageReporter, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &usageReporter{ + cfg: cfg, + reportedApps: make(map[string]struct{}), + close: make(chan struct{}), + }, nil +} + +// ReportApp adds an event related to the given app to the events queue, if the app wasn't reported +// already. Only one invocation of ReportApp can be in flight at a time. +func (r *usageReporter) ReportApp(appURI uri.ResourceURI) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed.Load() { + return trace.CompareFailed("usage reporter has been stopped") + } + + if _, hasAppBeenReported := r.reportedApps[appURI.String()]; hasAppBeenReported { + return nil + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + select { + case <-r.close: + cancel() + case <-ctx.Done(): + } + }() + + rootClusterURI := appURI.GetRootClusterURI() + client, err := r.cfg.ClientCache.GetCachedClient(ctx, appURI) + if err != nil { + return trace.Wrap(err) + } + rootClusterName := client.RootClusterName() + _, tc, err := r.cfg.ClientCache.ResolveClusterURI(appURI) + if err != nil { + return trace.Wrap(err) + } + + clusterID, ok := r.cfg.ClusterIDCache.Load(rootClusterURI) + if !ok { + return trace.NotFound("cluster ID for %q not found", rootClusterURI) + } + + log.DebugContext(ctx, "Reporting usage event", "app", appURI.String()) + + err = r.cfg.EventConsumer.ReportUsageEvent(&apiteleterm.ReportUsageEventRequest{ + AuthClusterId: clusterID, + PrehogReq: &prehogv1alpha.SubmitConnectEventRequest{ + DistinctId: r.cfg.InstallationID, + Timestamp: timestamppb.Now(), + Event: &prehogv1alpha.SubmitConnectEventRequest_ProtocolUse{ + ProtocolUse: &prehogv1alpha.ConnectProtocolUseEvent{ + ClusterName: rootClusterName, + UserName: tc.Username, + Protocol: "app", + Origin: "vnet", + AccessThrough: "vnet", + }, + }, + }, + }) + if err != nil { + return trace.Wrap(err, "adding usage event to queue") + } + + r.reportedApps[appURI.String()] = struct{}{} + + return nil +} + +// Stop aborts the reporting of an event that's currently in progress and prevents further events +// from being reported. It blocks until the current ReportApp call aborts. +func (r *usageReporter) Stop() { + if r.closed.Load() { + return + } + + // Prevent new calls to ReportApp from being made. + r.closed.Store(true) + // Abort context of the ReportApp call currently in flight. + close(r.close) + // Block until the current ReportApp call aborts. + r.mu.Lock() + defer r.mu.Unlock() +} diff --git a/lib/teleterm/vnet/service_test.go b/lib/teleterm/vnet/service_test.go new file mode 100644 index 0000000000000..9fe1d8ea9da74 --- /dev/null +++ b/lib/teleterm/vnet/service_test.go @@ -0,0 +1,163 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vnet + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + teletermv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" + "github.com/gravitational/teleport/lib/teleterm/clusters" +) + +func TestUsageReporter(t *testing.T) { + eventConsumer := fakeEventConsumer{} + + validCluster := uri.NewClusterURI("foo") + clusterWithoutClient := uri.NewClusterURI("no-client") + clusterWithoutProfile := uri.NewClusterURI("no-profile") + clusterWithoutClusterID := uri.NewClusterURI("no-cluster-id") + + clientCache := fakeClientCache{ + validClusterURIs: map[uri.ResourceURI]struct{}{ + validCluster: struct{}{}, + clusterWithoutProfile: struct{}{}, + clusterWithoutClusterID: struct{}{}, + }, + } + + clusterIDcache := clusteridcache.Cache{} + clusterIDcache.Store(uri.NewClusterURI("foo"), "1234") + + usageReporter, err := NewUsageReporter(UsageReporterConfig{ + EventConsumer: &eventConsumer, + ClientCache: &clientCache, + ClusterIDCache: &clusterIDcache, + InstallationID: "4321", + }) + require.NoError(t, err) + t.Cleanup(usageReporter.Stop) + + // Verify that reporting the same app twice adds only one usage event. + err = usageReporter.ReportApp(validCluster.AppendApp("app")) + require.NoError(t, err) + err = usageReporter.ReportApp(validCluster.AppendApp("app")) + require.NoError(t, err) + require.Equal(t, 1, eventConsumer.EventCount()) + + // Verify that reporting an invalid cluster doesn't submit an event. + err = usageReporter.ReportApp(clusterWithoutClient.AppendApp("bar")) + require.True(t, trace.IsNotFound(err), "Not a NotFound error: %#v", err) + require.Equal(t, 1, eventConsumer.EventCount()) + err = usageReporter.ReportApp(clusterWithoutProfile.AppendApp("bar")) + require.True(t, trace.IsNotFound(err), "Not a NotFound error: %#v", err) + require.Equal(t, 1, eventConsumer.EventCount()) + err = usageReporter.ReportApp(clusterWithoutClusterID.AppendApp("bar")) + require.ErrorIs(t, err, trace.NotFound("cluster ID for \"/clusters/no-cluster-id\" not found")) + require.Equal(t, 1, eventConsumer.EventCount()) +} + +func TestUsageReporter_Stop(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + eventConsumer := fakeEventConsumer{} + clientCache := fakeClientCache{blockingOnCtxC: make(chan struct{}, 1)} + clusterIDCache := clusteridcache.Cache{} + + usageReporter, err := NewUsageReporter(UsageReporterConfig{ + EventConsumer: &eventConsumer, + ClientCache: &clientCache, + ClusterIDCache: &clusterIDCache, + InstallationID: "4321", + }) + require.NoError(t, err) + t.Cleanup(usageReporter.Stop) + + go func() { + select { + case <-ctx.Done(): + case <-clientCache.blockingOnCtxC: + // Wait for ReportApp to start blocking on GetCachedClient. + } + usageReporter.Stop() + }() + + uri := uri.NewClusterURI("foo").AppendApp("bar") + err = usageReporter.ReportApp(uri) + require.ErrorIs(t, err, context.Canceled) + + err = usageReporter.ReportApp(uri) + require.True(t, trace.IsCompareFailed(err), "expected trace.CompareFailed but got %v", err) +} + +type fakeEventConsumer struct { + mu sync.Mutex + events []*teletermv1.ReportUsageEventRequest +} + +func (ec *fakeEventConsumer) ReportUsageEvent(event *teletermv1.ReportUsageEventRequest) error { + ec.mu.Lock() + defer ec.mu.Unlock() + + ec.events = append(ec.events, event) + return nil +} + +func (ec *fakeEventConsumer) EventCount() int { + ec.mu.Lock() + defer ec.mu.Unlock() + + return len(ec.events) +} + +type fakeClientCache struct { + validClusterURIs map[uri.ResourceURI]struct{} + // blockingOnCtxC makes GetCachedClient block until ctx is canceled. fakeClientCache writes to the + // channel just before GetCachedClient starts to block on ctx. + blockingOnCtxC chan struct{} +} + +func (c *fakeClientCache) GetCachedClient(ctx context.Context, appURI uri.ResourceURI) (*client.ClusterClient, error) { + if c.blockingOnCtxC != nil { + c.blockingOnCtxC <- struct{}{} + + <-ctx.Done() + return nil, trace.Wrap(ctx.Err()) + } + + if _, ok := c.validClusterURIs[appURI.GetClusterURI()]; !ok { + return nil, trace.NotFound("client for cluster %q not found", appURI.GetClusterURI()) + } + + return &client.ClusterClient{}, nil +} + +func (c *fakeClientCache) ResolveClusterURI(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) { + if _, ok := c.validClusterURIs[uri.GetClusterURI()]; !ok { + return nil, nil, trace.NotFound("client for cluster %q not found", uri.GetClusterURI()) + } + + return &clusters.Cluster{}, &client.TeleportClient{Config: client.Config{Username: "alice"}}, nil +} diff --git a/lib/vnet/app_resolver.go b/lib/vnet/app_resolver.go index 46a62f8016e96..11813e6ca18b5 100644 --- a/lib/vnet/app_resolver.go +++ b/lib/vnet/app_resolver.go @@ -62,6 +62,14 @@ type AppProvider interface { // GetVnetConfig returns the cluster VnetConfig resource. GetVnetConfig(ctx context.Context, profileName, leafClusterName string) (*vnet.VnetConfig, error) + + // OnNewConnection gets called whenever a new connection is about to be established through VNet. + // By the time OnNewConnection, VNet has already verified that the user holds a valid cert for the + // app. + // + // The connection won't be established until OnNewConnection returns. Returning an error prevents + // the connection from being made. + OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error } // DialOptions holds ALPN dial options for dialing apps. @@ -227,7 +235,14 @@ func (r *TCPAppResolver) newTCPAppHandler( leafClusterName: leafClusterName, app: app, } - middleware := client.NewCertChecker(appCertIssuer, r.clock) + certChecker := client.NewCertChecker(appCertIssuer, r.clock) + middleware := &localProxyMiddleware{ + certChecker: certChecker, + appProvider: r.appProvider, + app: app, + profileName: profileName, + leafClusterName: leafClusterName, + } localProxyConfig := alpnproxy.LocalProxyConfig{ RemoteProxyAddr: dialOpts.WebProxyAddr, @@ -293,3 +308,26 @@ func fullyQualify(domain string) string { } return domain + "." } + +// localProxyMiddleware wraps around [client.CertChecker] and additionally makes it so that its +// OnNewConnection method calls the same method of [AppProvider]. +type localProxyMiddleware struct { + app types.Application + profileName string + leafClusterName string + certChecker *client.CertChecker + appProvider AppProvider +} + +func (m *localProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy) error { + err := m.certChecker.OnNewConnection(ctx, lp) + if err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(m.appProvider.OnNewConnection(ctx, m.profileName, m.leafClusterName, m.app)) +} + +func (m *localProxyMiddleware) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) error { + return trace.Wrap(m.certChecker.OnStart(ctx, lp)) +} diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index 84bf8434faf63..4243d665081e4 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -33,6 +33,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "testing" "time" @@ -236,9 +237,10 @@ type testClusterSpec struct { } type echoAppProvider struct { - clusters map[string]testClusterSpec - dialOpts DialOptions - reissueAppCert func() tls.Certificate + clusters map[string]testClusterSpec + dialOpts DialOptions + reissueAppCert func() tls.Certificate + onNewConnectionCallCount atomic.Uint32 } // newEchoAppProvider returns an app provider with the list of named apps in each profile and leaf cluster. @@ -330,6 +332,12 @@ func (p *echoAppProvider) GetVnetConfig(ctx context.Context, profileName, leafCl }, nil } +func (p *echoAppProvider) OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error { + p.onNewConnectionCallCount.Add(1) + + return nil +} + // echoAppAuthClient is a fake auth client that answers GetResources requests with a static list of apps and // basic/faked predicate filtering. type echoAppAuthClient struct { @@ -377,87 +385,8 @@ func TestDialFakeApp(t *testing.T) { t.Cleanup(cancel) clock := clockwork.NewFakeClockAt(time.Now()) - ca := newSelfSignedCA(t) - - roots := x509.NewCertPool() - caX509, err := x509.ParseCertificate(ca.Certificate[0]) - require.NoError(t, err) - roots.AddCert(caX509) - - const proxyCN = "testproxy" - proxyCert := newServerCert(t, ca, proxyCN, clock.Now().Add(365*24*time.Hour)) - - proxyTLSConfig := &tls.Config{ - Certificates: []tls.Certificate{proxyCert}, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: roots, - } - - listener, err := tls.Listen("tcp", "localhost:0", proxyTLSConfig) - require.NoError(t, err) - - // Run a fake web proxy that will accept any client connection and echo the input back. - utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ - Name: "web proxy", - Task: func(ctx context.Context) error { - for { - conn, err := listener.Accept() - if err != nil { - if utils.IsOKNetworkError(err) { - return nil - } - return trace.Wrap(err) - } - go func() { - defer conn.Close() - - // Not using require/assert here and below because this is not in the right subtest or in - // the main test goroutine. The test will fail if the conn is not handled. - tlsConn, ok := conn.(*tls.Conn) - if !ok { - t.Log("client conn is not TLS") - return - } - if err := tlsConn.Handshake(); err != nil { - t.Log("error completing tls handshake") - return - } - clientCerts := tlsConn.ConnectionState().PeerCertificates - if len(clientCerts) == 0 { - t.Log("client has no certs") - return - } - // Manually checking the cert expiry compared to the time of the fake clock, since the TLS - // library will only compare the cert expiry to the real clock. - // It's important that the fake clock is never far behind the real clock, and that the - // cert NotBefore is always at/before the real current time, so the TLS library is - // satisfied. - if clock.Now().After(clientCerts[0].NotAfter) { - t.Logf("client cert is expired: currentTime=%s expiry=%s", clock.Now(), clientCerts[0].NotAfter) - return - } - - _, err := io.Copy(conn, conn) - if err != nil && !utils.IsOKNetworkError(err) { - t.Logf("error in io.Copy for echo proxy server: %v", err) - } - }() - } - }, - Terminate: func() error { - if err := listener.Close(); !utils.IsOKNetworkError(err) { - return trace.Wrap(err) - } - return nil - }, - }) - - dialOpts := DialOptions{ - WebProxyAddr: listener.Addr().String(), - RootClusterCACertPool: roots, - SNI: proxyCN, - } + dialOpts := mustStartFakeWebProxy(ctx, t, ca, clock) const appCertLifetime = time.Hour reissueClientCert := func() tls.Certificate { @@ -588,6 +517,48 @@ func testEchoConnection(t *testing.T, conn net.Conn) { } } +func TestOnNewConnection(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + clock := clockwork.NewFakeClockAt(time.Now()) + ca := newSelfSignedCA(t) + dialOpts := mustStartFakeWebProxy(ctx, t, ca, clock) + + const appCertLifetime = time.Hour + reissueClientCert := func() tls.Certificate { + return newClientCert(t, ca, "testclient", clock.Now().Add(appCertLifetime)) + } + + appProvider := newEchoAppProvider(map[string]testClusterSpec{ + "root1.example.com": { + apps: []string{"echo1"}, + cidrRange: "192.168.2.0/24", + leafClusters: map[string]testClusterSpec{}, + }, + }, dialOpts, reissueClientCert) + + validAppName := "echo1.root1.example.com" + invalidAppName := "not.an.app.example.com." + + p := newTestPack(t, ctx, clock, appProvider) + + // Attempt to establish a connection to an invalid app and verify that OnNewConnection was not + // called. + lookupCtx, lookupCtxCancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer lookupCtxCancel() + _, err := p.lookupHost(lookupCtx, invalidAppName) + require.Error(t, err, "Expected lookup of an invalid app to fail") + require.Equal(t, uint32(0), appProvider.onNewConnectionCallCount.Load()) + + // Establish a connection to a valid app and verify that OnNewConnection was called. + conn, err := p.dialHost(ctx, validAppName) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, conn.Close()) }) + require.Equal(t, uint32(1), appProvider.onNewConnectionCallCount.Load()) +} + func randomULAAddress() (tcpip.Address, error) { var bytes [16]byte bytes[0] = 0xfd @@ -733,3 +704,88 @@ func newLeafCert(t *testing.T, ca tls.Certificate, cn string, expires time.Time, PrivateKey: priv, } } + +func mustStartFakeWebProxy(ctx context.Context, t *testing.T, ca tls.Certificate, clock clockwork.FakeClock) DialOptions { + t.Helper() + + roots := x509.NewCertPool() + caX509, err := x509.ParseCertificate(ca.Certificate[0]) + require.NoError(t, err) + roots.AddCert(caX509) + + const proxyCN = "testproxy" + proxyCert := newServerCert(t, ca, proxyCN, clock.Now().Add(365*24*time.Hour)) + + proxyTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{proxyCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: roots, + } + + listener, err := tls.Listen("tcp", "localhost:0", proxyTLSConfig) + require.NoError(t, err) + + // Run a fake web proxy that will accept any client connection and echo the input back. + utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ + Name: "web proxy", + Task: func(ctx context.Context) error { + for { + conn, err := listener.Accept() + if err != nil { + if utils.IsOKNetworkError(err) { + return nil + } + return trace.Wrap(err) + } + go func() { + defer conn.Close() + + // Not using require/assert here and below because this is not in the right subtest or in + // the main test goroutine. The test will fail if the conn is not handled. + tlsConn, ok := conn.(*tls.Conn) + if !ok { + t.Log("client conn is not TLS") + return + } + if err := tlsConn.Handshake(); err != nil { + t.Log("error completing tls handshake") + return + } + clientCerts := tlsConn.ConnectionState().PeerCertificates + if len(clientCerts) == 0 { + t.Log("client has no certs") + return + } + // Manually checking the cert expiry compared to the time of the fake clock, since the TLS + // library will only compare the cert expiry to the real clock. + // It's important that the fake clock is never far behind the real clock, and that the + // cert NotBefore is always at/before the real current time, so the TLS library is + // satisfied. + if clock.Now().After(clientCerts[0].NotAfter) { + t.Logf("client cert is expired: currentTime=%s expiry=%s", clock.Now(), clientCerts[0].NotAfter) + return + } + + _, err := io.Copy(conn, conn) + if err != nil && !utils.IsOKNetworkError(err) { + t.Logf("error in io.Copy for echo proxy server: %v", err) + } + }() + } + }, + Terminate: func() error { + if err := listener.Close(); !utils.IsOKNetworkError(err) { + return trace.Wrap(err) + } + return nil + }, + }) + + dialOpts := DialOptions{ + WebProxyAddr: listener.Addr().String(), + RootClusterCACertPool: roots, + SNI: proxyCN, + } + + return dialOpts +} diff --git a/tool/tsh/common/daemon.go b/tool/tsh/common/daemon.go index 6892f729e4f9b..cc31f4ba8e795 100644 --- a/tool/tsh/common/daemon.go +++ b/tool/tsh/common/daemon.go @@ -49,6 +49,7 @@ func onDaemonStart(cf *CLIConf) error { PrehogAddr: cf.DaemonPrehogAddr, KubeconfigsDir: cf.DaemonKubeconfigsDir, AgentsDir: cf.DaemonAgentsDir, + InstallationID: cf.DaemonInstallationID, }) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 8028a641f29a6..75ce523fda0ac 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -225,6 +225,8 @@ type CLIConf struct { DaemonAgentsDir string // DaemonPid is the PID to be stopped by tsh daemon stop. DaemonPid int + // DaemonInstallationID is a unique ID identifying a specific Teleport Connect installation. + DaemonInstallationID string // DatabaseService specifies the database proxy server to log into. DatabaseService string @@ -790,6 +792,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { daemonStart.Flag("prehog-addr", "URL where prehog events should be submitted").StringVar(&cf.DaemonPrehogAddr) daemonStart.Flag("kubeconfigs-dir", "Directory containing kubeconfig for Kubernetes Access").StringVar(&cf.DaemonKubeconfigsDir) daemonStart.Flag("agents-dir", "Directory containing agent config files and data directories for Connect My Computer").StringVar(&cf.DaemonAgentsDir) + daemonStart.Flag("installation-id", "Unique ID identifying a specific Teleport Connect installation").StringVar(&cf.DaemonInstallationID) daemonStop := daemon.Command("stop", "Gracefully stops a process on Windows by sending Ctrl-Break to it.").Hidden() daemonStop.Flag("pid", "PID to be stopped").IntVar(&cf.DaemonPid) diff --git a/tool/tsh/common/vnet_common.go b/tool/tsh/common/vnet_common.go index f753d3124b580..754c61f9277dc 100644 --- a/tool/tsh/common/vnet_common.go +++ b/tool/tsh/common/vnet_common.go @@ -126,6 +126,12 @@ func (p *vnetAppProvider) GetVnetConfig(ctx context.Context, profileName, leafCl return vnetConfig, trace.Wrap(err) } +// OnNewConnection gets called before each VNet connection. It's a noop as tsh doesn't need to do +// anything extra here. +func (p *vnetAppProvider) OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error { + return nil +} + // getRootClusterCACertPool returns a certificate pool for the root cluster of the given profile. func (p *vnetAppProvider) getRootClusterCACertPool(ctx context.Context, profileName string) (*x509.CertPool, error) { tc, err := p.newTeleportClient(ctx, profileName, "") diff --git a/web/packages/teleterm/src/mainProcess/runtimeSettings.ts b/web/packages/teleterm/src/mainProcess/runtimeSettings.ts index f757bf22353c9..e93cdaabea4dd 100644 --- a/web/packages/teleterm/src/mainProcess/runtimeSettings.ts +++ b/web/packages/teleterm/src/mainProcess/runtimeSettings.ts @@ -78,6 +78,9 @@ export function getRuntimeSettings(): RuntimeSettings { const logsDir = path.join(userDataDir, 'logs'); // DO NOT expose agentsDir through RuntimeSettings. See the comment in getAgentsDir. const agentsDir = getAgentsDir(userDataDir); + const installationId = loadInstallationId( + path.resolve(app.getPath('userData'), 'installation_id') + ); const tshd = { binaryPath: tshBinPath, @@ -93,6 +96,7 @@ export function getRuntimeSettings(): RuntimeSettings { `--prehog-addr=${staticConfig.prehogAddress}`, `--kubeconfigs-dir=${kubeConfigsDir}`, `--agents-dir=${agentsDir}`, + `--installation-id=${installationId}`, ], }; const sharedProcess = { @@ -135,9 +139,7 @@ export function getRuntimeSettings(): RuntimeSettings { kubeConfigsDir, logsDir, platform: process.platform, - installationId: loadInstallationId( - path.resolve(app.getPath('userData'), 'installation_id') - ), + installationId, arch: os.arch(), osVersion: os.release(), appVersion,