Skip to content

Commit

Permalink
feat: update types
Browse files Browse the repository at this point in the history
  • Loading branch information
himanshu-dixit committed Dec 24, 2024
1 parent 41954a3 commit f9b8813
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 87 deletions.
6 changes: 3 additions & 3 deletions js/src/frameworks/vercel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { z } from "zod";
import { ComposioToolSet as BaseComposioToolSet } from "../sdk/base.toolset";
import { TELEMETRY_LOGGER } from "../sdk/utils/telemetry";
import { TELEMETRY_EVENTS } from "../sdk/utils/telemetry/events";
import { TRawActionData } from "../types/base_toolset";
import { RawActionData } from "../types/base_toolset";
import { jsonSchemaToModel } from "../utils/shared";
type Optional<T> = T | null;

Expand Down Expand Up @@ -37,7 +37,7 @@ export class VercelAIToolSet extends BaseComposioToolSet {
);
}

private generateVercelTool(schema: TRawActionData) {
private generateVercelTool(schema: RawActionData) {
const parameters = jsonSchemaToModel(schema.parameters);
return tool({
description: schema.description,
Expand All @@ -62,7 +62,7 @@ export class VercelAIToolSet extends BaseComposioToolSet {
useCase?: Optional<string>;
usecaseLimit?: Optional<number>;
filterByAvailableApps?: Optional<boolean>;
}): Promise<{ [key: string]: TRawActionData }> {
}): Promise<{ [key: string]: RawActionData }> {
TELEMETRY_LOGGER.manualTelemetry(TELEMETRY_EVENTS.SDK_METHOD_INVOKED, {
method: "getTools",
file: this.fileName,
Expand Down
47 changes: 31 additions & 16 deletions js/src/sdk/actionRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@ import { ZodObject, ZodOptional, ZodString, z } from "zod";
import { JsonSchema7Type, zodToJsonSchema } from "zod-to-json-schema";
import { Composio } from ".";
import apiClient from "../sdk/client/client";
import { TRawActionData } from "../types/base_toolset";
import { ActionProxyRequestConfigDTO } from "./client";
import { RawActionData } from "../types/base_toolset";
import { ActionProxyRequestConfigDTO, Parameter } from "./client";
import { ActionExecuteResponse } from "./models/actions";
import { CEG } from "./utils/error";
import { SDK_ERROR_CODES } from "./utils/errors/src/constants";

type RawExecuteRequestParam = {
connectedAccountId: string;
endpoint: string;
method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE";
parameters: Array<Parameter>;
body?: {
[key: string]: unknown;
};
};

type ExecuteRequest = Omit<ActionProxyRequestConfigDTO, "connectedAccountId">;
export type CreateActionOptions = {
actionName?: string;
toolName?: string;
Expand All @@ -15,8 +26,10 @@ export type CreateActionOptions = {
callback: (
inputParams: Record<string, string>,
authCredentials: Record<string, string> | undefined,
executeRequest: (data: ExecuteRequest) => Promise<Record<string, unknown>>
) => Promise<Record<string, unknown>>;
executeRequest: (
data: RawExecuteRequestParam
) => Promise<ActionExecuteResponse>
) => Promise<ActionExecuteResponse>;
};

interface ParamsSchema {
Expand Down Expand Up @@ -45,7 +58,7 @@ export class ActionRegistry {
this.customActions = new Map();
}

async createAction(options: CreateActionOptions): Promise<TRawActionData> {
async createAction(options: CreateActionOptions): Promise<RawActionData> {
const { callback } = options;
if (typeof callback !== "function") {
throw new Error("Callback must be a function");
Expand Down Expand Up @@ -82,36 +95,36 @@ export class ActionRegistry {
metadata: options,
schema: composioSchema,
});
return composioSchema as unknown as TRawActionData;
return composioSchema as unknown as RawActionData;
}

async getActions({
actions,
}: {
actions: Array<string>;
}): Promise<Array<TRawActionData>> {
const actionsArr: Array<TRawActionData> = [];
}): Promise<Array<RawActionData>> {
const actionsArr: Array<RawActionData> = [];
for (const name of actions) {
const lowerCaseName = name.toLowerCase();
if (this.customActions.has(lowerCaseName)) {
const action = this.customActions.get(lowerCaseName);
actionsArr.push(action!.schema as TRawActionData);
actionsArr.push(action!.schema as RawActionData);
}
}
return actionsArr;
}

async getAllActions(): Promise<Array<TRawActionData>> {
async getAllActions(): Promise<Array<RawActionData>> {
return Array.from(this.customActions.values()).map(
(action) => action.schema as TRawActionData
(action) => action.schema as RawActionData
);
}

async executeAction(
name: string,
inputParams: Record<string, unknown>,
metadata: ExecuteMetadata
): Promise<Record<string, unknown>> {
): Promise<ActionExecuteResponse | Record<string, unknown>> {
const lowerCaseName = name.toLocaleLowerCase();
if (!this.customActions.has(lowerCaseName)) {
throw new Error(`Action with name ${name} does not exist`);
Expand Down Expand Up @@ -145,10 +158,12 @@ export class ActionRegistry {
};
}
if (typeof callback !== "function") {
throw new Error("Callback must be a function");
throw CEG.getCustomError(SDK_ERROR_CODES.COMMON.INVALID_PARAMS_PASSED, {
message: "Callback must be a function",
});
}

const executeRequest = async (data: ExecuteRequest) => {
const executeRequest = async (data: RawExecuteRequestParam) => {
try {
const { data: res } = await apiClient.actionsV2.executeWithHttpClient({
body: {
Expand All @@ -165,7 +180,7 @@ export class ActionRegistry {
return await callback(
inputParams as Record<string, string>,
authCredentials,
(data: ExecuteRequest) => executeRequest(data)
(data: RawExecuteRequestParam) => executeRequest(data)
);
}
}
16 changes: 10 additions & 6 deletions js/src/sdk/base.toolset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ import { z } from "zod";
import { Composio } from "../sdk";
import type { Optional, Sequence } from "../types/base";
import {
RawActionData,
TPostProcessor,
TPreProcessor,
TRawActionData,
TSchemaProcessor,
ZExecuteActionParams,
ZToolSchemaFilter,
Expand All @@ -13,13 +13,15 @@ import { getEnvVariable } from "../utils/shared";
import { ActionRegistry, CreateActionOptions } from "./actionRegistry";
import { COMPOSIO_BASE_URL } from "./client/core/OpenAPI";
import { ActionExecutionResDto } from "./client/types.gen";
import { ActionExecuteResponse } from "./models/actions";
import { getUserDataJson } from "./utils/config";
import {
fileInputProcessor,
fileResponseProcessor,
fileSchemaProcessor,
} from "./utils/processor/file";

export type ExecuteActionParams = z.infer<typeof ZExecuteActionParams>;
export class ComposioToolSet {
client: Composio;
apiKey: string;
Expand Down Expand Up @@ -80,7 +82,7 @@ export class ComposioToolSet {
async getToolsSchema(
filters: z.infer<typeof ZToolSchemaFilter>,
_entityId?: Optional<string>
): Promise<TRawActionData[]> {
): Promise<RawActionData[]> {
const parsedFilters = ZToolSchemaFilter.parse(filters);

const apps = await this.client.actions.list({
Expand Down Expand Up @@ -120,7 +122,7 @@ export class ComposioToolSet {
];

return toolsActions.map((tool) => {
let schema = tool as TRawActionData;
let schema = tool as RawActionData;
allSchemaProcessor.forEach((processor) => {
schema = processor({
actionName: schema?.metadata?.actionName || "",
Expand All @@ -142,7 +144,9 @@ export class ComposioToolSet {
.then((actions) => actions.length > 0);
}

async executeAction(functionParams: z.infer<typeof ZExecuteActionParams>) {
async executeAction(
functionParams: ExecuteActionParams
): Promise<ActionExecuteResponse> {
const {
action,
params: inputParams = {},
Expand Down Expand Up @@ -189,11 +193,11 @@ export class ComposioToolSet {
});
}

const data = (await this.client.getEntity(entityId).execute({
const data = await this.client.getEntity(entityId).execute({
actionName: action,
params: params,
text: nlaText,
})) as ActionExecutionResDto;
});

return this.processResponse(data, {
action: action,
Expand Down
2 changes: 1 addition & 1 deletion js/src/sdk/client/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@ export type ActionExecutionResDto = {
* Whether the action execution was successfully executed or not. If this is false, error field will be populated with the error message.
* @deprecated
*/
successfull: boolean;
successfull?: boolean;
/**
* Whether the action execution was successfully executed or not. If this is false, error field will be populated with the error message.
*/
Expand Down
66 changes: 19 additions & 47 deletions js/src/sdk/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import axios from "axios";
import { z } from "zod";
import logger from "../utils/logger";
import { GetConnectorInfoResDTO } from "./client";
import { Entity } from "./models/Entity";
import { Actions } from "./models/actions";
import { ActiveTriggers } from "./models/activeTriggers";
Expand All @@ -18,6 +18,17 @@ import { getPackageJsonDir } from "./utils/projectUtils";
import { TELEMETRY_LOGGER } from "./utils/telemetry";
import { TELEMETRY_EVENTS } from "./utils/telemetry/events";

import {
ZGetExpectedParamsForUserParams,
ZGetExpectedParamsRes,
} from "../types/composio";
import { ZAuthMode } from "./types/integration";

export type ComposioInputFieldsParams = z.infer<
typeof ZGetExpectedParamsForUserParams
>;
export type ComposioInputFieldsRes = z.infer<typeof ZGetExpectedParamsRes>;

export class Composio {
/**
* The Composio class serves as the main entry point for interacting with the Composio SDK.
Expand Down Expand Up @@ -138,29 +149,8 @@ export class Composio {
}

async getExpectedParamsForUser(
params: {
app?: string;
integrationId?: string;
entityId?: string;
authScheme?:
| "OAUTH2"
| "OAUTH1"
| "API_KEY"
| "BASIC"
| "BEARER_TOKEN"
| "BASIC_WITH_JWT";
} = {}
): Promise<{
expectedInputFields: GetConnectorInfoResDTO["expectedInputFields"];
integrationId: string;
authScheme:
| "OAUTH2"
| "OAUTH1"
| "API_KEY"
| "BASIC"
| "BEARER_TOKEN"
| "BASIC_WITH_JWT";
}> {
params: ComposioInputFieldsParams
): Promise<ComposioInputFieldsRes> {
TELEMETRY_LOGGER.manualTelemetry(TELEMETRY_EVENTS.SDK_METHOD_INVOKED, {
method: "getExpectedParamsForUser",
file: this.fileName,
Expand Down Expand Up @@ -199,13 +189,7 @@ export class Composio {
return {
expectedInputFields: integration.expectedInputFields,
integrationId: integration.id!,
authScheme: integration.authScheme as
| "OAUTH2"
| "OAUTH1"
| "API_KEY"
| "BASIC"
| "BEARER_TOKEN"
| "BASIC_WITH_JWT",
authScheme: integration.authScheme as z.infer<typeof ZAuthMode>,
};
}

Expand Down Expand Up @@ -268,21 +252,15 @@ export class Composio {
integration = await this.integrations.create({
appId: appInfo.appId,
name: `integration_${timestamp}`,
authScheme: schema,
authScheme: schema as z.infer<typeof ZAuthMode>,
authConfig: {},
useComposioAuth: true,
});

return {
expectedInputFields: integration?.expectedInputFields!,
integrationId: integration?.id!,
authScheme: integration?.authScheme as
| "OAUTH2"
| "OAUTH1"
| "API_KEY"
| "BASIC"
| "BEARER_TOKEN"
| "BASIC_WITH_JWT",
authScheme: integration?.authScheme as z.infer<typeof ZAuthMode>,
};
}

Expand All @@ -297,7 +275,7 @@ export class Composio {
integration = await this.integrations.create({
appId: appInfo.appId,
name: `integration_${timestamp}`,
authScheme: schema,
authScheme: schema as z.infer<typeof ZAuthMode>,
authConfig: {},
useComposioAuth: false,
});
Expand All @@ -310,13 +288,7 @@ export class Composio {
return {
expectedInputFields: integration.expectedInputFields,
integrationId: integration.id!,
authScheme: integration.authScheme as
| "OAUTH2"
| "OAUTH1"
| "API_KEY"
| "BASIC"
| "BEARER_TOKEN"
| "BASIC_WITH_JWT",
authScheme: integration.authScheme as z.infer<typeof ZAuthMode>,
};
}
}
7 changes: 3 additions & 4 deletions js/src/sdk/models/Entity.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { z } from "zod";
import logger from "../../utils/logger";
import { ActionExecutionResDto, GetConnectionsResponseDto } from "../client";
import { GetConnectionsResponseDto } from "../client";
import {
ZConnectionParams,
ZExecuteActionParams,
Expand All @@ -12,7 +12,7 @@ import { CEG } from "../utils/error";
import { SDK_ERROR_CODES } from "../utils/errors/src/constants";
import { TELEMETRY_LOGGER } from "../utils/telemetry";
import { TELEMETRY_EVENTS } from "../utils/telemetry/events";
import { Actions } from "./actions";
import { ActionExecuteResponse, Actions } from "./actions";
import { ActiveTriggers } from "./activeTriggers";
import { Apps } from "./apps";
import { BackendClient } from "./backendClient";
Expand All @@ -35,7 +35,6 @@ type InitiateConnectionParams = z.infer<typeof ZInitiateConnectionParams>;
type ExecuteActionParams = z.infer<typeof ZExecuteActionParams>;

// type from API
export type ExecuteActionRes = ActionExecutionResDto;
export type ConnectedAccountListRes = GetConnectionsResponseDto;

export class Entity {
Expand Down Expand Up @@ -66,7 +65,7 @@ export class Entity {
params,
text,
connectedAccountId,
}: ExecuteActionParams): Promise<ExecuteActionRes> {
}: ExecuteActionParams): Promise<ActionExecuteResponse> {
TELEMETRY_LOGGER.manualTelemetry(TELEMETRY_EVENTS.SDK_METHOD_INVOKED, {
method: "execute",
file: this.fileName,
Expand Down
Loading

0 comments on commit f9b8813

Please sign in to comment.