diff --git a/src/colab/api.ts b/src/colab/api.ts index 3258d422..8879cd6e 100644 --- a/src/colab/api.ts +++ b/src/colab/api.ts @@ -127,6 +127,13 @@ function normalizeVariant(variant: ColabGapiVariant): Variant { } } +export const Accelerator = z.object({ + /** The variant of the assignment. */ + variant: z.enum(ColabGapiVariant).transform(normalizeVariant), + /** The assigned accelerator. */ + models: z.array(z.string().toUpperCase()), +}); + /** * The schema for top level information about a user's tier, usage and * availability in Colab. @@ -139,27 +146,21 @@ export const UserInfoSchema = z.object({ /** The paid Colab Compute Units balance. */ paidComputeUnitsBalance: z.number().optional(), /** The eligible machine accelerators. */ - eligibleAccelerators: z - .array( - z.object({ - /** The variant of the assignment. */ - variant: z.enum(ColabGapiVariant).transform(normalizeVariant), - /** The assigned accelerator. */ - models: z.array(z.string().toUpperCase()), - }), - ) - .optional(), + eligibleAccelerators: z.array(Accelerator), + /** The ineligible machine accelerators. */ + ineligibleAccelerators: z.array(Accelerator), }); +/** Colab user information. */ +export type UserInfo = z.infer; -/** The schema of Colab Compute Units (CCU) information. */ -export const CcuInfoSchema = z.object({ - /** - * The current balance of the paid CCUs. - * - * Naming is unfortunate due to historical reasons and free CCU quota - * balance is made available in a separate field for the same reasons. - */ - currentBalance: z.number(), +/** + * The schema for top level information about a user's tier, usage and + * availability in Colab when CCU consumption info is requested (consumption + * fields are required). + */ +export const ConsumptionUserInfoSchema = UserInfoSchema.required({ + paidComputeUnitsBalance: true, +}).extend({ /** * The current rate of consumption of the user's CCUs (paid or free) based on * all assigned VMs. @@ -170,18 +171,6 @@ export const CcuInfoSchema = z.object({ * is positive. */ assignmentsCount: z.number(), - /** The list of eligible GPU accelerators. */ - eligibleGpus: z.array(z.string().toUpperCase()), - /** The list of ineligible GPU accelerators. */ - ineligibleGpus: z.array(z.string().toUpperCase()).optional(), - /** - * The list of eligible TPU accelerators. - */ - eligibleTpus: z.array(z.string().toUpperCase()), - /** - * The list of ineligible TPU accelerators. - */ - ineligibleTpus: z.array(z.string().toUpperCase()).optional(), /** Free CCU quota information if applicable. */ freeCcuQuotaInfo: z .object({ @@ -212,8 +201,8 @@ export const CcuInfoSchema = z.object({ }) .optional(), }); -/** Colab Compute Units (CCU) information. */ -export type CcuInfo = z.infer; +/** Colab consumption user information. */ +export type ConsumptionUserInfo = z.infer; /** The response when getting an assignment. */ export const GetAssignmentResponseSchema = z diff --git a/src/colab/client.ts b/src/colab/client.ts index b8b14cf5..c66123be 100644 --- a/src/colab/client.ts +++ b/src/colab/client.ts @@ -15,14 +15,14 @@ import { uuidToWebSafeBase64 } from '../utils/uuid'; import { Assignment, AuthType, - CcuInfo, Variant, GetAssignmentResponse, - CcuInfoSchema, AssignmentSchema, GetAssignmentResponseSchema, + UserInfo, UserInfoSchema, - SubscriptionTier, + ConsumptionUserInfo, + ConsumptionUserInfoSchema, PostAssignmentResponse, Outcome, PostAssignmentResponseSchema, @@ -92,31 +92,32 @@ export class ColabClient { } /** - * Gets the user's subscription tier. + * Gets the current user information. * * @param signal - Optional {@link AbortSignal} to cancel the request. - * @returns The user's subscription tier. */ - async getSubscriptionTier(signal?: AbortSignal): Promise { - const userInfo = await this.issueRequest( + async getUserInfo(signal?: AbortSignal): Promise { + return await this.issueRequest( new URL('v1/user-info', this.colabGapiDomain), { method: 'GET', signal }, UserInfoSchema, ); - return userInfo.subscriptionTier; } /** - * Gets the current Colab Compute Units (CCU) information. + * Gets the current user with Colab Compute Units (CCU) information. * * @param signal - Optional {@link AbortSignal} to cancel the request. - * @returns The current CCU information. */ - async getCcuInfo(signal?: AbortSignal): Promise { - return this.issueRequest( - new URL(`${TUN_ENDPOINT}/ccu-info`, this.colabDomain), + async getConsumptionUserInfo( + signal?: AbortSignal, + ): Promise { + const url = new URL('v1/user-info', this.colabGapiDomain); + url.searchParams.append('get_ccu_consumption_info', 'true'); + return await this.issueRequest( + url, { method: 'GET', signal }, - CcuInfoSchema, + ConsumptionUserInfoSchema, ); } diff --git a/src/colab/client.unit.test.ts b/src/colab/client.unit.test.ts index 4ca3e359..5c038739 100644 --- a/src/colab/client.unit.test.ts +++ b/src/colab/client.unit.test.ts @@ -14,7 +14,6 @@ import { ColabAssignedServer } from '../jupyter/servers'; import { TestUri } from '../test/helpers/uri'; import { uuidToWebSafeBase64 } from '../utils/uuid'; import { - CcuInfo, Assignment, Shape, SubscriptionState, @@ -25,6 +24,8 @@ import { RuntimeProxyToken, AuthType, ExperimentFlag, + ConsumptionUserInfo, + UserInfo, } from './api'; import { ColabClient, @@ -98,11 +99,20 @@ describe('ColabClient', () => { sinon.restore(); }); - it('successfully gets the subscription tier', async () => { + it('successfully gets user info', async () => { const mockResponse = { - subscriptionTier: 'SUBSCRIPTION_TIER_NONE', - paidComputeUnitsBalance: 0, - eligibleAccelerators: [{ variant: 'VARIANT_GPU', models: ['T4'] }], + subscriptionTier: 'SUBSCRIPTION_TIER_PRO', + eligibleAccelerators: [ + { + variant: 'VARIANT_GPU', + models: ['T4', 'A100', 'L4'], + }, + { + variant: 'VARIANT_TPU', + models: ['V5E1', 'V6E1', 'V28'], + }, + ], + ineligibleAccelerators: [], }; fetchStub .withArgs( @@ -117,22 +127,52 @@ describe('ColabClient', () => { new Response(withXSSI(JSON.stringify(mockResponse)), { status: 200 }), ); - await expect(client.getSubscriptionTier()).to.eventually.deep.equal( - SubscriptionTier.NONE, - ); + const response = client.getUserInfo(); + const expectedResponse: UserInfo = { + subscriptionTier: SubscriptionTier.PRO, + eligibleAccelerators: [ + { + variant: Variant.GPU, + models: ['T4', 'A100', 'L4'], + }, + { + variant: Variant.TPU, + models: ['V5E1', 'V6E1', 'V28'], + }, + ], + ineligibleAccelerators: [], + }; + await expect(response).to.eventually.deep.equal(expectedResponse); sinon.assert.calledOnce(fetchStub); }); - it('successfully gets CCU info', async () => { + it('successfully gets consumption user info', async () => { const mockResponse = { - currentBalance: 1, + subscriptionTier: 'SUBSCRIPTION_TIER_NONE', + paidComputeUnitsBalance: 1, consumptionRateHourly: 2, assignmentsCount: 3, - eligibleGpus: ['T4'], - ineligibleGpus: ['A100', 'L4'], - eligibleTpus: ['V6E1', 'V28'], - ineligibleTpus: ['V5E1'], + eligibleAccelerators: [ + { + variant: 'VARIANT_GPU', + models: ['T4'], + }, + { + variant: 'VARIANT_TPU', + models: ['V6E1', 'V28'], + }, + ], + ineligibleAccelerators: [ + { + variant: 'VARIANT_GPU', + models: ['A100', 'L4'], + }, + { + variant: 'VARIANT_TPU', + models: ['V5E1'], + }, + ], freeCcuQuotaInfo: { remainingTokens: '4', nextRefillTimestampSec: 5, @@ -142,25 +182,49 @@ describe('ColabClient', () => { .withArgs( urlMatcher({ method: 'GET', - host: COLAB_HOST, - path: '/tun/m/ccu-info', + host: GOOGLE_APIS_HOST, + path: '/v1/user-info', + queryParams: { get_ccu_consumption_info: 'true' }, + withAuthUser: false, }), ) .resolves( new Response(withXSSI(JSON.stringify(mockResponse)), { status: 200 }), ); - const expectedResponse: CcuInfo = { - ...mockResponse, + const response = client.getConsumptionUserInfo(); + + const expectedResponse: ConsumptionUserInfo = { + subscriptionTier: SubscriptionTier.NONE, + paidComputeUnitsBalance: mockResponse.paidComputeUnitsBalance, + consumptionRateHourly: mockResponse.consumptionRateHourly, + assignmentsCount: mockResponse.assignmentsCount, + eligibleAccelerators: [ + { + variant: Variant.GPU, + models: ['T4'], + }, + { + variant: Variant.TPU, + models: ['V6E1', 'V28'], + }, + ], + ineligibleAccelerators: [ + { + variant: Variant.GPU, + models: ['A100', 'L4'], + }, + { + variant: Variant.TPU, + models: ['V5E1'], + }, + ], freeCcuQuotaInfo: { ...mockResponse.freeCcuQuotaInfo, remainingTokens: Number(mockResponse.freeCcuQuotaInfo.remainingTokens), }, }; - await expect(client.getCcuInfo()).to.eventually.deep.equal( - expectedResponse, - ); - + await expect(response).to.eventually.deep.equal(expectedResponse); sinon.assert.calledOnce(fetchStub); }); @@ -646,57 +710,66 @@ describe('ColabClient', () => { }); it('supports non-XSSI responses', async () => { - const mockResponse = { - currentBalance: 1, - consumptionRateHourly: 2, - assignmentsCount: 3, - eligibleGpus: ['T4'], - ineligibleGpus: ['A100', 'L4'], - eligibleTpus: ['V6E1', 'V28'], - ineligibleTpus: ['V5E1'], - }; fetchStub .withArgs( urlMatcher({ method: 'GET', - host: COLAB_HOST, - path: '/tun/m/ccu-info', + host: GOOGLE_APIS_HOST, + path: '/v1/user-info', + withAuthUser: false, }), ) - .resolves(new Response(JSON.stringify(mockResponse), { status: 200 })); + .resolves( + new Response( + JSON.stringify({ + subscriptionTier: 'SUBSCRIPTION_TIER_NONE', + eligibleAccelerators: [], + ineligibleAccelerators: [], + }), + { status: 200 }, + ), + ); - await expect(client.getCcuInfo()).to.eventually.deep.equal(mockResponse); + await expect(client.getUserInfo()).to.eventually.deep.equal({ + subscriptionTier: SubscriptionTier.NONE, + eligibleAccelerators: [], + ineligibleAccelerators: [], + }); sinon.assert.calledOnce(fetchStub); }); it('retries request on 401 if onAuthError is provided', async () => { - const mockResponse = { - currentBalance: 1, - consumptionRateHourly: 2, - assignmentsCount: 3, - eligibleGpus: ['T4'], - ineligibleGpus: ['A100', 'L4'], - eligibleTpus: ['V6E1', 'V28'], - ineligibleTpus: ['V5E1'], - }; - fetchStub .withArgs( urlMatcher({ method: 'GET', - host: COLAB_HOST, - path: '/tun/m/ccu-info', + host: GOOGLE_APIS_HOST, + path: '/v1/user-info', + withAuthUser: false, }), ) .onFirstCall() .resolves(new Response('Unauthorized', { status: 401 })) .onSecondCall() .resolves( - new Response(withXSSI(JSON.stringify(mockResponse)), { status: 200 }), + new Response( + withXSSI( + JSON.stringify({ + subscriptionTier: 'SUBSCRIPTION_TIER_NONE', + eligibleAccelerators: [], + ineligibleAccelerators: [], + }), + ), + { status: 200 }, + ), ); - await expect(client.getCcuInfo()).to.eventually.deep.equal(mockResponse); + await expect(client.getUserInfo()).to.eventually.deep.equal({ + subscriptionTier: SubscriptionTier.NONE, + eligibleAccelerators: [], + ineligibleAccelerators: [], + }); sinon.assert.calledTwice(fetchStub); sinon.assert.calledOnce(onAuthErrorStub); @@ -707,7 +780,7 @@ describe('ColabClient', () => { .withArgs(sinon.match.any) .resolves(new Response('Unauthorized', { status: 401 })); - await expect(client.getCcuInfo()).to.eventually.be.rejectedWith( + await expect(client.getUserInfo()).to.eventually.be.rejectedWith( /Unauthorized/, ); @@ -726,7 +799,7 @@ describe('ColabClient', () => { .withArgs(sinon.match.any) .resolves(new Response('Unauthorized', { status: 401 })); - await expect(client.getCcuInfo()).to.eventually.be.rejectedWith( + await expect(client.getUserInfo()).to.eventually.be.rejectedWith( /Unauthorized/, ); sinon.assert.notCalled(onAuthErrorStub); @@ -737,8 +810,9 @@ describe('ColabClient', () => { .withArgs( urlMatcher({ method: 'GET', - host: COLAB_HOST, - path: '/tun/m/ccu-info', + host: GOOGLE_APIS_HOST, + path: '/v1/user-info', + withAuthUser: false, }), ) .resolves( @@ -748,7 +822,7 @@ describe('ColabClient', () => { }), ); - await expect(client.getCcuInfo()).to.eventually.be.rejectedWith( + await expect(client.getUserInfo()).to.eventually.be.rejectedWith( /Foo error/, ); }); @@ -758,37 +832,40 @@ describe('ColabClient', () => { .withArgs( urlMatcher({ method: 'GET', - host: COLAB_HOST, - path: '/tun/m/ccu-info', + host: GOOGLE_APIS_HOST, + path: '/v1/user-info', + withAuthUser: false, }), ) .resolves(new Response(withXSSI('not JSON eh?'), { status: 200 })); - await expect(client.getCcuInfo()).to.eventually.be.rejectedWith( + await expect(client.getUserInfo()).to.eventually.be.rejectedWith( /not JSON.+eh\?/, ); }); it('rejects response schema mismatches', async () => { - const mockResponse: Partial = { - currentBalance: 1, + const mockResponse = { + subscriptionTier: 'SUBSCRIPTION_TIER_NONE', + paidComputeUnitsBalance: 1, consumptionRateHourly: 2, - eligibleGpus: ['T4'], }; fetchStub .withArgs( urlMatcher({ method: 'GET', - host: COLAB_HOST, - path: '/tun/m/ccu-info', + host: GOOGLE_APIS_HOST, + path: '/v1/user-info', + queryParams: { get_ccu_consumption_info: 'true' }, + withAuthUser: false, }), ) .resolves( new Response(withXSSI(JSON.stringify(mockResponse)), { status: 200 }), ); - await expect(client.getCcuInfo()).to.eventually.be.rejectedWith( - /assignmentsCount.+received undefined/s, + await expect(client.getConsumptionUserInfo()).to.eventually.be.rejectedWith( + /eligibleAccelerators.+received undefined/s, ); }); diff --git a/src/colab/consumption/notifier.ts b/src/colab/consumption/notifier.ts index 195939fa..f8d4a68d 100644 --- a/src/colab/consumption/notifier.ts +++ b/src/colab/consumption/notifier.ts @@ -4,9 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -import vscode from 'vscode'; -import { CcuInfo, SubscriptionTier } from '../api'; -import { ColabClient } from '../client'; +import vscode, { Disposable, Event } from 'vscode'; +import { ConsumptionUserInfo, SubscriptionTier } from '../api'; import { openColabSignup } from '../commands/external'; const WARN_WHEN_LESS_THAN_MINUTES = 30; @@ -23,8 +22,8 @@ type Notify = * Monitors Colab Compute Units (CCU) balance and consumption rate, notifying * the user when their CCU-s are depleted or running low. */ -export class ConsumptionNotifier implements vscode.Disposable { - private ccuListener: vscode.Disposable; +export class ConsumptionNotifier implements Disposable { + private ccuListener: Disposable; private snoozeError = false; private snoozeWarn = false; private errorTimeout?: NodeJS.Timeout; @@ -32,8 +31,7 @@ export class ConsumptionNotifier implements vscode.Disposable { constructor( private readonly vs: typeof vscode, - private readonly colab: ColabClient, - onDidChangeCcuInfo: vscode.Event, + onDidChangeCcuInfo: Event, private readonly snoozeMinutes: number = DEFAULT_SNOOZE_MINUTES, ) { this.ccuListener = onDidChangeCcuInfo((e) => this.notifyCcuConsumption(e)); @@ -51,12 +49,13 @@ export class ConsumptionNotifier implements vscode.Disposable { * Gives the user an action to sign up, upgrade or purchase more CCU-s (link * to the signup page). */ - protected async notifyCcuConsumption(e: CcuInfo): Promise { + protected async notifyCcuConsumption(e: ConsumptionUserInfo): Promise { // When the user is not consuming any CCU-s, no need to notify. if (e.consumptionRateHourly <= 0) { return; } - const paidMinutesLeft = (e.currentBalance / e.consumptionRateHourly) * 60; + const paidMinutesLeft = + (e.paidComputeUnitsBalance / e.consumptionRateHourly) * 60; const freeMinutesLeft = calculateRoughMinutesLeft(e); // Quantize to 10 minutes. const totalMinutesLeft = ((paidMinutesLeft + freeMinutesLeft) / 10) * 10; @@ -71,7 +70,7 @@ export class ConsumptionNotifier implements vscode.Disposable { const action = notification.notify( notification.message, - await this.getTierRelevantAction(paidMinutesLeft > 0), + this.getTierRelevantAction(e.subscriptionTier, paidMinutesLeft > 0), ); this.setSnoozeTimeout(notification.notify); if (await action) { @@ -107,10 +106,10 @@ export class ConsumptionNotifier implements vscode.Disposable { return { message, notify }; } - private async getTierRelevantAction( + private getTierRelevantAction( + tier: SubscriptionTier, hasPaidBalance: boolean, - ): Promise { - const tier = await this.colab.getSubscriptionTier(); + ): SignupAction { switch (tier) { case SubscriptionTier.NONE: return hasPaidBalance @@ -146,14 +145,16 @@ export class ConsumptionNotifier implements vscode.Disposable { } } -function calculateRoughMinutesLeft(ccuInfo: CcuInfo): number { - const freeQuota = ccuInfo.freeCcuQuotaInfo; +function calculateRoughMinutesLeft( + consumptionUserInfo: ConsumptionUserInfo, +): number { + const freeQuota = consumptionUserInfo.freeCcuQuotaInfo; if (!freeQuota) { return 0; } // Free quota is in milli-CCUs. const freeCcu = freeQuota.remainingTokens / 1000; - return Math.floor((freeCcu / ccuInfo.consumptionRateHourly) * 60); + return Math.floor((freeCcu / consumptionUserInfo.consumptionRateHourly) * 60); } enum SignupAction { diff --git a/src/colab/consumption/notifier.unit.test.ts b/src/colab/consumption/notifier.unit.test.ts index 878e7851..2df0ac1a 100644 --- a/src/colab/consumption/notifier.unit.test.ts +++ b/src/colab/consumption/notifier.unit.test.ts @@ -5,11 +5,10 @@ */ import { assert, expect } from 'chai'; -import sinon, { SinonFakeTimers, SinonStubbedInstance } from 'sinon'; +import sinon, { SinonFakeTimers } from 'sinon'; import { TestEventEmitter } from '../../test/helpers/events'; import { newVsCodeStub, VsCodeStub } from '../../test/helpers/vscode'; -import { CcuInfo, SubscriptionTier } from '../api'; -import { ColabClient } from '../client'; +import { SubscriptionTier, ConsumptionUserInfo } from '../api'; import { ConsumptionNotifier } from './notifier'; const NOTIFICATION_SEVERITIES = ['warn', 'error'] as const; @@ -23,7 +22,7 @@ type NotificationSeverity = (typeof NOTIFICATION_SEVERITIES)[number]; // events each time it calculated the remaining minutes (e.g. for logging // purposes). class TestConsumptionNotifier extends ConsumptionNotifier { - override notifyCcuConsumption(e: CcuInfo): Promise { + override notifyCcuConsumption(e: ConsumptionUserInfo): Promise { return super.notifyCcuConsumption(e); } @@ -55,20 +54,16 @@ class TestConsumptionNotifier extends ConsumptionNotifier { describe('ConsumptionNotifier', () => { let vs: VsCodeStub; - let colabClient: SinonStubbedInstance; - let ccuEmitter: TestEventEmitter; + let ccuEmitter: TestEventEmitter; let consumptionNotifier: TestConsumptionNotifier; beforeEach(() => { vs = newVsCodeStub(); - colabClient = sinon.createStubInstance(ColabClient); - colabClient.getSubscriptionTier.resolves(SubscriptionTier.NONE); - ccuEmitter = new TestEventEmitter(); + ccuEmitter = new TestEventEmitter(); consumptionNotifier = new TestConsumptionNotifier( vs.asVsCode(), - colabClient, ccuEmitter.event, ); }); @@ -136,7 +131,10 @@ describe('ConsumptionNotifier', () => { type Consumption = ConsumptionByMinutes | ConsumptionByRate; - function createCcuInfo(c: Consumption): CcuInfo { + function createCcuInfo( + c: Consumption, + tier?: SubscriptionTier, + ): ConsumptionUserInfo { let hourlyConsumptionRate: number; let paidBalance: number; let freeTokens: number; @@ -150,7 +148,8 @@ describe('ConsumptionNotifier', () => { freeTokens = (c.freeMinutes / 60) * hourlyConsumptionRate * 1000; } return { - currentBalance: paidBalance, + subscriptionTier: tier ?? SubscriptionTier.NONE, + paidComputeUnitsBalance: paidBalance, consumptionRateHourly: hourlyConsumptionRate, freeCcuQuotaInfo: { remainingTokens: freeTokens, @@ -158,8 +157,8 @@ describe('ConsumptionNotifier', () => { }, // Irrelevant fields for SUT. assignmentsCount: 1, - eligibleGpus: [], - eligibleTpus: [], + eligibleAccelerators: [], + ineligibleAccelerators: [], }; } @@ -217,8 +216,7 @@ describe('ConsumptionNotifier', () => { ]; for (const t of nonNotifyingTests) { it(`should not notify when ${t.label}`, async () => { - colabClient.getSubscriptionTier.resolves(t.tier ?? SubscriptionTier.NONE); - const ccuInfo = createCcuInfo(t.consumption); + const ccuInfo = createCcuInfo(t.consumption, t.tier); const noOp = consumptionNotifier.nextConsumptionCalculation(); ccuEmitter.fire(ccuInfo); @@ -354,8 +352,7 @@ describe('ConsumptionNotifier', () => { for (const t of notifyingTests) { const action = t.should.action.toLowerCase(); it(`should ${t.should.severity} with a prompt to ${action} when ${t.label}`, async () => { - colabClient.getSubscriptionTier.resolves(t.tier); - const ccuInfo = createCcuInfo(t.consumption); + const ccuInfo = createCcuInfo(t.consumption, t.tier); const waitForNotification = nextNotification(t.should.severity); ccuEmitter.fire(ccuInfo); diff --git a/src/colab/consumption/poller.ts b/src/colab/consumption/poller.ts index 700ed7b6..b7f7fef5 100644 --- a/src/colab/consumption/poller.ts +++ b/src/colab/consumption/poller.ts @@ -4,14 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import vscode, { Disposable } from 'vscode'; +import vscode, { Disposable, Event, EventEmitter } from 'vscode'; import { OverrunPolicy, SequentialTaskRunner, StartMode, } from '../../common/task-runner'; import { Toggleable } from '../../common/toggleable'; -import { CcuInfo } from '../api'; +import { ConsumptionUserInfo } from '../api'; import { ColabClient } from '../client'; const POLL_INTERVAL_MS = 1000 * 60 * 5; // 5 minutes. @@ -24,9 +24,9 @@ const TASK_TIMEOUT_MS = 1000 * 10; // 10 seconds. * (single-threaded, no worker threads). */ export class ConsumptionPoller implements Toggleable, Disposable { - readonly onDidChangeCcuInfo: vscode.Event; - private readonly emitter: vscode.EventEmitter; - private ccuInfo?: CcuInfo; + readonly onDidChangeCcuInfo: Event; + private readonly emitter: EventEmitter; + private consumptionUserInfo?: ConsumptionUserInfo; private runner: SequentialTaskRunner; private isDisposed = false; @@ -34,7 +34,7 @@ export class ConsumptionPoller implements Toggleable, Disposable { private readonly vs: typeof vscode, private readonly client: ColabClient, ) { - this.emitter = new this.vs.EventEmitter(); + this.emitter = new this.vs.EventEmitter(); this.onDidChangeCcuInfo = this.emitter.event; this.runner = new SequentialTaskRunner( { @@ -76,13 +76,17 @@ export class ConsumptionPoller implements Toggleable, Disposable { * Checks the latests CCU info and emits an event when there is a change. */ private async poll(signal?: AbortSignal): Promise { - const ccuInfo = await this.client.getCcuInfo(signal); - if (JSON.stringify(ccuInfo) === JSON.stringify(this.ccuInfo)) { + const consumptionUserInfo = + await this.client.getConsumptionUserInfo(signal); + if ( + JSON.stringify(consumptionUserInfo) === + JSON.stringify(this.consumptionUserInfo) + ) { return; } - this.ccuInfo = ccuInfo; - this.emitter.fire(this.ccuInfo); + this.consumptionUserInfo = consumptionUserInfo; + this.emitter.fire(this.consumptionUserInfo); } private assertNotDisposed(): void { diff --git a/src/colab/consumption/poller.unit.test.ts b/src/colab/consumption/poller.unit.test.ts index 410ca264..a20539a0 100644 --- a/src/colab/consumption/poller.unit.test.ts +++ b/src/colab/consumption/poller.unit.test.ts @@ -12,20 +12,37 @@ import { createStubInstance, } from 'sinon'; import { newVsCodeStub, VsCodeStub } from '../../test/helpers/vscode'; -import { CcuInfo } from '../api'; +import { ConsumptionUserInfo, SubscriptionTier, Variant } from '../api'; import { ColabClient } from '../client'; import { ConsumptionPoller } from './poller'; const POLL_INTERVAL_MS = 1000 * 60 * 5; // 5 minutes. const TASK_TIMEOUT_MS = 1000 * 10; // 10 seconds. -const DEFAULT_CCU_INFO: CcuInfo = { - currentBalance: 1, +const DEFAULT_CCU_INFO: ConsumptionUserInfo = { + subscriptionTier: SubscriptionTier.NONE, + paidComputeUnitsBalance: 1, consumptionRateHourly: 2, assignmentsCount: 3, - eligibleGpus: ['T4'], - ineligibleGpus: ['A100', 'L4'], - eligibleTpus: ['V6E1', 'V28'], - ineligibleTpus: ['V5E1'], + eligibleAccelerators: [ + { + variant: Variant.GPU, + models: ['T4'], + }, + { + variant: Variant.TPU, + models: ['V6E1', 'V28'], + }, + ], + ineligibleAccelerators: [ + { + variant: Variant.GPU, + models: ['A100', 'L4'], + }, + { + variant: Variant.TPU, + models: ['V5E1'], + }, + ], freeCcuQuotaInfo: { remainingTokens: 4, nextRefillTimestampSec: 5, @@ -54,7 +71,7 @@ describe('ConsumptionPoller', () => { describe('lifecycle', () => { beforeEach(() => { - clientStub.getCcuInfo.resolves(DEFAULT_CCU_INFO); + clientStub.getConsumptionUserInfo.resolves(DEFAULT_CCU_INFO); }); afterEach(() => { @@ -62,12 +79,12 @@ describe('ConsumptionPoller', () => { }); it('disposes the runner', async () => { - clientStub.getCcuInfo.resetHistory(); + clientStub.getConsumptionUserInfo.resetHistory(); poller.dispose(); await fakeClock.tickAsync(POLL_INTERVAL_MS); - sinon.assert.notCalled(clientStub.getCcuInfo); + sinon.assert.notCalled(clientStub.getConsumptionUserInfo); }); it('throws when used after being disposed', () => { @@ -82,8 +99,8 @@ describe('ConsumptionPoller', () => { }); it('aborts slow calls to get CCU info', async () => { - clientStub.getCcuInfo.resetHistory(); - clientStub.getCcuInfo.onFirstCall().callsFake( + clientStub.getConsumptionUserInfo.resetHistory(); + clientStub.getConsumptionUserInfo.onFirstCall().callsFake( // eslint-disable-next-line @typescript-eslint/no-empty-function async () => new Promise(() => {}), ); @@ -91,23 +108,24 @@ describe('ConsumptionPoller', () => { await fakeClock.tickAsync(TASK_TIMEOUT_MS + 1); - sinon.assert.calledOnce(clientStub.getCcuInfo); - expect(clientStub.getCcuInfo.firstCall.args[0]?.aborted).to.be.true; + sinon.assert.calledOnce(clientStub.getConsumptionUserInfo); + expect(clientStub.getConsumptionUserInfo.firstCall.args[0]?.aborted).to.be + .true; }); }); describe('toggled on', () => { beforeEach(async () => { - clientStub.getCcuInfo.resolves(DEFAULT_CCU_INFO); + clientStub.getConsumptionUserInfo.resolves(DEFAULT_CCU_INFO); poller.on(); // Turning the poller on runs the task immediately. Wait past the task // timeout to ensure the immediate invocation runs to completion. await fakeClock.tickAsync(TASK_TIMEOUT_MS); - clientStub.getCcuInfo.resetHistory(); + clientStub.getConsumptionUserInfo.resetHistory(); }); describe('when the CCU info does not change', () => { - let onDidChangeCcuInfo: sinon.SinonStub<[CcuInfo]>; + let onDidChangeCcuInfo: sinon.SinonStub<[ConsumptionUserInfo]>; beforeEach(() => { onDidChangeCcuInfo = sinon.stub(); @@ -117,50 +135,60 @@ describe('ConsumptionPoller', () => { it('does not emit an event', async () => { await fakeClock.tickAsync(POLL_INTERVAL_MS); - sinon.assert.calledOnce(clientStub.getCcuInfo); + sinon.assert.calledOnce(clientStub.getConsumptionUserInfo); sinon.assert.notCalled(onDidChangeCcuInfo); }); }); describe('when the CCU info changes', () => { - const newCcuInfo: CcuInfo = { + const newCcuInfo: ConsumptionUserInfo = { ...DEFAULT_CCU_INFO, - eligibleGpus: [], + eligibleAccelerators: [], }; - let onDidChangeCcuInfo: sinon.SinonStub<[CcuInfo]>; + let onDidChangeCcuInfo: sinon.SinonStub<[ConsumptionUserInfo]>; beforeEach(() => { onDidChangeCcuInfo = sinon.stub(); poller.onDidChangeCcuInfo(onDidChangeCcuInfo); - clientStub.getCcuInfo.resolves(newCcuInfo); + clientStub.getConsumptionUserInfo.resolves(newCcuInfo); }); it('emits an event', async () => { await fakeClock.tickAsync(POLL_INTERVAL_MS); - sinon.assert.calledOnce(clientStub.getCcuInfo); + sinon.assert.calledOnce(clientStub.getConsumptionUserInfo); sinon.assert.calledOnce(onDidChangeCcuInfo); }); }); }); it('can be toggled on and off', async () => { - const onDidChangeCcuInfo: sinon.SinonStub<[CcuInfo]> = sinon.stub(); + const onDidChangeCcuInfo: sinon.SinonStub<[ConsumptionUserInfo]> = + sinon.stub(); poller.onDidChangeCcuInfo(onDidChangeCcuInfo); // On for 3. - clientStub.getCcuInfo.resolves({ ...DEFAULT_CCU_INFO, currentBalance: 3 }); + clientStub.getConsumptionUserInfo.resolves({ + ...DEFAULT_CCU_INFO, + paidComputeUnitsBalance: 3, + }); poller.on(); await fakeClock.tickAsync(POLL_INTERVAL_MS); // Off for 2. - clientStub.getCcuInfo.resolves({ ...DEFAULT_CCU_INFO, currentBalance: 2 }); + clientStub.getConsumptionUserInfo.resolves({ + ...DEFAULT_CCU_INFO, + paidComputeUnitsBalance: 2, + }); poller.off(); await fakeClock.tickAsync(POLL_INTERVAL_MS); // On for 1. - clientStub.getCcuInfo.resolves({ ...DEFAULT_CCU_INFO, currentBalance: 1 }); + clientStub.getConsumptionUserInfo.resolves({ + ...DEFAULT_CCU_INFO, + paidComputeUnitsBalance: 1, + }); poller.on(); await fakeClock.tickAsync(POLL_INTERVAL_MS); @@ -168,11 +196,11 @@ describe('ConsumptionPoller', () => { sinon.assert.calledWith( onDidChangeCcuInfo.firstCall, - sinon.match({ currentBalance: 3 }), + sinon.match({ paidComputeUnitsBalance: 3 }), ); sinon.assert.calledWith( onDidChangeCcuInfo.secondCall, - sinon.match({ currentBalance: 1 }), + sinon.match({ paidComputeUnitsBalance: 1 }), ); }); }); diff --git a/src/extension.ts b/src/extension.ts index 95edd37f..376a25c8 100644 --- a/src/extension.ts +++ b/src/extension.ts @@ -179,11 +179,7 @@ function watchConsumption(colab: ColabClient): { const disposables: Disposable[] = []; const poller = new ConsumptionPoller(vscode, colab); disposables.push(poller); - const notifier = new ConsumptionNotifier( - vscode, - colab, - poller.onDidChangeCcuInfo, - ); + const notifier = new ConsumptionNotifier(vscode, poller.onDidChangeCcuInfo); disposables.push(notifier); return { toggle: poller, disposables }; diff --git a/src/jupyter/assignments.ts b/src/jupyter/assignments.ts index 8e4d1e14..30688e4b 100644 --- a/src/jupyter/assignments.ts +++ b/src/jupyter/assignments.ts @@ -105,34 +105,27 @@ export class AssignmentManager implements vscode.Disposable { /** * Retrieves a list of available server descriptors that can be assigned. * - * @param subscriptionTier - The user's subscription tier. * @param signal - An optional {@link AbortSignal} to cancel the operation. * @returns A list of available server descriptors. */ // TODO: Consider communicating which machines are available, but not to the // user at their tier (in the "ineligible" list). async getAvailableServerDescriptors( - subscriptionTier: SubscriptionTier, signal?: AbortSignal, ): Promise { - const ccuInfo = await this.client.getCcuInfo(signal); - - const eligibleGpus = new Set(ccuInfo.eligibleGpus); - const gpus: ColabServerDescriptor[] = Array.from(eligibleGpus).map((e) => ({ - label: `Colab GPU ${e}`, - variant: Variant.GPU, - accelerator: e, - })); - - const eligibleTpus = new Set(ccuInfo.eligibleTpus); - const tpus: ColabServerDescriptor[] = Array.from(eligibleTpus).map((e) => ({ - label: `Colab TPU ${e}`, - variant: Variant.TPU, - accelerator: e, - })); - - const defaultDescriptors = [DEFAULT_CPU_SERVER, ...gpus, ...tpus]; - if (subscriptionTier === SubscriptionTier.NONE) { + const userInfo = await this.client.getUserInfo(signal); + + const eligibleDescriptors: ColabServerDescriptor[] = + userInfo.eligibleAccelerators.flatMap((acc) => + acc.models.map((model) => ({ + label: `Colab ${acc.variant} ${model}`, + variant: acc.variant, + accelerator: model, + })), + ); + + const defaultDescriptors = [DEFAULT_CPU_SERVER, ...eligibleDescriptors]; + if (userInfo.subscriptionTier === SubscriptionTier.NONE) { return defaultDescriptors; } diff --git a/src/jupyter/assignments.unit.test.ts b/src/jupyter/assignments.unit.test.ts index 39e31351..3592e5ed 100644 --- a/src/jupyter/assignments.unit.test.ts +++ b/src/jupyter/assignments.unit.test.ts @@ -15,6 +15,7 @@ import { Shape, SubscriptionState, SubscriptionTier, + UserInfo, Variant, } from '../colab/api'; import { @@ -150,18 +151,20 @@ describe('AssignmentManager', () => { }); describe('getAvailableServerDescriptors', () => { - const mockCcuInfo = { - currentBalance: 1, - consumptionRateHourly: 2, - assignmentsCount: 0, - eligibleGpus: ['T4', 'A100'], - ineligibleGpus: [], - eligibleTpus: ['V5E1', 'V6E1'], - ineligibleTpus: [], - freeCcuQuotaInfo: { - remainingTokens: 4, - nextRefillTimestampSec: 5, - }, + const mockUserInfo: UserInfo = { + subscriptionTier: SubscriptionTier.NONE, + paidComputeUnitsBalance: 1, + eligibleAccelerators: [ + { + variant: Variant.GPU, + models: ['T4', 'A100'], + }, + { + variant: Variant.TPU, + models: ['V5E1', 'V6E1'], + }, + ], + ineligibleAccelerators: [], }; const defaultGpuT4Descriptor = { @@ -189,11 +192,9 @@ describe('AssignmentManager', () => { }; it('returns the default CPU and the eligible servers', async () => { - colabClientStub.getCcuInfo.resolves(mockCcuInfo); + colabClientStub.getUserInfo.resolves(mockUserInfo); - const servers = await assignmentManager.getAvailableServerDescriptors( - SubscriptionTier.NONE, - ); + const servers = await assignmentManager.getAvailableServerDescriptors(); expect(servers).to.deep.equal([ DEFAULT_CPU_SERVER, @@ -205,11 +206,12 @@ describe('AssignmentManager', () => { }); it('returns the default CPU and the eligible servers for pro users', async () => { - colabClientStub.getCcuInfo.resolves(mockCcuInfo); + colabClientStub.getUserInfo.resolves({ + ...mockUserInfo, + subscriptionTier: SubscriptionTier.PRO, + }); - const tier = SubscriptionTier.PRO; - const servers = - await assignmentManager.getAvailableServerDescriptors(tier); + const servers = await assignmentManager.getAvailableServerDescriptors(); expect(servers).to.deep.equal([ { ...DEFAULT_CPU_SERVER, shape: Shape.STANDARD }, diff --git a/src/jupyter/provider.ts b/src/jupyter/provider.ts index 4fab5510..71545f9f 100644 --- a/src/jupyter/provider.ts +++ b/src/jupyter/provider.ts @@ -140,7 +140,7 @@ export class ColabJupyterServerProvider commands.push(AUTO_CONNECT, NEW_SERVER, OPEN_COLAB_WEB); if (this.isAuthorized) { try { - const tier = await this.client.getSubscriptionTier(); + const tier = (await this.client.getUserInfo()).subscriptionTier; if (tier === SubscriptionTier.NONE) { commands.push(UPGRADE_TO_PRO); } @@ -209,9 +209,8 @@ export class ColabJupyterServerProvider } private async assignServer(): Promise { - const tier = await this.client.getSubscriptionTier(); const serverType = await this.serverPicker.prompt( - await this.assignmentManager.getAvailableServerDescriptors(tier), + await this.assignmentManager.getAvailableServerDescriptors(), ); if (!serverType) { throw new this.vs.CancellationError(); diff --git a/src/jupyter/provider.unit.test.ts b/src/jupyter/provider.unit.test.ts index 13cd8c4c..8bb7449c 100644 --- a/src/jupyter/provider.unit.test.ts +++ b/src/jupyter/provider.unit.test.ts @@ -286,7 +286,7 @@ describe('ColabJupyterServerProvider', () => { }); it('excludes upgrade to pro command when getting the subscription tier fails', async () => { - colabClientStub.getSubscriptionTier.rejects(new Error('foo')); + colabClientStub.getUserInfo.rejects(new Error('foo')); const commands = await serverProvider.provideCommands( undefined, cancellationToken, @@ -301,7 +301,11 @@ describe('ColabJupyterServerProvider', () => { }); it('excludes upgrade to pro command for users with pro', async () => { - colabClientStub.getSubscriptionTier.resolves(SubscriptionTier.PRO); + colabClientStub.getUserInfo.resolves({ + subscriptionTier: SubscriptionTier.PRO, + eligibleAccelerators: [], + ineligibleAccelerators: [], + }); const commands = await serverProvider.provideCommands( undefined, @@ -317,9 +321,11 @@ describe('ColabJupyterServerProvider', () => { }); it('excludes upgrade to pro command for users with pro-plus', async () => { - colabClientStub.getSubscriptionTier.resolves( - SubscriptionTier.PRO_PLUS, - ); + colabClientStub.getUserInfo.resolves({ + subscriptionTier: SubscriptionTier.PRO_PLUS, + eligibleAccelerators: [], + ineligibleAccelerators: [], + }); const commands = await serverProvider.provideCommands( undefined, @@ -335,7 +341,11 @@ describe('ColabJupyterServerProvider', () => { }); it('returns commands to auto-connect, create a server, open Colab web and upgrade to pro for free users', async () => { - colabClientStub.getSubscriptionTier.resolves(SubscriptionTier.NONE); + colabClientStub.getUserInfo.resolves({ + subscriptionTier: SubscriptionTier.NONE, + eligibleAccelerators: [], + ineligibleAccelerators: [], + }); const commands = await serverProvider.provideCommands( undefined, @@ -486,8 +496,6 @@ describe('ColabJupyterServerProvider', () => { }); it('completes assigning a server', async () => { - colabClientStub.getSubscriptionTier.resolves(SubscriptionTier.PRO); - const availableServers = [DEFAULT_SERVER]; assignmentStub.getAvailableServerDescriptors.resolves( availableServers, @@ -512,10 +520,7 @@ describe('ColabJupyterServerProvider', () => { ).to.eventually.deep.equal(DEFAULT_SERVER); sinon.assert.calledOnce(serverPickerStub.prompt); - sinon.assert.calledOnceWithExactly( - assignmentStub.getAvailableServerDescriptors, - SubscriptionTier.PRO, - ); + sinon.assert.calledOnce(assignmentStub.getAvailableServerDescriptors); sinon.assert.calledOnce(assignmentStub.assignServer); }); });