From 2d14f0b208fd50958360f278e21dbff1f6e4e0d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Cie=C5=9Blak?= Date: Tue, 4 Jun 2024 18:18:12 +0200 Subject: [PATCH] Add telemetry to VNet in Connect (#41587) * Add OnNewConnection to AppProvider * Cache cluster IDs * Pass installation ID to tsh daemon It will be needed to report usage events straight from tsh daemon. It used to be available only in the Electron app which sent this ID with every ReportUsageEvent RPC. * apiserver.New: Create listener after services This way if initializing a service fails, we don't create a listener unnecessarily. * GetCachedClient: Rename argument to clarify usage * Report usage event on VNet connection * Remove debug log when app was already reported * Reuse context in tests * Extract fake web proxy setup to separate function * usageReporter.ReportApp: Use background ctx instead of TCP conn ctx --- integration/teleterm_test.go | 17 +- lib/teleterm/apiserver/apiserver.go | 26 +- lib/teleterm/apiserver/config.go | 13 +- lib/teleterm/clusteridcache/clusteridcache.go | 61 +++++ lib/teleterm/clusters/cluster.go | 4 +- lib/teleterm/config.go | 6 + lib/teleterm/daemon/config.go | 9 + lib/teleterm/daemon/daemon.go | 8 +- lib/teleterm/teleterm.go | 6 + lib/teleterm/teleterm_test.go | 1 + lib/teleterm/vnet/service.go | 215 ++++++++++++++++- lib/teleterm/vnet/service_test.go | 163 +++++++++++++ lib/vnet/app_resolver.go | 40 +++- lib/vnet/vnet_test.go | 222 +++++++++++------- tool/tsh/common/daemon.go | 1 + tool/tsh/common/tsh.go | 3 + tool/tsh/common/vnet_common.go | 6 + .../src/mainProcess/runtimeSettings.ts | 8 +- 18 files changed, 700 insertions(+), 109 deletions(-) create mode 100644 lib/teleterm/clusteridcache/clusteridcache.go create mode 100644 lib/teleterm/vnet/service_test.go diff --git a/integration/teleterm_test.go b/integration/teleterm_test.go index 5a1b900947bf7..14bf6853c7c1b 100644 --- a/integration/teleterm_test.go +++ b/integration/teleterm_test.go @@ -57,6 +57,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleterm/api/uri" "github.com/gravitational/teleport/lib/teleterm/apiserver/handler" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" libutils "github.com/gravitational/teleport/lib/utils" @@ -98,7 +99,7 @@ func TestTeleterm(t *testing.T) { testGetClusterReturnsPropertiesFromAuthServer(t, pack) }) - t.Run("Test headless watcher", func(t *testing.T) { + t.Run("headless watcher", func(t *testing.T) { t.Parallel() testHeadlessWatcher(t, pack, creds) @@ -124,7 +125,7 @@ func TestTeleterm(t *testing.T) { testDeleteConnectMyComputerNode(t, pack) }) - t.Run("TestClientCache", func(t *testing.T) { + t.Run("client cache", func(t *testing.T) { t.Parallel() testClientCache(t, pack, creds) @@ -362,10 +363,13 @@ func testGetClusterReturnsPropertiesFromAuthServer(t *testing.T, pack *dbhelpers }) require.NoError(t, err) + clusterIDCache := clusteridcache.Cache{} + daemonService, err := daemon.New(daemon.Config{ Storage: storage, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + ClusterIDCache: &clusterIDCache, }) require.NoError(t, err) t.Cleanup(func() { @@ -381,15 +385,22 @@ func testGetClusterReturnsPropertiesFromAuthServer(t *testing.T, pack *dbhelpers rootClusterName, _, err := net.SplitHostPort(pack.Root.Cluster.Web) require.NoError(t, err) + clusterURI := uri.NewClusterURI(rootClusterName) response, err := handler.GetCluster(context.Background(), &api.GetClusterRequest{ - ClusterUri: uri.NewClusterURI(rootClusterName).String(), + ClusterUri: clusterURI.String(), }) require.NoError(t, err) require.Equal(t, userName, response.LoggedInUser.Name) require.ElementsMatch(t, []string{requestableRoleName}, response.LoggedInUser.RequestableRoles) require.ElementsMatch(t, []string{suggestedReviewer}, response.LoggedInUser.SuggestedReviewers) + + // Verify that cluster ID cache gets updated. + clusterIDFromCache, ok := clusterIDCache.Load(clusterURI) + require.True(t, ok, "ID for cluster %q was not found in the cache", clusterURI) + require.NotEmpty(t, clusterIDFromCache) + require.Equal(t, response.AuthClusterId, clusterIDFromCache) } func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.UserCreds) { diff --git a/lib/teleterm/apiserver/apiserver.go b/lib/teleterm/apiserver/apiserver.go index 5ff1793a0f4c5..40e387638877e 100644 --- a/lib/teleterm/apiserver/apiserver.go +++ b/lib/teleterm/apiserver/apiserver.go @@ -40,18 +40,6 @@ func New(cfg Config) (*APIServer, error) { return nil, trace.Wrap(err) } - // Create the listener, set up the server. - - ls, err := newListener(cfg.HostAddr, cfg.ListeningC) - if err != nil { - return nil, trace.Wrap(err) - } - - grpcServer := grpc.NewServer(cfg.TshdServerCreds, - grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Log)), - grpc.MaxConcurrentStreams(defaults.GRPCMaxConcurrentStreams), - ) - // Create Terminal and VNet services. serviceHandler, err := handler.New( @@ -66,11 +54,25 @@ func New(cfg Config) (*APIServer, error) { vnetService, err := vnet.New(vnet.Config{ DaemonService: cfg.Daemon, InsecureSkipVerify: cfg.InsecureSkipVerify, + ClusterIDCache: cfg.ClusterIDCache, + InstallationID: cfg.InstallationID, }) if err != nil { return nil, trace.Wrap(err) } + // Create the listener, set up the server. + + ls, err := newListener(cfg.HostAddr, cfg.ListeningC) + if err != nil { + return nil, trace.Wrap(err) + } + + grpcServer := grpc.NewServer(cfg.TshdServerCreds, + grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Log)), + grpc.MaxConcurrentStreams(defaults.GRPCMaxConcurrentStreams), + ) + api.RegisterTerminalServiceServer(grpcServer, serviceHandler) vnetapi.RegisterVnetServiceServer(grpcServer, vnetService) diff --git a/lib/teleterm/apiserver/config.go b/lib/teleterm/apiserver/config.go index 086bb957c10f1..1f296dce88893 100644 --- a/lib/teleterm/apiserver/config.go +++ b/lib/teleterm/apiserver/config.go @@ -24,6 +24,7 @@ import ( "google.golang.org/grpc" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/daemon" "github.com/gravitational/teleport/lib/utils" ) @@ -34,7 +35,9 @@ type Config struct { HostAddr string InsecureSkipVerify bool // Daemon is the terminal daemon service - Daemon *daemon.Service + Daemon *daemon.Service + ClusterIDCache *clusteridcache.Cache + InstallationID string // Log is a component logger Log logrus.FieldLogger TshdServerCreds grpc.ServerOption @@ -65,5 +68,13 @@ func (c *Config) CheckAndSetDefaults() error { c.Log = logrus.WithField(teleport.ComponentKey, "conn:apiserver") } + if c.InstallationID == "" { + return trace.BadParameter("missing installation ID") + } + + if c.ClusterIDCache == nil { + c.ClusterIDCache = &clusteridcache.Cache{} + } + return nil } diff --git a/lib/teleterm/clusteridcache/clusteridcache.go b/lib/teleterm/clusteridcache/clusteridcache.go new file mode 100644 index 0000000000000..23a0558cfdb73 --- /dev/null +++ b/lib/teleterm/clusteridcache/clusteridcache.go @@ -0,0 +1,61 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package clusteridcache + +import ( + "sync" + + "github.com/gravitational/teleport/lib/teleterm/api/uri" +) + +// Cache stores cluster IDs indexed by their cluster URIs. +// +// Cluster IDs are required when reporting usage events, but they are not publicly known and can be +// fetched only after logging in to a cluster. Today, most events are sent from the Electron app. +// The Electron app caches cluster IDs on its own. However, sometimes we want to send events +// straight from the tsh daemon, in which case we need to know the ID of a cluster. +// +// Whenever the user logs in and fetches full details of a cluster, the cluster ID gets saved to the +// cache. Later on when tsh daemon wants to send a usage event, it can load the cluster ID from the +// cache. +// +// This cache is never cleared since cluster IDs are saved only for root clusters. Logging in to a +// root cluster overwrites existing ID under the same URI. +// +// TODO(ravicious): Refactor usage reporting to operate on cluster URIs instead of cluster IDs and +// keep the cache only on the side of tsh daemon. Fetch a cluster ID whenever it's first requested +// to avoid an issue with trying to send a usage event before the cluster ID is known. +// https://github.com/gravitational/teleport/issues/23030 +type Cache struct { + m sync.Map +} + +// Store stores the cluster ID for the given uri of a root cluster. +func (c *Cache) Store(uri uri.ResourceURI, clusterID string) { + c.m.Store(uri.String(), clusterID) +} + +// Load returns the cluster ID for the given uri of a root cluster. +func (c *Cache) Load(uri uri.ResourceURI) (string, bool) { + id, ok := c.m.Load(uri.String()) + + if !ok { + return "", false + } + + return id.(string), true +} diff --git a/lib/teleterm/clusters/cluster.go b/lib/teleterm/clusters/cluster.go index 04f7d94873fe2..0d4c16a4a8d4b 100644 --- a/lib/teleterm/clusters/cluster.go +++ b/lib/teleterm/clusters/cluster.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" ) // Cluster describes user settings and access to various resources. @@ -88,7 +89,7 @@ func (c *Cluster) Connected() bool { // GetWithDetails makes requests to the auth server to return details of the current // Cluster that cannot be found on the disk only, including details about the user // and enabled enterprise features. This method requires a valid cert. -func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.ClientI) (*ClusterWithDetails, error) { +func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.ClientI, clusterIDCache *clusteridcache.Cache) (*ClusterWithDetails, error) { var ( clusterPingResponse *webclient.PingResponse webConfig *webclient.WebConfig @@ -142,6 +143,7 @@ func (c *Cluster) GetWithDetails(ctx context.Context, authClient authclient.Clie return trace.Wrap(err) } authClusterID = clusterName.GetClusterID() + clusterIDCache.Store(c.URI, authClusterID) return nil }) return trace.Wrap(err) diff --git a/lib/teleterm/config.go b/lib/teleterm/config.go index 794623db11b9e..0038863a5fb51 100644 --- a/lib/teleterm/config.go +++ b/lib/teleterm/config.go @@ -44,6 +44,8 @@ type Config struct { KubeconfigsDir string // AgentsDir contains agent config files and data directories for Connect My Computer. AgentsDir string + // InstallationID is a unique ID identifying a specific Teleport Connect installation. + InstallationID string } // CheckAndSetDefaults checks and sets default config values. @@ -77,5 +79,9 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("missing agents directory") } + if c.InstallationID == "" { + return trace.BadParameter("missing installation ID") + } + return nil } diff --git a/lib/teleterm/daemon/config.go b/lib/teleterm/daemon/config.go index 92e375ddd721c..80cc79d081946 100644 --- a/lib/teleterm/daemon/config.go +++ b/lib/teleterm/daemon/config.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/clientcache" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/services/connectmycomputer" ) @@ -74,6 +75,10 @@ type Config struct { ConnectMyComputerNodeName *connectmycomputer.NodeName CreateClientCacheFunc func(resolver clientcache.NewClientFunc) (ClientCache, error) + // ClusterIDCache gets updated whenever daemon.Service.ResolveClusterWithDetails gets called. + // Since that method is called by the Electron app only for root clusters and typically only once + // after a successful login, this cache doesn't have to be cleared. + ClusterIDCache *clusteridcache.Cache } // ResolveClusterFunc returns a cluster by URI. @@ -174,5 +179,9 @@ func (c *Config) CheckAndSetDefaults() error { } } + if c.ClusterIDCache == nil { + c.ClusterIDCache = &clusteridcache.Cache{} + } + return nil } diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index 09cd4cf489d1e..497dffb7437db 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -280,7 +280,7 @@ func (s *Service) ResolveClusterWithDetails(ctx context.Context, uri string) (*c return nil, nil, trace.Wrap(err) } - withDetails, err := cluster.GetWithDetails(ctx, proxyClient.CurrentCluster()) + withDetails, err := cluster.GetWithDetails(ctx, proxyClient.CurrentCluster(), s.cfg.ClusterIDCache) if err != nil { return nil, nil, trace.Wrap(err) } @@ -1133,9 +1133,9 @@ func (s *Service) findGatewayByTargetURI(targetURI uri.ResourceURI) (gateway.Gat // GetCachedClient returns a client from the cache if it exists, // otherwise it dials the remote server. -func (s *Service) GetCachedClient(ctx context.Context, clusterURI uri.ResourceURI) (*client.ClusterClient, error) { - profileName := clusterURI.GetProfileName() - leafClusterName := clusterURI.GetLeafClusterName() +func (s *Service) GetCachedClient(ctx context.Context, resourceURI uri.ResourceURI) (*client.ClusterClient, error) { + profileName := resourceURI.GetProfileName() + leafClusterName := resourceURI.GetLeafClusterName() clt, err := s.clientCache.Get(ctx, profileName, leafClusterName) return clt, trace.Wrap(err) } diff --git a/lib/teleterm/teleterm.go b/lib/teleterm/teleterm.go index 0c03e6ef8f4f9..d2348a7bb93d3 100644 --- a/lib/teleterm/teleterm.go +++ b/lib/teleterm/teleterm.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/credentials/insecure" "github.com/gravitational/teleport/lib/teleterm/apiserver" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" ) @@ -55,12 +56,15 @@ func Serve(ctx context.Context, cfg Config) error { return trace.Wrap(err) } + clusterIDCache := &clusteridcache.Cache{} + daemonService, err := daemon.New(daemon.Config{ Storage: storage, CreateTshdEventsClientCredsFunc: grpcCredentials.tshdEvents, PrehogAddr: cfg.PrehogAddr, KubeconfigsDir: cfg.KubeconfigsDir, AgentsDir: cfg.AgentsDir, + ClusterIDCache: clusterIDCache, }) if err != nil { return trace.Wrap(err) @@ -72,6 +76,8 @@ func Serve(ctx context.Context, cfg Config) error { Daemon: daemonService, TshdServerCreds: grpcCredentials.tshd, ListeningC: cfg.ListeningC, + ClusterIDCache: clusterIDCache, + InstallationID: cfg.InstallationID, }) if err != nil { return trace.Wrap(err) diff --git a/lib/teleterm/teleterm_test.go b/lib/teleterm/teleterm_test.go index e836677f0b507..854273d71c683 100644 --- a/lib/teleterm/teleterm_test.go +++ b/lib/teleterm/teleterm_test.go @@ -120,6 +120,7 @@ func TestStart(t *testing.T) { ListeningC: listeningC, KubeconfigsDir: t.TempDir(), AgentsDir: t.TempDir(), + InstallationID: "foo", } ctx, cancel := context.WithCancel(context.Background()) diff --git a/lib/teleterm/vnet/service.go b/lib/teleterm/vnet/service.go index 3053caebd6d6e..79694156c99b5 100644 --- a/lib/teleterm/vnet/service.go +++ b/lib/teleterm/vnet/service.go @@ -21,16 +21,21 @@ import ( "crypto/tls" "errors" "sync" + "sync/atomic" "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/gravitational/teleport" vnetproto "github.com/gravitational/teleport/api/gen/proto/go/teleport/vnet/v1" "github.com/gravitational/teleport/api/types" + prehogv1alpha "github.com/gravitational/teleport/gen/proto/go/prehog/v1alpha" apiteleterm "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" api "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/vnet/v1" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" + "github.com/gravitational/teleport/lib/teleterm/clusters" "github.com/gravitational/teleport/lib/teleterm/daemon" logutils "github.com/gravitational/teleport/lib/utils/log" "github.com/gravitational/teleport/lib/vnet" @@ -54,6 +59,7 @@ type Service struct { mu sync.Mutex status status processManager *vnet.ProcessManager + usageReporter *usageReporter } // New creates an instance of Service. @@ -68,8 +74,17 @@ func New(cfg Config) (*Service, error) { } type Config struct { - DaemonService *daemon.Service + // DaemonService is used to get cached clients and for usage reporting. If DaemonService was not + // one giant blob of methods, Config could accept two separate services instead. + DaemonService *daemon.Service + // InsecureSkipVerify signifies whether VNet is going to verify the identity of the proxy service. InsecureSkipVerify bool + // ClusterIDCache is used for usage reporting to read cluster ID that needs to be included with + // every event. + ClusterIDCache *clusteridcache.Cache + // InstallationID is a unique ID of this particular Connect installation, used for usage + // reporting. + InstallationID string } // CheckAndSetDefaults checks and sets the defaults @@ -78,6 +93,14 @@ func (c *Config) CheckAndSetDefaults() error { return trace.BadParameter("missing DaemonService") } + if c.ClusterIDCache == nil { + return trace.BadParameter("missing ClusterIDCache") + } + + if c.InstallationID == "" { + return trace.BadParameter("missing InstallationID") + } + return nil } @@ -93,9 +116,25 @@ func (s *Service) Start(ctx context.Context, req *api.StartRequest) (*api.StartR return &api.StartResponse{}, nil } + usageReporter, err := NewUsageReporter(UsageReporterConfig{ + ClientCache: s.cfg.DaemonService, + EventConsumer: s.cfg.DaemonService, + ClusterIDCache: s.cfg.ClusterIDCache, + InstallationID: s.cfg.InstallationID, + }) + if err != nil { + return nil, trace.Wrap(err) + } + defer func() { + if s.status != statusRunning { + usageReporter.Stop() + } + }() + appProvider := &appProvider{ daemonService: s.cfg.DaemonService, insecureSkipVerify: s.cfg.InsecureSkipVerify, + usageReporter: usageReporter, } processManager, err := vnet.SetupAndRun(ctx, appProvider) @@ -124,6 +163,7 @@ func (s *Service) Start(ctx context.Context, req *api.StartRequest) (*api.StartR }() s.processManager = processManager + s.usageReporter = usageReporter s.status = statusRunning return &api.StartResponse{}, nil } @@ -155,6 +195,7 @@ func (s *Service) stopLocked() error { if err != nil && !errors.Is(err, context.Canceled) { return trace.Wrap(err) } + s.usageReporter.Stop() s.status = statusNotRunning return nil @@ -177,6 +218,7 @@ func (s *Service) Close() error { type appProvider struct { daemonService *daemon.Service + usageReporter *usageReporter insecureSkipVerify bool } @@ -271,3 +313,174 @@ func (p *appProvider) GetVnetConfig(ctx context.Context, profileName, leafCluste vnetConfig, err := vnetConfigClient.GetVnetConfig(ctx, &vnetproto.GetVnetConfigRequest{}) return vnetConfig, trace.Wrap(err) } + +// OnNewConnection submits a usage event once per appProvider lifetime. +// That is, if a user makes multiple connections to a single app, OnNewConnection submits a single +// event. This is to mimic how Connect submits events for its app gateways. This lets us compare +// popularity of VNet and app gateways. +func (p *appProvider) OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error { + // Enqueue the event from a separate goroutine since we don't care about errors anyway and we also + // don't want to slow down VNet connections. + go func() { + uri := uri.NewClusterURI(profileName).AppendLeafCluster(leafClusterName).AppendApp(app.GetName()) + + // Not passing ctx to ReportApp since ctx is tied to the lifetime of the connection. + // If it's a short-lived connection, inheriting its context would interrupt reporting. + err := p.usageReporter.ReportApp(uri) + if err != nil { + log.ErrorContext(ctx, "Failed to submit usage event", "app", uri, "error", err) + } + }() + + return nil +} + +type usageReporter struct { + cfg UsageReporterConfig + // reportedApps contains a set of URIs for apps which usage has been already reported. + // App gateways (local proxies) in Connect report a single event per gateway created per app. VNet + // needs to replicate this behavior, hence why it keeps track of reported apps to report only one + // event per app per VNet's lifespan. + reportedApps map[string]struct{} + // mu protects access to reportedApps. + mu sync.Mutex + // close is used to abort a ReportApp call that's currently in flight. + close chan struct{} + // closed signals that usageReporter has been stopped and no more events should be reported. + closed atomic.Bool +} + +type clientCache interface { + GetCachedClient(context.Context, uri.ResourceURI) (*client.ClusterClient, error) + ResolveClusterURI(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) +} + +type eventConsumer interface { + ReportUsageEvent(*apiteleterm.ReportUsageEventRequest) error +} + +type UsageReporterConfig struct { + ClientCache clientCache + EventConsumer eventConsumer + // clusterIDCache stores cluster ID that needs to be included with each usage event. It's updated + // outside of usageReporter – the middleware merely reads data from it. If the cache does not + // contain the given cluster ID, usageReporter drops the event. + ClusterIDCache *clusteridcache.Cache + InstallationID string +} + +func (c *UsageReporterConfig) CheckAndSetDefaults() error { + if c.ClientCache == nil { + return trace.BadParameter("missing ClientCache") + } + + if c.EventConsumer == nil { + return trace.BadParameter("missing EventConsumer") + } + + if c.ClusterIDCache == nil { + return trace.BadParameter("missing ClusterIDCache") + } + + if c.InstallationID == "" { + return trace.BadParameter("missing InstallationID") + } + + return nil +} + +func NewUsageReporter(cfg UsageReporterConfig) (*usageReporter, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + return &usageReporter{ + cfg: cfg, + reportedApps: make(map[string]struct{}), + close: make(chan struct{}), + }, nil +} + +// ReportApp adds an event related to the given app to the events queue, if the app wasn't reported +// already. Only one invocation of ReportApp can be in flight at a time. +func (r *usageReporter) ReportApp(appURI uri.ResourceURI) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed.Load() { + return trace.CompareFailed("usage reporter has been stopped") + } + + if _, hasAppBeenReported := r.reportedApps[appURI.String()]; hasAppBeenReported { + return nil + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + select { + case <-r.close: + cancel() + case <-ctx.Done(): + } + }() + + rootClusterURI := appURI.GetRootClusterURI() + client, err := r.cfg.ClientCache.GetCachedClient(ctx, appURI) + if err != nil { + return trace.Wrap(err) + } + rootClusterName := client.RootClusterName() + _, tc, err := r.cfg.ClientCache.ResolveClusterURI(appURI) + if err != nil { + return trace.Wrap(err) + } + + clusterID, ok := r.cfg.ClusterIDCache.Load(rootClusterURI) + if !ok { + return trace.NotFound("cluster ID for %q not found", rootClusterURI) + } + + log.DebugContext(ctx, "Reporting usage event", "app", appURI.String()) + + err = r.cfg.EventConsumer.ReportUsageEvent(&apiteleterm.ReportUsageEventRequest{ + AuthClusterId: clusterID, + PrehogReq: &prehogv1alpha.SubmitConnectEventRequest{ + DistinctId: r.cfg.InstallationID, + Timestamp: timestamppb.Now(), + Event: &prehogv1alpha.SubmitConnectEventRequest_ProtocolUse{ + ProtocolUse: &prehogv1alpha.ConnectProtocolUseEvent{ + ClusterName: rootClusterName, + UserName: tc.Username, + Protocol: "app", + Origin: "vnet", + AccessThrough: "vnet", + }, + }, + }, + }) + if err != nil { + return trace.Wrap(err, "adding usage event to queue") + } + + r.reportedApps[appURI.String()] = struct{}{} + + return nil +} + +// Stop aborts the reporting of an event that's currently in progress and prevents further events +// from being reported. It blocks until the current ReportApp call aborts. +func (r *usageReporter) Stop() { + if r.closed.Load() { + return + } + + // Prevent new calls to ReportApp from being made. + r.closed.Store(true) + // Abort context of the ReportApp call currently in flight. + close(r.close) + // Block until the current ReportApp call aborts. + r.mu.Lock() + defer r.mu.Unlock() +} diff --git a/lib/teleterm/vnet/service_test.go b/lib/teleterm/vnet/service_test.go new file mode 100644 index 0000000000000..9fe1d8ea9da74 --- /dev/null +++ b/lib/teleterm/vnet/service_test.go @@ -0,0 +1,163 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package vnet + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + teletermv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/teleterm/v1" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/teleterm/api/uri" + "github.com/gravitational/teleport/lib/teleterm/clusteridcache" + "github.com/gravitational/teleport/lib/teleterm/clusters" +) + +func TestUsageReporter(t *testing.T) { + eventConsumer := fakeEventConsumer{} + + validCluster := uri.NewClusterURI("foo") + clusterWithoutClient := uri.NewClusterURI("no-client") + clusterWithoutProfile := uri.NewClusterURI("no-profile") + clusterWithoutClusterID := uri.NewClusterURI("no-cluster-id") + + clientCache := fakeClientCache{ + validClusterURIs: map[uri.ResourceURI]struct{}{ + validCluster: struct{}{}, + clusterWithoutProfile: struct{}{}, + clusterWithoutClusterID: struct{}{}, + }, + } + + clusterIDcache := clusteridcache.Cache{} + clusterIDcache.Store(uri.NewClusterURI("foo"), "1234") + + usageReporter, err := NewUsageReporter(UsageReporterConfig{ + EventConsumer: &eventConsumer, + ClientCache: &clientCache, + ClusterIDCache: &clusterIDcache, + InstallationID: "4321", + }) + require.NoError(t, err) + t.Cleanup(usageReporter.Stop) + + // Verify that reporting the same app twice adds only one usage event. + err = usageReporter.ReportApp(validCluster.AppendApp("app")) + require.NoError(t, err) + err = usageReporter.ReportApp(validCluster.AppendApp("app")) + require.NoError(t, err) + require.Equal(t, 1, eventConsumer.EventCount()) + + // Verify that reporting an invalid cluster doesn't submit an event. + err = usageReporter.ReportApp(clusterWithoutClient.AppendApp("bar")) + require.True(t, trace.IsNotFound(err), "Not a NotFound error: %#v", err) + require.Equal(t, 1, eventConsumer.EventCount()) + err = usageReporter.ReportApp(clusterWithoutProfile.AppendApp("bar")) + require.True(t, trace.IsNotFound(err), "Not a NotFound error: %#v", err) + require.Equal(t, 1, eventConsumer.EventCount()) + err = usageReporter.ReportApp(clusterWithoutClusterID.AppendApp("bar")) + require.ErrorIs(t, err, trace.NotFound("cluster ID for \"/clusters/no-cluster-id\" not found")) + require.Equal(t, 1, eventConsumer.EventCount()) +} + +func TestUsageReporter_Stop(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + t.Cleanup(cancel) + eventConsumer := fakeEventConsumer{} + clientCache := fakeClientCache{blockingOnCtxC: make(chan struct{}, 1)} + clusterIDCache := clusteridcache.Cache{} + + usageReporter, err := NewUsageReporter(UsageReporterConfig{ + EventConsumer: &eventConsumer, + ClientCache: &clientCache, + ClusterIDCache: &clusterIDCache, + InstallationID: "4321", + }) + require.NoError(t, err) + t.Cleanup(usageReporter.Stop) + + go func() { + select { + case <-ctx.Done(): + case <-clientCache.blockingOnCtxC: + // Wait for ReportApp to start blocking on GetCachedClient. + } + usageReporter.Stop() + }() + + uri := uri.NewClusterURI("foo").AppendApp("bar") + err = usageReporter.ReportApp(uri) + require.ErrorIs(t, err, context.Canceled) + + err = usageReporter.ReportApp(uri) + require.True(t, trace.IsCompareFailed(err), "expected trace.CompareFailed but got %v", err) +} + +type fakeEventConsumer struct { + mu sync.Mutex + events []*teletermv1.ReportUsageEventRequest +} + +func (ec *fakeEventConsumer) ReportUsageEvent(event *teletermv1.ReportUsageEventRequest) error { + ec.mu.Lock() + defer ec.mu.Unlock() + + ec.events = append(ec.events, event) + return nil +} + +func (ec *fakeEventConsumer) EventCount() int { + ec.mu.Lock() + defer ec.mu.Unlock() + + return len(ec.events) +} + +type fakeClientCache struct { + validClusterURIs map[uri.ResourceURI]struct{} + // blockingOnCtxC makes GetCachedClient block until ctx is canceled. fakeClientCache writes to the + // channel just before GetCachedClient starts to block on ctx. + blockingOnCtxC chan struct{} +} + +func (c *fakeClientCache) GetCachedClient(ctx context.Context, appURI uri.ResourceURI) (*client.ClusterClient, error) { + if c.blockingOnCtxC != nil { + c.blockingOnCtxC <- struct{}{} + + <-ctx.Done() + return nil, trace.Wrap(ctx.Err()) + } + + if _, ok := c.validClusterURIs[appURI.GetClusterURI()]; !ok { + return nil, trace.NotFound("client for cluster %q not found", appURI.GetClusterURI()) + } + + return &client.ClusterClient{}, nil +} + +func (c *fakeClientCache) ResolveClusterURI(uri uri.ResourceURI) (*clusters.Cluster, *client.TeleportClient, error) { + if _, ok := c.validClusterURIs[uri.GetClusterURI()]; !ok { + return nil, nil, trace.NotFound("client for cluster %q not found", uri.GetClusterURI()) + } + + return &clusters.Cluster{}, &client.TeleportClient{Config: client.Config{Username: "alice"}}, nil +} diff --git a/lib/vnet/app_resolver.go b/lib/vnet/app_resolver.go index 46a62f8016e96..11813e6ca18b5 100644 --- a/lib/vnet/app_resolver.go +++ b/lib/vnet/app_resolver.go @@ -62,6 +62,14 @@ type AppProvider interface { // GetVnetConfig returns the cluster VnetConfig resource. GetVnetConfig(ctx context.Context, profileName, leafClusterName string) (*vnet.VnetConfig, error) + + // OnNewConnection gets called whenever a new connection is about to be established through VNet. + // By the time OnNewConnection, VNet has already verified that the user holds a valid cert for the + // app. + // + // The connection won't be established until OnNewConnection returns. Returning an error prevents + // the connection from being made. + OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error } // DialOptions holds ALPN dial options for dialing apps. @@ -227,7 +235,14 @@ func (r *TCPAppResolver) newTCPAppHandler( leafClusterName: leafClusterName, app: app, } - middleware := client.NewCertChecker(appCertIssuer, r.clock) + certChecker := client.NewCertChecker(appCertIssuer, r.clock) + middleware := &localProxyMiddleware{ + certChecker: certChecker, + appProvider: r.appProvider, + app: app, + profileName: profileName, + leafClusterName: leafClusterName, + } localProxyConfig := alpnproxy.LocalProxyConfig{ RemoteProxyAddr: dialOpts.WebProxyAddr, @@ -293,3 +308,26 @@ func fullyQualify(domain string) string { } return domain + "." } + +// localProxyMiddleware wraps around [client.CertChecker] and additionally makes it so that its +// OnNewConnection method calls the same method of [AppProvider]. +type localProxyMiddleware struct { + app types.Application + profileName string + leafClusterName string + certChecker *client.CertChecker + appProvider AppProvider +} + +func (m *localProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy) error { + err := m.certChecker.OnNewConnection(ctx, lp) + if err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(m.appProvider.OnNewConnection(ctx, m.profileName, m.leafClusterName, m.app)) +} + +func (m *localProxyMiddleware) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) error { + return trace.Wrap(m.certChecker.OnStart(ctx, lp)) +} diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index 84bf8434faf63..4243d665081e4 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -33,6 +33,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "testing" "time" @@ -236,9 +237,10 @@ type testClusterSpec struct { } type echoAppProvider struct { - clusters map[string]testClusterSpec - dialOpts DialOptions - reissueAppCert func() tls.Certificate + clusters map[string]testClusterSpec + dialOpts DialOptions + reissueAppCert func() tls.Certificate + onNewConnectionCallCount atomic.Uint32 } // newEchoAppProvider returns an app provider with the list of named apps in each profile and leaf cluster. @@ -330,6 +332,12 @@ func (p *echoAppProvider) GetVnetConfig(ctx context.Context, profileName, leafCl }, nil } +func (p *echoAppProvider) OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error { + p.onNewConnectionCallCount.Add(1) + + return nil +} + // echoAppAuthClient is a fake auth client that answers GetResources requests with a static list of apps and // basic/faked predicate filtering. type echoAppAuthClient struct { @@ -377,87 +385,8 @@ func TestDialFakeApp(t *testing.T) { t.Cleanup(cancel) clock := clockwork.NewFakeClockAt(time.Now()) - ca := newSelfSignedCA(t) - - roots := x509.NewCertPool() - caX509, err := x509.ParseCertificate(ca.Certificate[0]) - require.NoError(t, err) - roots.AddCert(caX509) - - const proxyCN = "testproxy" - proxyCert := newServerCert(t, ca, proxyCN, clock.Now().Add(365*24*time.Hour)) - - proxyTLSConfig := &tls.Config{ - Certificates: []tls.Certificate{proxyCert}, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: roots, - } - - listener, err := tls.Listen("tcp", "localhost:0", proxyTLSConfig) - require.NoError(t, err) - - // Run a fake web proxy that will accept any client connection and echo the input back. - utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ - Name: "web proxy", - Task: func(ctx context.Context) error { - for { - conn, err := listener.Accept() - if err != nil { - if utils.IsOKNetworkError(err) { - return nil - } - return trace.Wrap(err) - } - go func() { - defer conn.Close() - - // Not using require/assert here and below because this is not in the right subtest or in - // the main test goroutine. The test will fail if the conn is not handled. - tlsConn, ok := conn.(*tls.Conn) - if !ok { - t.Log("client conn is not TLS") - return - } - if err := tlsConn.Handshake(); err != nil { - t.Log("error completing tls handshake") - return - } - clientCerts := tlsConn.ConnectionState().PeerCertificates - if len(clientCerts) == 0 { - t.Log("client has no certs") - return - } - // Manually checking the cert expiry compared to the time of the fake clock, since the TLS - // library will only compare the cert expiry to the real clock. - // It's important that the fake clock is never far behind the real clock, and that the - // cert NotBefore is always at/before the real current time, so the TLS library is - // satisfied. - if clock.Now().After(clientCerts[0].NotAfter) { - t.Logf("client cert is expired: currentTime=%s expiry=%s", clock.Now(), clientCerts[0].NotAfter) - return - } - - _, err := io.Copy(conn, conn) - if err != nil && !utils.IsOKNetworkError(err) { - t.Logf("error in io.Copy for echo proxy server: %v", err) - } - }() - } - }, - Terminate: func() error { - if err := listener.Close(); !utils.IsOKNetworkError(err) { - return trace.Wrap(err) - } - return nil - }, - }) - - dialOpts := DialOptions{ - WebProxyAddr: listener.Addr().String(), - RootClusterCACertPool: roots, - SNI: proxyCN, - } + dialOpts := mustStartFakeWebProxy(ctx, t, ca, clock) const appCertLifetime = time.Hour reissueClientCert := func() tls.Certificate { @@ -588,6 +517,48 @@ func testEchoConnection(t *testing.T, conn net.Conn) { } } +func TestOnNewConnection(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + clock := clockwork.NewFakeClockAt(time.Now()) + ca := newSelfSignedCA(t) + dialOpts := mustStartFakeWebProxy(ctx, t, ca, clock) + + const appCertLifetime = time.Hour + reissueClientCert := func() tls.Certificate { + return newClientCert(t, ca, "testclient", clock.Now().Add(appCertLifetime)) + } + + appProvider := newEchoAppProvider(map[string]testClusterSpec{ + "root1.example.com": { + apps: []string{"echo1"}, + cidrRange: "192.168.2.0/24", + leafClusters: map[string]testClusterSpec{}, + }, + }, dialOpts, reissueClientCert) + + validAppName := "echo1.root1.example.com" + invalidAppName := "not.an.app.example.com." + + p := newTestPack(t, ctx, clock, appProvider) + + // Attempt to establish a connection to an invalid app and verify that OnNewConnection was not + // called. + lookupCtx, lookupCtxCancel := context.WithTimeout(ctx, 200*time.Millisecond) + defer lookupCtxCancel() + _, err := p.lookupHost(lookupCtx, invalidAppName) + require.Error(t, err, "Expected lookup of an invalid app to fail") + require.Equal(t, uint32(0), appProvider.onNewConnectionCallCount.Load()) + + // Establish a connection to a valid app and verify that OnNewConnection was called. + conn, err := p.dialHost(ctx, validAppName) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, conn.Close()) }) + require.Equal(t, uint32(1), appProvider.onNewConnectionCallCount.Load()) +} + func randomULAAddress() (tcpip.Address, error) { var bytes [16]byte bytes[0] = 0xfd @@ -733,3 +704,88 @@ func newLeafCert(t *testing.T, ca tls.Certificate, cn string, expires time.Time, PrivateKey: priv, } } + +func mustStartFakeWebProxy(ctx context.Context, t *testing.T, ca tls.Certificate, clock clockwork.FakeClock) DialOptions { + t.Helper() + + roots := x509.NewCertPool() + caX509, err := x509.ParseCertificate(ca.Certificate[0]) + require.NoError(t, err) + roots.AddCert(caX509) + + const proxyCN = "testproxy" + proxyCert := newServerCert(t, ca, proxyCN, clock.Now().Add(365*24*time.Hour)) + + proxyTLSConfig := &tls.Config{ + Certificates: []tls.Certificate{proxyCert}, + ClientAuth: tls.RequireAndVerifyClientCert, + ClientCAs: roots, + } + + listener, err := tls.Listen("tcp", "localhost:0", proxyTLSConfig) + require.NoError(t, err) + + // Run a fake web proxy that will accept any client connection and echo the input back. + utils.RunTestBackgroundTask(ctx, t, &utils.TestBackgroundTask{ + Name: "web proxy", + Task: func(ctx context.Context) error { + for { + conn, err := listener.Accept() + if err != nil { + if utils.IsOKNetworkError(err) { + return nil + } + return trace.Wrap(err) + } + go func() { + defer conn.Close() + + // Not using require/assert here and below because this is not in the right subtest or in + // the main test goroutine. The test will fail if the conn is not handled. + tlsConn, ok := conn.(*tls.Conn) + if !ok { + t.Log("client conn is not TLS") + return + } + if err := tlsConn.Handshake(); err != nil { + t.Log("error completing tls handshake") + return + } + clientCerts := tlsConn.ConnectionState().PeerCertificates + if len(clientCerts) == 0 { + t.Log("client has no certs") + return + } + // Manually checking the cert expiry compared to the time of the fake clock, since the TLS + // library will only compare the cert expiry to the real clock. + // It's important that the fake clock is never far behind the real clock, and that the + // cert NotBefore is always at/before the real current time, so the TLS library is + // satisfied. + if clock.Now().After(clientCerts[0].NotAfter) { + t.Logf("client cert is expired: currentTime=%s expiry=%s", clock.Now(), clientCerts[0].NotAfter) + return + } + + _, err := io.Copy(conn, conn) + if err != nil && !utils.IsOKNetworkError(err) { + t.Logf("error in io.Copy for echo proxy server: %v", err) + } + }() + } + }, + Terminate: func() error { + if err := listener.Close(); !utils.IsOKNetworkError(err) { + return trace.Wrap(err) + } + return nil + }, + }) + + dialOpts := DialOptions{ + WebProxyAddr: listener.Addr().String(), + RootClusterCACertPool: roots, + SNI: proxyCN, + } + + return dialOpts +} diff --git a/tool/tsh/common/daemon.go b/tool/tsh/common/daemon.go index 6892f729e4f9b..cc31f4ba8e795 100644 --- a/tool/tsh/common/daemon.go +++ b/tool/tsh/common/daemon.go @@ -49,6 +49,7 @@ func onDaemonStart(cf *CLIConf) error { PrehogAddr: cf.DaemonPrehogAddr, KubeconfigsDir: cf.DaemonKubeconfigsDir, AgentsDir: cf.DaemonAgentsDir, + InstallationID: cf.DaemonInstallationID, }) if err != nil { return trace.Wrap(err) diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 8028a641f29a6..75ce523fda0ac 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -225,6 +225,8 @@ type CLIConf struct { DaemonAgentsDir string // DaemonPid is the PID to be stopped by tsh daemon stop. DaemonPid int + // DaemonInstallationID is a unique ID identifying a specific Teleport Connect installation. + DaemonInstallationID string // DatabaseService specifies the database proxy server to log into. DatabaseService string @@ -790,6 +792,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { daemonStart.Flag("prehog-addr", "URL where prehog events should be submitted").StringVar(&cf.DaemonPrehogAddr) daemonStart.Flag("kubeconfigs-dir", "Directory containing kubeconfig for Kubernetes Access").StringVar(&cf.DaemonKubeconfigsDir) daemonStart.Flag("agents-dir", "Directory containing agent config files and data directories for Connect My Computer").StringVar(&cf.DaemonAgentsDir) + daemonStart.Flag("installation-id", "Unique ID identifying a specific Teleport Connect installation").StringVar(&cf.DaemonInstallationID) daemonStop := daemon.Command("stop", "Gracefully stops a process on Windows by sending Ctrl-Break to it.").Hidden() daemonStop.Flag("pid", "PID to be stopped").IntVar(&cf.DaemonPid) diff --git a/tool/tsh/common/vnet_common.go b/tool/tsh/common/vnet_common.go index f753d3124b580..754c61f9277dc 100644 --- a/tool/tsh/common/vnet_common.go +++ b/tool/tsh/common/vnet_common.go @@ -126,6 +126,12 @@ func (p *vnetAppProvider) GetVnetConfig(ctx context.Context, profileName, leafCl return vnetConfig, trace.Wrap(err) } +// OnNewConnection gets called before each VNet connection. It's a noop as tsh doesn't need to do +// anything extra here. +func (p *vnetAppProvider) OnNewConnection(ctx context.Context, profileName, leafClusterName string, app types.Application) error { + return nil +} + // getRootClusterCACertPool returns a certificate pool for the root cluster of the given profile. func (p *vnetAppProvider) getRootClusterCACertPool(ctx context.Context, profileName string) (*x509.CertPool, error) { tc, err := p.newTeleportClient(ctx, profileName, "") diff --git a/web/packages/teleterm/src/mainProcess/runtimeSettings.ts b/web/packages/teleterm/src/mainProcess/runtimeSettings.ts index f757bf22353c9..e93cdaabea4dd 100644 --- a/web/packages/teleterm/src/mainProcess/runtimeSettings.ts +++ b/web/packages/teleterm/src/mainProcess/runtimeSettings.ts @@ -78,6 +78,9 @@ export function getRuntimeSettings(): RuntimeSettings { const logsDir = path.join(userDataDir, 'logs'); // DO NOT expose agentsDir through RuntimeSettings. See the comment in getAgentsDir. const agentsDir = getAgentsDir(userDataDir); + const installationId = loadInstallationId( + path.resolve(app.getPath('userData'), 'installation_id') + ); const tshd = { binaryPath: tshBinPath, @@ -93,6 +96,7 @@ export function getRuntimeSettings(): RuntimeSettings { `--prehog-addr=${staticConfig.prehogAddress}`, `--kubeconfigs-dir=${kubeConfigsDir}`, `--agents-dir=${agentsDir}`, + `--installation-id=${installationId}`, ], }; const sharedProcess = { @@ -135,9 +139,7 @@ export function getRuntimeSettings(): RuntimeSettings { kubeConfigsDir, logsDir, platform: process.platform, - installationId: loadInstallationId( - path.resolve(app.getPath('userData'), 'installation_id') - ), + installationId, arch: os.arch(), osVersion: os.release(), appVersion,