diff --git a/cmd/livepeer/starter/flags.go b/cmd/livepeer/starter/flags.go index ff3395623d..13ce8721b4 100644 --- a/cmd/livepeer/starter/flags.go +++ b/cmd/livepeer/starter/flags.go @@ -75,8 +75,8 @@ func NewLivepeerConfig(fs *flag.FlagSet) LivepeerConfig { cfg.LiveAIAuthWebhookURL = fs.String("liveAIAuthWebhookUrl", "", "Live AI RTMP authentication webhook URL") cfg.LivePaymentInterval = fs.Duration("livePaymentInterval", *cfg.LivePaymentInterval, "Interval to pay process Gateway <> Orchestrator Payments for Live AI Video") cfg.LiveOutSegmentTimeout = fs.Duration("liveOutSegmentTimeout", *cfg.LiveOutSegmentTimeout, "Timeout duration to wait the output segment to be available in the Live AI pipeline; defaults to no timeout") - cfg.LiveAICapRefreshModels = fs.String("liveAICapRefreshModels", "", "Comma separated list of models to periodically fetch capacity for. Leave unset to switch off periodic refresh.") cfg.LiveAISaveNSegments = fs.Int("liveAISaveNSegments", 10, "Set how many segments to save to disk for debugging (both input and output)") + cfg.LiveAICapReportInterval = fs.Duration("liveAICapReportInterval", *cfg.LiveAICapReportInterval, "Interval to report Live AI container capacity metrics") // Onchain: cfg.EthAcctAddr = fs.String("ethAcctAddr", *cfg.EthAcctAddr, "Existing Eth account address. For use when multiple ETH accounts exist in the keystore directory") diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 4c7caf27db..158e433b88 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -184,7 +184,7 @@ type LivepeerConfig struct { LiveAIHeartbeatInterval *time.Duration LivePaymentInterval *time.Duration LiveOutSegmentTimeout *time.Duration - LiveAICapRefreshModels *string + LiveAICapReportInterval *time.Duration LiveAISaveNSegments *int } @@ -241,6 +241,7 @@ func DefaultLivepeerConfig() LivepeerConfig { defaultLiveOutSegmentTimeout := 0 * time.Second defaultGatewayHost := "" defaultLiveAIHeartbeatInterval := 5 * time.Second + defaultLiveAICapReportInterval := 25 * time.Minute // Onchain: defaultEthAcctAddr := "" @@ -359,6 +360,7 @@ func DefaultLivepeerConfig() LivepeerConfig { LiveOutSegmentTimeout: &defaultLiveOutSegmentTimeout, GatewayHost: &defaultGatewayHost, LiveAIHeartbeatInterval: &defaultLiveAIHeartbeatInterval, + LiveAICapReportInterval: &defaultLiveAICapReportInterval, // Onchain: EthAcctAddr: &defaultEthAcctAddr, @@ -1591,7 +1593,7 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if *cfg.Network != "offchain" { ctx, cancel := context.WithCancel(ctx) defer cancel() - dbOrchPoolCache, err := discovery.NewDBOrchestratorPoolCache(ctx, n, timeWatcher, orchBlacklist, *cfg.DiscoveryTimeout) + dbOrchPoolCache, err := discovery.NewDBOrchestratorPoolCache(ctx, n, timeWatcher, orchBlacklist, *cfg.DiscoveryTimeout, *cfg.LiveAICapReportInterval) if err != nil { exit("Could not create orchestrator pool with DB cache: %v", err) } @@ -1756,9 +1758,6 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { if cfg.LiveAITrickleHostForRunner != nil { n.LiveAITrickleHostForRunner = *cfg.LiveAITrickleHostForRunner } - if cfg.LiveAICapRefreshModels != nil && *cfg.LiveAICapRefreshModels != "" { - n.LiveAICapRefreshModels = strings.Split(*cfg.LiveAICapRefreshModels, ",") - } n.LiveAISaveNSegments = cfg.LiveAISaveNSegments //Create Livepeer Node diff --git a/core/livepeernode.go b/core/livepeernode.go index fdb501f0ca..e9dde61db0 100644 --- a/core/livepeernode.go +++ b/core/livepeernode.go @@ -168,7 +168,6 @@ type LivepeerNode struct { LiveAIHeartbeatInterval time.Duration LivePaymentInterval time.Duration LiveOutSegmentTimeout time.Duration - LiveAICapRefreshModels []string LiveAISaveNSegments *int // Gateway diff --git a/discovery/db_discovery.go b/discovery/db_discovery.go index 3fc8a37446..6666ee099b 100644 --- a/discovery/db_discovery.go +++ b/discovery/db_discovery.go @@ -14,6 +14,7 @@ import ( "github.com/livepeer/go-livepeer/core" "github.com/livepeer/go-livepeer/eth" lpTypes "github.com/livepeer/go-livepeer/eth/types" + "github.com/livepeer/go-livepeer/monitor" "github.com/livepeer/go-livepeer/net" "github.com/livepeer/go-livepeer/pm" "github.com/livepeer/go-livepeer/server" @@ -21,27 +22,25 @@ import ( "github.com/golang/glog" ) -var cacheRefreshInterval = 25 * time.Minute -var getTicker = func() *time.Ticker { - return time.NewTicker(cacheRefreshInterval) -} +var networkCapabilitiesReportingInterval = 25 * time.Minute type ticketParamsValidator interface { ValidateTicketParams(ticketParams *pm.TicketParams) error } type DBOrchestratorPoolCache struct { - store common.OrchestratorStore - lpEth eth.LivepeerEthClient - ticketParamsValidator ticketParamsValidator - rm common.RoundsManager - bcast common.Broadcaster - orchBlacklist []string - discoveryTimeout time.Duration - node *core.LivepeerNode + store common.OrchestratorStore + lpEth eth.LivepeerEthClient + ticketParamsValidator ticketParamsValidator + rm common.RoundsManager + bcast common.Broadcaster + orchBlacklist []string + discoveryTimeout time.Duration + node *core.LivepeerNode + lastNetworkCapabilitiesReported time.Time } -func NewDBOrchestratorPoolCache(ctx context.Context, node *core.LivepeerNode, rm common.RoundsManager, orchBlacklist []string, discoveryTimeout time.Duration) (*DBOrchestratorPoolCache, error) { +func NewDBOrchestratorPoolCache(ctx context.Context, node *core.LivepeerNode, rm common.RoundsManager, orchBlacklist []string, discoveryTimeout time.Duration, liveAICapReportInterval time.Duration) (*DBOrchestratorPoolCache, error) { if node.Eth == nil { return nil, fmt.Errorf("could not create DBOrchestratorPoolCache: LivepeerEthClient is nil") } @@ -66,7 +65,7 @@ func NewDBOrchestratorPoolCache(ctx context.Context, node *core.LivepeerNode, rm return err } - if err := dbo.pollOrchestratorInfo(ctx); err != nil { + if err := dbo.pollOrchestratorInfo(ctx, liveAICapReportInterval); err != nil { return err } return nil @@ -252,13 +251,13 @@ func (dbo *DBOrchestratorPoolCache) cacheOrchestratorStake() error { return nil } -func (dbo *DBOrchestratorPoolCache) pollOrchestratorInfo(ctx context.Context) error { +func (dbo *DBOrchestratorPoolCache) pollOrchestratorInfo(ctx context.Context, liveAICapReportInterval time.Duration) error { if err := dbo.cacheOrchInfos(); err != nil { glog.Errorf("unable to poll orchestrator info: %v", err) return err } - ticker := getTicker() + ticker := time.NewTicker(liveAICapReportInterval) go func() { for { select { @@ -393,12 +392,59 @@ func (dbo *DBOrchestratorPoolCache) cacheOrchInfos() error { i = numOrchs //exit loop } } - //save network capabilities in LivepeerNode - dbo.node.UpdateNetworkCapabilities(orchNetworkCapabilities) + + // Only update network capabilities every 25 minutes + if time.Since(dbo.lastNetworkCapabilitiesReported) >= networkCapabilitiesReportingInterval { + // Save network capabilities in LivepeerNode + dbo.node.UpdateNetworkCapabilities(orchNetworkCapabilities) + + dbo.lastNetworkCapabilitiesReported = time.Now() + } + + // Report AI container capacity metrics + reportAICapacityFromNetworkCapabilities(orchNetworkCapabilities) return nil } +func reportAICapacityFromNetworkCapabilities(orchNetworkCapabilities []*common.OrchNetworkCapabilities) { + // Build structured capacity data + modelCapacities := make(map[string]*monitor.ModelAICapacities) + + for _, orchCap := range orchNetworkCapabilities { + models := getModelCapsFromNetCapabilities(orchCap.Capabilities) + + for modelID, model := range models { + if _, exists := modelCapacities[modelID]; !exists { + modelCapacities[modelID] = &monitor.ModelAICapacities{ + ModelID: modelID, + Orchestrators: make(map[string]monitor.AIContainerCapacity), + } + } + + capacity := monitor.AIContainerCapacity{ + Idle: int(model.Capacity), + InUse: int(model.CapacityInUse), + } + modelCapacities[modelID].Orchestrators[orchCap.OrchURI] = capacity + } + } + + monitor.ReportAIContainerCapacity(modelCapacities) +} + +func getModelCapsFromNetCapabilities(caps *net.Capabilities) map[string]*net.Capabilities_CapabilityConstraints_ModelConstraint { + if caps == nil || caps.Constraints == nil || caps.Constraints.PerCapability == nil { + return nil + } + liveAI, ok := caps.Constraints.PerCapability[uint32(core.Capability_LiveVideoToVideo)] + if !ok { + return nil + } + + return liveAI.Models +} + func (dbo *DBOrchestratorPoolCache) Broadcaster() common.Broadcaster { return dbo.bcast } diff --git a/discovery/discovery.go b/discovery/discovery.go index 0135aa85a9..6c70dcb4cd 100644 --- a/discovery/discovery.go +++ b/discovery/discovery.go @@ -246,7 +246,6 @@ func (o *orchestratorPool) GetOrchestrators(ctx context.Context, numOrchestrator for _, i := range rand.Perm(numAvailableOrchs) { go getOrchInfo(ctx, common.OrchestratorDescriptor{linfos[i], nil}, 0, odCh, errCh, allOrchDescrCh) } - go reportLiveAICapacity(allOrchDescrCh, caps) // use a timer to time out the entire get info loop below cutoffTimer := time.NewTimer(maxGetOrchestratorCutoffTimeout) @@ -326,62 +325,6 @@ func (o *orchestratorPool) GetOrchestrators(ctx context.Context, numOrchestrator return ods, nil } -func getModelCaps(caps *net.Capabilities) map[string]*net.Capabilities_CapabilityConstraints_ModelConstraint { - if caps == nil || caps.Constraints == nil || caps.Constraints.PerCapability == nil { - return nil - } - liveAI, ok := caps.Constraints.PerCapability[uint32(core.Capability_LiveVideoToVideo)] - if !ok { - return nil - } - - return liveAI.Models -} - -func reportLiveAICapacity(ch chan common.OrchestratorDescriptor, caps common.CapabilityComparator) { - if !monitor.Enabled { - return - } - modelsReq := getModelCaps(caps.ToNetCapabilities()) - - var allOrchInfo []common.OrchestratorDescriptor - var done bool - for { - select { - case od := <-ch: - allOrchInfo = append(allOrchInfo, od) - case <-time.After(maxGetOrchestratorCutoffTimeout): - done = true - } - if done { - break - } - } - - idleContainersByModelAndOrchestrator := make(map[string]map[string]int) - for _, od := range allOrchInfo { - var models map[string]*net.Capabilities_CapabilityConstraints_ModelConstraint - if od.RemoteInfo != nil { - models = getModelCaps(od.RemoteInfo.Capabilities) - } - - for modelID := range modelsReq { - idle := 0 - if models != nil { - if model, ok := models[modelID]; ok { - idle = int(model.Capacity) - } - } - - if _, exists := idleContainersByModelAndOrchestrator[modelID]; !exists { - idleContainersByModelAndOrchestrator[modelID] = make(map[string]int) - } - idleContainersByModelAndOrchestrator[modelID][od.LocalInfo.URL.String()] = idle - } - } - monitor.AIContainersIdleAfterGatewayDiscovery(idleContainersByModelAndOrchestrator) -} - func (o *orchestratorPool) Size() int { return len(o.infos) } diff --git a/discovery/discovery_test.go b/discovery/discovery_test.go index c10a322d46..361a6526d9 100644 --- a/discovery/discovery_test.go +++ b/discovery/discovery_test.go @@ -51,7 +51,7 @@ func TestNewDBOrchestratorPoolCache_NilEthClient_ReturnsError(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) assert.Nil(pool) assert.EqualError(err, "could not create DBOrchestratorPoolCache: LivepeerEthClient is nil") } @@ -173,7 +173,7 @@ func sync_TestDBOrchestratorPoolCacheSize(t *testing.T) { goleak.VerifyNone(t, common.IgnoreRoutines()...) }() - emptyPool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + emptyPool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) require.NotNil(emptyPool) assert.Equal(0, emptyPool.Size()) @@ -184,7 +184,7 @@ func sync_TestDBOrchestratorPoolCacheSize(t *testing.T) { dbh.UpdateOrch(ethOrchToDBOrch(o)) } - nonEmptyPool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + nonEmptyPool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) require.NotNil(nonEmptyPool) assert.Equal(len(addresses), nonEmptyPool.Size()) @@ -232,7 +232,7 @@ func TestNewDBOrchestorPoolCache_NoEthAddress(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pool, err := NewDBOrchestratorPoolCache(ctx, node, rm, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, rm, []string{}, 500*time.Millisecond, 1*time.Minute) require.Nil(err) // Check that serverGetOrchInfo returns early and the orchestrator isn't updated @@ -282,7 +282,7 @@ func TestNewDBOrchestratorPoolCache_InvalidPrices(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pool, err := NewDBOrchestratorPoolCache(ctx, node, rm, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, rm, []string{}, 500*time.Millisecond, 1*time.Minute) require.Nil(err) // priceInfo.PixelsPerUnit = 0 @@ -346,7 +346,7 @@ func sync_TestNewDBOrchestratorPoolCache_GivenListOfOrchs_CreatesPoolCacheCorrec sender.On("ValidateTicketParams", mock.Anything).Return(nil).Times(3) - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) dbOrchs, err := dbh.SelectOrchs(nil) require.NoError(err) assert.Equal(pool.Size(), 3) @@ -422,7 +422,7 @@ func TestNewDBOrchestratorPoolCache_TestURLs(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) // bad URLs are inserted in the database but are not included in the working set, as there is no returnable query for getting their priceInfo // And if URL is updated it won't be picked up until next cache update @@ -455,7 +455,7 @@ func TestNewDBOrchestratorPoolCache_TestURLs_Empty(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) assert.Equal(0, pool.Size()) infos := pool.GetInfos() @@ -527,10 +527,10 @@ func sync_TestNewDBOrchestorPoolCache_PollOrchestratorInfo(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - origCacheRefreshInterval := cacheRefreshInterval - cacheRefreshInterval = 200 * time.Millisecond - defer func() { cacheRefreshInterval = origCacheRefreshInterval }() - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 700*time.Millisecond) + origCacheRefreshInterval := networkCapabilitiesReportingInterval + networkCapabilitiesReportingInterval = 200 * time.Millisecond + defer func() { networkCapabilitiesReportingInterval = origCacheRefreshInterval }() + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 700*time.Millisecond, 200*time.Millisecond) require.NoError(err) // Ensure orchestrators exist in DB @@ -686,7 +686,7 @@ func sync_TestCachedPool_AllOrchestratorsTooExpensive_ReturnsAllOrchestrators(t sender.On("ValidateTicketParams", mock.Anything).Return(nil) - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) // ensuring orchs exist in DB @@ -775,7 +775,7 @@ func sync_TestCachedPool_GetOrchestrators_MaxBroadcastPriceNotSet(t *testing.T) sender.On("ValidateTicketParams", mock.Anything).Return(nil) - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) // ensuring orchs exist in DB @@ -881,7 +881,7 @@ func sync_TestCachedPool_N_OrchestratorsGoodPricing_ReturnsNOrchestrators(t *tes sender.On("ValidateTicketParams", mock.Anything).Return(nil) - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) // ensuring orchs exist in DB @@ -976,7 +976,7 @@ func TestCachedPool_GetOrchestrators_TicketParamsValidation(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) // Test 25 out of 50 orchs pass ticket params validation @@ -1065,7 +1065,7 @@ func sync_TestCachedPool_GetOrchestrators_OnlyActiveOrchestrators(t *testing.T) sender.On("ValidateTicketParams", mock.Anything).Return(nil) - pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{round: big.NewInt(24)}, []string{}, 500*time.Millisecond) + pool, err := NewDBOrchestratorPoolCache(ctx, node, &stubRoundsManager{round: big.NewInt(24)}, []string{}, 500*time.Millisecond, 1*time.Minute) require.NoError(err) // ensuring orchs exist in DB @@ -1649,7 +1649,7 @@ func TestSetGetOrchestratorTimeout(t *testing.T) { } //set timeout to 1000ms - poolCache, err := NewDBOrchestratorPoolCache(context.TODO(), node, &stubRoundsManager{}, []string{}, 1000*time.Millisecond) + poolCache, err := NewDBOrchestratorPoolCache(context.TODO(), node, &stubRoundsManager{}, []string{}, 1000*time.Millisecond, 1*time.Minute) assert.Nil(err) //confirm the timeout is now 1000ms assert.Equal(poolCache.discoveryTimeout, 1000*time.Millisecond) diff --git a/monitor/census.go b/monitor/census.go index 0b37faf8b7..24e017b991 100644 --- a/monitor/census.go +++ b/monitor/census.go @@ -73,6 +73,18 @@ const ( segTypeRec = "recorded" // segment in the stream for which recording is enabled ) +// AIContainerCapacity holds capacity information for AI containers +type AIContainerCapacity struct { + Idle int + InUse int +} + +// ModelAICapacities holds all orchestrator capacities for a specific model +type ModelAICapacities struct { + ModelID string + Orchestrators map[string]AIContainerCapacity // orchURI -> capacity +} + const ( //mpeg7-sign comparison fail of fast verification FVType1Error = 1 @@ -197,24 +209,24 @@ type ( mSceneClassification *stats.Int64Measure // Metrics for AI jobs - mAIModelsRequested *stats.Int64Measure - mAIRequestLatencyScore *stats.Float64Measure - mAIRequestPrice *stats.Float64Measure - mAIRequestError *stats.Int64Measure - mAIResultDownloaded *stats.Int64Measure - mAIResultDownloadTime *stats.Float64Measure - mAIResultUploaded *stats.Int64Measure - mAIResultUploadTime *stats.Float64Measure - mAIResultSaveFailed *stats.Int64Measure - mAIContainersInUse *stats.Int64Measure - mAIContainersIdle *stats.Int64Measure - aiContainersIdleByPipelineByOrchestrator map[string]map[string]int - mAIGPUsIdle *stats.Int64Measure - mAICurrentLivePipelines *stats.Int64Measure - aiLiveSessionsByPipeline map[string]int - mAIFirstSegmentDelay *stats.Int64Measure - mAILiveAttempts *stats.Int64Measure - mAINumOrchs *stats.Int64Measure + mAIModelsRequested *stats.Int64Measure + mAIRequestLatencyScore *stats.Float64Measure + mAIRequestPrice *stats.Float64Measure + mAIRequestError *stats.Int64Measure + mAIResultDownloaded *stats.Int64Measure + mAIResultDownloadTime *stats.Float64Measure + mAIResultUploaded *stats.Int64Measure + mAIResultUploadTime *stats.Float64Measure + mAIResultSaveFailed *stats.Int64Measure + mAIContainersInUse *stats.Int64Measure + mAIContainersIdle *stats.Int64Measure + aiContainersCapacityByModel map[string]*ModelAICapacities + mAIGPUsIdle *stats.Int64Measure + mAICurrentLivePipelines *stats.Int64Measure + aiLiveSessionsByPipeline map[string]int + mAIFirstSegmentDelay *stats.Int64Measure + mAILiveAttempts *stats.Int64Measure + mAINumOrchs *stats.Int64Measure mAIWhipTransportBytesReceived *stats.Int64Measure mAIWhipTransportBytesSent *stats.Int64Measure @@ -391,7 +403,7 @@ func InitCensus(nodeType NodeType, version string) { census.mAIResultSaveFailed = stats.Int64("ai_result_upload_failed_total", "AIResultUploadFailed", "tot") census.mAIContainersInUse = stats.Int64("ai_container_in_use", "Number of containers currently used for AI processing", "tot") census.mAIContainersIdle = stats.Int64("ai_container_idle", "Number of containers currently available for AI processing", "tot") - census.aiContainersIdleByPipelineByOrchestrator = make(map[string]map[string]int) + census.aiContainersCapacityByModel = make(map[string]*ModelAICapacities) census.mAIGPUsIdle = stats.Int64("ai_gpus_idle", "Number of idle GPUs (with no configured container)", "tot") census.mAICurrentLivePipelines = stats.Int64("ai_current_live_pipelines", "Number of live AI pipelines currently running", "tot") census.aiLiveSessionsByPipeline = make(map[string]int) @@ -1019,7 +1031,7 @@ func InitCensus(nodeType NodeType, version string) { Name: "ai_container_in_use", Measure: census.mAIContainersInUse, Description: "Number of containers currently used for AI processing", - TagKeys: append([]tag.Key{census.kPipeline, census.kModelName}, baseTags...), + TagKeys: append([]tag.Key{census.kPipeline, census.kModelName, census.kOrchestratorURI}, baseTags...), Aggregation: view.LastValue(), }, { @@ -2018,43 +2030,56 @@ func AIContainersInUse(currentContainersInUse int, pipeline, modelID string) { } } -func AIContainersIdleAfterGatewayDiscovery(idleContainersByPipelinesAndOrchestrator map[string]map[string]int) { +func ReportAIContainerCapacity(modelCapacities map[string]*ModelAICapacities) { census.lock.Lock() defer census.lock.Unlock() - // Reset all existing pipeline idleContainers to zero first. + // Reset all existing model container counts to zero first. // This ensures we don't have any stale counts. - for k, v := range census.aiContainersIdleByPipelineByOrchestrator { - for k2 := range v { - census.aiContainersIdleByPipelineByOrchestrator[k][k2] = 0 + for _, modelCap := range census.aiContainersCapacityByModel { + for orchURI := range modelCap.Orchestrators { + modelCap.Orchestrators[orchURI] = AIContainerCapacity{Idle: 0, InUse: 0} } } - // Update counts. - for pipeline, v := range idleContainersByPipelinesAndOrchestrator { - for orchestrator, count := range v { - if _, exists := census.aiContainersIdleByPipelineByOrchestrator[pipeline]; !exists { - census.aiContainersIdleByPipelineByOrchestrator[pipeline] = make(map[string]int) + + // Update counts with new data + for modelID, newModelCap := range modelCapacities { + if _, exists := census.aiContainersCapacityByModel[modelID]; !exists { + census.aiContainersCapacityByModel[modelID] = &ModelAICapacities{ + ModelID: modelID, + Orchestrators: make(map[string]AIContainerCapacity), } - census.aiContainersIdleByPipelineByOrchestrator[pipeline][orchestrator] = count + } + for orchURI, capacity := range newModelCap.Orchestrators { + census.aiContainersCapacityByModel[modelID].Orchestrators[orchURI] = capacity } } - // Record metrics for all pipelines for all orchestrators - for model, v := range census.aiContainersIdleByPipelineByOrchestrator { - for orchURL, v2 := range v { + // Record metrics for all models for all orchestrators + for modelID, modelCap := range census.aiContainersCapacityByModel { + for orchURI, capacity := range modelCap.Orchestrators { + // Record idle containers metric + if err := stats.RecordWithTags(census.ctx, + []tag.Mutator{tag.Insert(census.kModelName, modelID), tag.Insert(census.kOrchestratorURI, orchURI)}, + census.mAIContainersIdle.M(int64(capacity.Idle))); err != nil { + glog.Errorf("Error recording idle containers metric err=%q", err) + } + + // Record in-use containers metric if err := stats.RecordWithTags(census.ctx, - []tag.Mutator{tag.Insert(census.kModelName, model), tag.Insert(census.kOrchestratorURI, orchURL)}, - census.mAIContainersIdle.M(int64(v2))); err != nil { - glog.Errorf("Error recording metrics err=%q", err) + []tag.Mutator{tag.Insert(census.kModelName, modelID), tag.Insert(census.kOrchestratorURI, orchURI)}, + census.mAIContainersInUse.M(int64(capacity.InUse))); err != nil { + glog.Errorf("Error recording in-use containers metric err=%q", err) } - if v2 == 0 { + + if capacity.Idle == 0 && capacity.InUse == 0 { // Remove zero counts, no need to report it again - delete(census.aiContainersIdleByPipelineByOrchestrator[model], orchURL) + delete(census.aiContainersCapacityByModel[modelID].Orchestrators, orchURI) } } - if len(census.aiContainersIdleByPipelineByOrchestrator[model]) == 0 { - // If there are no more pipelines for this model, remove it from the map - delete(census.aiContainersIdleByPipelineByOrchestrator, model) + if len(census.aiContainersCapacityByModel[modelID].Orchestrators) == 0 { + // If there are no more orchestrators for this model, remove it from the map + delete(census.aiContainersCapacityByModel, modelID) } } } diff --git a/server/ai_session.go b/server/ai_session.go index eeaf92587e..8aac600d3b 100644 --- a/server/ai_session.go +++ b/server/ai_session.go @@ -478,53 +478,6 @@ func (sel *AISessionSelector) getSessions(ctx context.Context) ([]*BroadcastSess return selectOrchestrator(ctx, sel.node, streamParams, numOrchs, sel.suspender, common.ScoreAtLeast(0), func(sessionID string) {}) } -type noopSus struct{} - -func (n noopSus) Suspended(orch string) int { - return 0 -} - -func (c *AISessionManager) refreshOrchCapacity(modelIDs []string) { - if len(modelIDs) < 1 { - return - } - - pool := c.node.OrchestratorPool - if pool == nil { - return - } - clog.Infof(context.Background(), "Starting periodic orchestrator refresh for capacity reporting") - - modelsReq := make(map[string]*core.ModelConstraint) - for _, modelID := range modelIDs { - modelsReq[modelID] = &core.ModelConstraint{ - Warm: false, - RunnerVersion: c.node.Capabilities.MinRunnerVersionConstraint(core.Capability_LiveVideoToVideo, modelID), - } - } - go func() { - refreshInterval := 10 * time.Second - ticker := time.NewTicker(refreshInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - ctx, cancel := context.WithTimeout(context.Background(), refreshInterval) - capabilityConstraints := core.PerCapabilityConstraints{ - core.Capability_LiveVideoToVideo: {Models: modelsReq}, - } - caps := core.NewCapabilities(append(core.DefaultCapabilities(), core.Capability_LiveVideoToVideo), nil) - caps.SetPerCapabilityConstraints(capabilityConstraints) - caps.SetMinVersionConstraint(c.node.Capabilities.MinVersionConstraint()) - - pool.GetOrchestrators(ctx, pool.Size(), noopSus{}, caps, common.ScoreAtLeast(0)) - - cancel() - } - } - }() -} - type AISessionManager struct { node *core.LivepeerNode selectors map[string]*AISessionSelector @@ -539,7 +492,6 @@ func NewAISessionManager(node *core.LivepeerNode, ttl time.Duration) *AISessionM mu: sync.Mutex{}, ttl: ttl, } - sessionManager.refreshOrchCapacity(node.LiveAICapRefreshModels) return sessionManager }