Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 22 additions & 33 deletions src/colab/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ function normalizeVariant(variant: ColabGapiVariant): Variant {
}
}

export const Accelerator = z.object({
/** The variant of the assignment. */
variant: z.enum(ColabGapiVariant).transform(normalizeVariant),
/** The assigned accelerator. */
models: z.array(z.string().toUpperCase()),
});

/**
* The schema for top level information about a user's tier, usage and
* availability in Colab.
Expand All @@ -139,27 +146,21 @@ export const UserInfoSchema = z.object({
/** The paid Colab Compute Units balance. */
paidComputeUnitsBalance: z.number().optional(),
/** The eligible machine accelerators. */
eligibleAccelerators: z
.array(
z.object({
/** The variant of the assignment. */
variant: z.enum(ColabGapiVariant).transform(normalizeVariant),
/** The assigned accelerator. */
models: z.array(z.string().toUpperCase()),
}),
)
.optional(),
eligibleAccelerators: z.array(Accelerator),
/** The ineligible machine accelerators. */
ineligibleAccelerators: z.array(Accelerator),
});
/** Colab user information. */
export type UserInfo = z.infer<typeof UserInfoSchema>;

/** The schema of Colab Compute Units (CCU) information. */
export const CcuInfoSchema = z.object({
/**
* The current balance of the paid CCUs.
*
* Naming is unfortunate due to historical reasons and free CCU quota
* balance is made available in a separate field for the same reasons.
*/
currentBalance: z.number(),
/**
* The schema for top level information about a user's tier, usage and
* availability in Colab when CCU consumption info is requested (consumption
* fields are required).
*/
export const ConsumptionUserInfoSchema = UserInfoSchema.required({
paidComputeUnitsBalance: true,
}).extend({
/**
* The current rate of consumption of the user's CCUs (paid or free) based on
* all assigned VMs.
Expand All @@ -170,18 +171,6 @@ export const CcuInfoSchema = z.object({
* is positive.
*/
assignmentsCount: z.number(),
/** The list of eligible GPU accelerators. */
eligibleGpus: z.array(z.string().toUpperCase()),
/** The list of ineligible GPU accelerators. */
ineligibleGpus: z.array(z.string().toUpperCase()).optional(),
/**
* The list of eligible TPU accelerators.
*/
eligibleTpus: z.array(z.string().toUpperCase()),
/**
* The list of ineligible TPU accelerators.
*/
ineligibleTpus: z.array(z.string().toUpperCase()).optional(),
/** Free CCU quota information if applicable. */
freeCcuQuotaInfo: z
.object({
Expand Down Expand Up @@ -212,8 +201,8 @@ export const CcuInfoSchema = z.object({
})
.optional(),
});
/** Colab Compute Units (CCU) information. */
export type CcuInfo = z.infer<typeof CcuInfoSchema>;
/** Colab consumption user information. */
export type ConsumptionUserInfo = z.infer<typeof ConsumptionUserInfoSchema>;

/** The response when getting an assignment. */
export const GetAssignmentResponseSchema = z
Expand Down
29 changes: 15 additions & 14 deletions src/colab/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import { uuidToWebSafeBase64 } from '../utils/uuid';
import {
Assignment,
AuthType,
CcuInfo,
Variant,
GetAssignmentResponse,
CcuInfoSchema,
AssignmentSchema,
GetAssignmentResponseSchema,
UserInfo,
UserInfoSchema,
SubscriptionTier,
ConsumptionUserInfo,
ConsumptionUserInfoSchema,
PostAssignmentResponse,
Outcome,
PostAssignmentResponseSchema,
Expand Down Expand Up @@ -92,31 +92,32 @@ export class ColabClient {
}

/**
* Gets the user's subscription tier.
* Gets the current user information.
*
* @param signal - Optional {@link AbortSignal} to cancel the request.
* @returns The user's subscription tier.
*/
async getSubscriptionTier(signal?: AbortSignal): Promise<SubscriptionTier> {
const userInfo = await this.issueRequest(
async getUserInfo(signal?: AbortSignal): Promise<UserInfo> {
return await this.issueRequest(
new URL('v1/user-info', this.colabGapiDomain),
{ method: 'GET', signal },
UserInfoSchema,
);
return userInfo.subscriptionTier;
}

/**
* Gets the current Colab Compute Units (CCU) information.
* Gets the current user with Colab Compute Units (CCU) information.
*
* @param signal - Optional {@link AbortSignal} to cancel the request.
* @returns The current CCU information.
*/
async getCcuInfo(signal?: AbortSignal): Promise<CcuInfo> {
return this.issueRequest(
new URL(`${TUN_ENDPOINT}/ccu-info`, this.colabDomain),
async getConsumptionUserInfo(
signal?: AbortSignal,
): Promise<ConsumptionUserInfo> {
const url = new URL('v1/user-info', this.colabGapiDomain);
url.searchParams.append('get_ccu_consumption_info', 'true');
return await this.issueRequest(
url,
{ method: 'GET', signal },
CcuInfoSchema,
ConsumptionUserInfoSchema,
);
}

Expand Down
Loading
Loading