Skip to content

Commit

Permalink
Take plan out of WorkspaceType and access it from auth.plan() (#2244)
Browse files Browse the repository at this point in the history
  • Loading branch information
PopDaph authored Oct 25, 2023
1 parent 698df45 commit dee5671
Show file tree
Hide file tree
Showing 22 changed files with 138 additions and 96 deletions.
14 changes: 8 additions & 6 deletions front/components/assistant_builder/AssistantBuilder.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import { PostOrPatchAgentConfigurationRequestBodySchema } from "@app/pages/api/w
import { AppType } from "@app/types/app";
import { TimeframeUnit } from "@app/types/assistant/actions/retrieval";
import { DataSourceType } from "@app/types/data_source";
import { UserType, WorkspaceType } from "@app/types/user";
import { PlanType, UserType, WorkspaceType } from "@app/types/user";

import DataSourceResourceSelectorTree from "../DataSourceResourceSelectorTree";
import AssistantBuilderDustAppModal from "./AssistantBuilderDustAppModal";
Expand Down Expand Up @@ -159,6 +159,7 @@ export type AssistantBuilderInitialState = {
type AssistantBuilderProps = {
user: UserType;
owner: WorkspaceType;
plan: PlanType;
gaTrackingId: string;
dataSources: DataSourceType[];
dustApps: AppType[];
Expand Down Expand Up @@ -203,6 +204,7 @@ const getCreativityLevelFromTemperature = (temperature: number) => {
export default function AssistantBuilder({
user,
owner,
plan,
gaTrackingId,
dataSources,
dustApps,
Expand All @@ -219,7 +221,7 @@ export default function AssistantBuilder({
...DEFAULT_ASSISTANT_STATE,
generationSettings: {
...DEFAULT_ASSISTANT_STATE.generationSettings,
modelSettings: owner.plan.limits.largeModels
modelSettings: plan.limits.largeModels
? GPT_4_32K_MODEL_CONFIG
: GPT_3_5_TURBO_16K_MODEL_CONFIG,
},
Expand Down Expand Up @@ -844,7 +846,7 @@ export default function AssistantBuilder({
/>
</div>
<AdvancedSettings
owner={owner}
plan={plan}
generationSettings={builderState.generationSettings}
setGenerationSettings={(generationSettings) => {
setEdited(true);
Expand Down Expand Up @@ -1338,11 +1340,11 @@ function AssistantBuilderTextArea({
}

function AdvancedSettings({
owner,
plan,
generationSettings,
setGenerationSettings,
}: {
owner: WorkspaceType;
plan: PlanType;
generationSettings: AssistantBuilderState["generationSettings"];
setGenerationSettings: (
generationSettingsSettings: AssistantBuilderState["generationSettings"]
Expand Down Expand Up @@ -1382,7 +1384,7 @@ function AdvancedSettings({
{usedModelConfigs
.filter(
(modelConfig) =>
!modelConfig.largeModel || owner.plan.limits.largeModels
!modelConfig.largeModel || plan.limits.largeModels
)
.map((modelConfig) => (
<DropdownMenu.Item
Expand Down
5 changes: 2 additions & 3 deletions front/lib/api/assistant/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ export async function generateActionInputs(

const MIN_GENERATION_TOKENS = 2048;

const useLargeModels = auth.workspace()?.plan.limits.largeModels
? true
: false;
const plan = auth.plan();
const useLargeModels = plan && plan.limits.largeModels ? true : false;

let model: { providerId: string; modelId: string } = useLargeModels
? {
Expand Down
18 changes: 10 additions & 8 deletions front/lib/api/assistant/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ export async function getAgentConfiguration(
if (!owner) {
throw new Error("Unexpected `auth` without `workspace`.");
}
const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}

if (isGlobalAgentId(agentId)) {
return await getGlobalAgent(auth, agentId);
Expand Down Expand Up @@ -187,10 +191,7 @@ export async function getAgentConfiguration(
};

// Enforce plan limits: check if large models are allowed and act accordingly
if (
!owner.plan.limits.largeModels &&
getSupportedModelConfig(model).largeModel
) {
if (!plan.limits.largeModels && getSupportedModelConfig(model).largeModel) {
return null;
}
}
Expand Down Expand Up @@ -396,14 +397,15 @@ export async function createAgentGenerationConfiguration(
if (!owner) {
throw new Error("Unexpected `auth` without `workspace`.");
}
const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}

if (temperature < 0) {
throw new Error("Temperature must be positive.");
}
if (
getSupportedModelConfig(model).largeModel &&
!owner.plan.limits.largeModels
) {
if (getSupportedModelConfig(model).largeModel && !plan.limits.largeModels) {
throw new Error("You need to upgrade your plan to use large models.");
}

Expand Down
12 changes: 10 additions & 2 deletions front/lib/api/assistant/global_agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ async function _getHelperGlobalAgent(
if (!owner) {
throw new Error("Unexpected `auth` without `workspace`.");
}
const model = owner.plan.limits.largeModels
const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}
const model = plan.limits.largeModels
? {
providerId: GPT_4_32K_MODEL_CONFIG.providerId,
modelId: GPT_4_32K_MODEL_CONFIG.modelId,
Expand Down Expand Up @@ -538,6 +542,10 @@ export async function getGlobalAgent(
if (!owner) {
throw new Error("Cannot find Global Agent Configuration: no workspace.");
}
const plan = auth.plan();
if (!plan) {
throw new Error("Unexpected `auth` without `plan`.");
}

const settings = await GlobalAgentSettings.findOne({
where: { workspaceId: owner.id, agentId: sId },
Expand Down Expand Up @@ -584,7 +592,7 @@ export async function getGlobalAgent(

// Enforce plan limits: check if large models are allowed and act accordingly
if (
!owner.plan.limits.largeModels &&
!plan.limits.largeModels &&
agentConfiguration.generation &&
getSupportedModelConfig(agentConfiguration.generation?.model).largeModel
) {
Expand Down
3 changes: 1 addition & 2 deletions front/lib/api/workspace.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Authenticator, planForWorkspace } from "@app/lib/auth";
import { Authenticator } from "@app/lib/auth";
import { RoleType } from "@app/lib/auth";
import {
Membership,
Expand Down Expand Up @@ -28,7 +28,6 @@ export async function getWorkspaceInfos(
name: workspace.name,
allowedDomain: workspace.allowedDomain,
role: "none",
plan: planForWorkspace(workspace),
upgradedAt: workspace.upgradedAt?.getTime() || null,
};
}
Expand Down
33 changes: 25 additions & 8 deletions front/lib/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,19 @@ export class Authenticator {
_workspace: Workspace | null;
_user: User | null;
_role: RoleType;
_plan: PlanType | null;

// Should only be called from the static methods below.
constructor(workspace: Workspace | null, user: User | null, role: RoleType) {
constructor(
workspace: Workspace | null,
user: User | null,
role: RoleType,
plan: PlanType | null
) {
this._workspace = workspace;
this._user = user;
this._role = role;
this._plan = plan;
}

/**
Expand Down Expand Up @@ -98,7 +105,9 @@ export class Authenticator {
}
}

return new Authenticator(workspace, user, role);
const plan = workspace ? planForWorkspace(workspace) : null;

return new Authenticator(workspace, user, role, plan);
}

/**
Expand Down Expand Up @@ -136,11 +145,13 @@ export class Authenticator {
})(),
]);

const plan = workspace ? planForWorkspace(workspace) : null;

if (!user || !user.isDustSuperUser) {
return new Authenticator(workspace, user, "none");
return new Authenticator(workspace, user, "none", plan);
}

return new Authenticator(workspace, user, "admin");
return new Authenticator(workspace, user, "admin", plan);
}

/**
Expand Down Expand Up @@ -182,8 +193,10 @@ export class Authenticator {
}
}

const plan = workspace ? planForWorkspace(workspace) : null;

return {
auth: new Authenticator(workspace, null, role),
auth: new Authenticator(workspace, null, role, plan),
keyWorkspaceId: keyWorkspace.sId,
};
}
Expand All @@ -204,7 +217,9 @@ export class Authenticator {
if (!workspace) {
throw new Error(`Could not find workspace with sId ${workspaceId}`);
}
return new Authenticator(workspace, null, "builder");
const plan = workspace ? planForWorkspace(workspace) : null;

return new Authenticator(workspace, null, "builder", plan);
}

role(): RoleType {
Expand Down Expand Up @@ -249,12 +264,15 @@ export class Authenticator {
name: this._workspace.name,
allowedDomain: this._workspace.allowedDomain || null,
role: this._role,
plan: planForWorkspace(this._workspace),
upgradedAt: this._workspace.upgradedAt?.getTime() || null,
}
: null;
}

plan(): PlanType | null {
return this._plan;
}

/**
* This is a convenience method to get the user from the Authenticator. The returned UserType
* object won't have the user's workspaces set.
Expand Down Expand Up @@ -354,7 +372,6 @@ export async function getUserFromSession(
name: w.name,
allowedDomain: w.allowedDomain || null,
role,
plan: planForWorkspace(w),
upgradedAt: w.upgradedAt?.getTime() || null,
};
}),
Expand Down
2 changes: 1 addition & 1 deletion front/migrations/20230919_workspace_upgraded_at.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async function main() {
}

async function markWorkspaceAsUpgraded(workspace: Workspace) {
if (!workspace.upgradedAt && workspace.plan) {
if (!workspace.upgradedAt) {
const updatedAt = workspace.updatedAt;
await workspace.update({
upgradedAt: updatedAt,
Expand Down
9 changes: 1 addition & 8 deletions front/pages/api/poke/workspaces/[wId]/downgrade.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import { NextApiRequest, NextApiResponse } from "next";

import { downgradeWorkspace } from "@app/lib/api/workspace";
import {
getSession,
getUserFromSession,
planForWorkspace,
} from "@app/lib/auth";
import { getSession, getUserFromSession } from "@app/lib/auth";
import { ReturnedAPIErrorType } from "@app/lib/error";
import { Workspace } from "@app/lib/models";
import { apiError, withLogging } from "@app/logger/withlogging";
Expand Down Expand Up @@ -74,16 +70,13 @@ async function handler(

await downgradeWorkspace(workspace.id);

const plan = await planForWorkspace(workspace);

return res.status(200).json({
workspace: {
id: workspace.id,
sId: workspace.sId,
name: workspace.name,
allowedDomain: workspace.allowedDomain || null,
role: "admin",
plan,
upgradedAt: workspace.upgradedAt?.getTime() || null,
},
});
Expand Down
8 changes: 1 addition & 7 deletions front/pages/api/poke/workspaces/[wId]/upgrade.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import { NextApiRequest, NextApiResponse } from "next";

import { upgradeWorkspace } from "@app/lib/api/workspace";
import {
getSession,
getUserFromSession,
planForWorkspace,
} from "@app/lib/auth";
import { getSession, getUserFromSession } from "@app/lib/auth";
import { ReturnedAPIErrorType } from "@app/lib/error";
import { Workspace } from "@app/lib/models";
import { apiError, withLogging } from "@app/logger/withlogging";
Expand Down Expand Up @@ -73,7 +69,6 @@ async function handler(
}

await upgradeWorkspace(workspace.id);
const plan = await planForWorkspace(workspace);

return res.status(200).json({
workspace: {
Expand All @@ -82,7 +77,6 @@ async function handler(
name: workspace.name,
allowedDomain: workspace.allowedDomain || null,
role: "admin",
plan,
upgradedAt: workspace.upgradedAt?.getTime() || null,
},
});
Expand Down
11 changes: 3 additions & 8 deletions front/pages/api/poke/workspaces/index.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import { NextApiRequest, NextApiResponse } from "next";
import { FindOptions, Op, WhereOptions } from "sequelize";

import {
getSession,
getUserFromSession,
planForWorkspace,
} from "@app/lib/auth";
import { getSession, getUserFromSession } from "@app/lib/auth";
import { ReturnedAPIErrorType } from "@app/lib/error";
import { Workspace } from "@app/lib/models";
import { apiError, withLogging } from "@app/logger/withlogging";
Expand Down Expand Up @@ -102,13 +98,13 @@ async function handler(
if (upgraded !== undefined) {
if (upgraded) {
conditions.push({
plan: {
upgradedAt: {
[Op.not]: null,
},
});
} else {
conditions.push({
plan: null,
upgradedAt: null,
});
}
}
Expand Down Expand Up @@ -144,7 +140,6 @@ async function handler(
sId: ws.sId,
name: ws.name,
allowedDomain: ws.allowedDomain || null,
plan: planForWorkspace(ws),
role: "admin",
upgradedAt: ws.upgradedAt?.getTime() || null,
})),
Expand Down
Loading

0 comments on commit dee5671

Please sign in to comment.