From aaaaf5f7eee961e913e295953c8905ff1dbeb4c8 Mon Sep 17 00:00:00 2001 From: Alban Dumouilla Date: Tue, 17 Dec 2024 09:41:49 +0100 Subject: [PATCH] Add ApiKeyCredentials to oAuth creds (#9283) * Add ApiKeyCredentials to postCredentials and getCredentials * lint * Rename ModjoCredentials and add oauth modjo support * Lint * snake_case + make errors generic * SnowflakeCredentials Typeguard * Add more snowflake typeguards --- connectors/src/connectors/snowflake/index.ts | 29 +++++++++++++++- .../src/connectors/snowflake/lib/utils.ts | 33 +++++++++++++++++-- .../snowflake/temporal/activities.ts | 7 ++++ .../remote_databases/get_remote_database.rs | 3 ++ .../remote_databases/remote_database.rs | 4 +++ core/src/oauth/credential.rs | 4 +++ types/src/oauth/lib.ts | 16 +++++++-- 7 files changed, 90 insertions(+), 6 deletions(-) diff --git a/connectors/src/connectors/snowflake/index.ts b/connectors/src/connectors/snowflake/index.ts index 4e9652317fd7..063d0ae9ca67 100644 --- a/connectors/src/connectors/snowflake/index.ts +++ b/connectors/src/connectors/snowflake/index.ts @@ -4,7 +4,7 @@ import type { ContentNodesViewType, Result, } from "@dust-tt/types"; -import { assertNever, Err, Ok } from "@dust-tt/types"; +import { assertNever, Err, isSnowflakeCredentials, Ok } from "@dust-tt/types"; import type { CreateConnectorErrorCode, @@ -73,6 +73,15 @@ export class SnowflakeConnectorManager extends BaseConnectorManager { } const credentials = credentialsRes.value.credentials; + if (!isSnowflakeCredentials(credentials)) { + return new Err( + new ConnectorManagerError( + "INVALID_CONFIGURATION", + "Invalid credentials type - expected snowflake credentials" + ) + ); + } + // Then we test the connection is successful. const connectionRes = await testConnection({ credentials }); if (connectionRes.isErr()) { @@ -130,6 +139,15 @@ export class SnowflakeConnectorManager extends BaseConnectorManager { const newCredentials = newCredentialsRes.value.credentials; + if (!isSnowflakeCredentials(newCredentials)) { + return new Err( + new ConnectorManagerError( + "INVALID_CONFIGURATION", + "Invalid credentials type - expected snowflake credentials" + ) + ); + } + const connectionRes = await testConnection({ credentials: newCredentials }); if (connectionRes.isErr()) { return new Err( @@ -262,6 +280,15 @@ export class SnowflakeConnectorManager extends BaseConnectorManager { }); } + if (!isSnowflakeCredentials(credentials)) { + return new Err( + new ConnectorManagerError( + "INVALID_CONFIGURATION", + "Invalid credentials type - expected snowflake credentials" + ) + ); + } + // We display all available nodes with our credentials. return fetchAvailableChildrenInSnowflake({ connectorId: connector.id, diff --git a/connectors/src/connectors/snowflake/lib/utils.ts b/connectors/src/connectors/snowflake/lib/utils.ts index 61020fc9e51a..cea22576a21f 100644 --- a/connectors/src/connectors/snowflake/lib/utils.ts +++ b/connectors/src/connectors/snowflake/lib/utils.ts @@ -1,5 +1,10 @@ import type { ConnectionCredentials, ModelId, Result } from "@dust-tt/types"; -import { Err, getConnectionCredentials, Ok } from "@dust-tt/types"; +import { + Err, + getConnectionCredentials, + isSnowflakeCredentials, + Ok, +} from "@dust-tt/types"; import { apiConfig } from "@connectors/lib/api/config"; import type { Logger } from "@connectors/logger/logger"; @@ -51,8 +56,19 @@ export const getCredentials = async ({ logger.error({ credentialsId }, "Failed to retrieve credentials"); return new Err(Error("Failed to retrieve credentials")); } + // Narrow the type of credentials to just the username/password variant + const credentials = credentialsRes.value.credential.content; + if (!isSnowflakeCredentials(credentials)) { + logger.error( + { credentialsId }, + "Invalid credentials type - expected snowflake credentials" + ); + return new Err( + Error("Invalid credentials type - expected snowflake credentials") + ); + } return new Ok({ - credentials: credentialsRes.value.credential.content, + credentials, }); }; @@ -86,8 +102,19 @@ export const getConnectorAndCredentials = async ({ logger.error({ connectorId }, "Failed to retrieve credentials"); return new Err(Error("Failed to retrieve credentials")); } + // Narrow the type of credentials to just the username/password variant + const credentials = credentialsRes.value.credential.content; + if (!isSnowflakeCredentials(credentials)) { + logger.error( + { connectorId }, + "Invalid credentials type - expected snowflake credentials" + ); + return new Err( + Error("Invalid credentials type - expected snowflake credentials") + ); + } return new Ok({ connector, - credentials: credentialsRes.value.credential.content, + credentials, }); }; diff --git a/connectors/src/connectors/snowflake/temporal/activities.ts b/connectors/src/connectors/snowflake/temporal/activities.ts index 5afc4dbad35c..cf2b806d8c44 100644 --- a/connectors/src/connectors/snowflake/temporal/activities.ts +++ b/connectors/src/connectors/snowflake/temporal/activities.ts @@ -1,4 +1,5 @@ import type { ModelId } from "@dust-tt/types"; +import { isSnowflakeCredentials } from "@dust-tt/types"; import { connectToSnowflake, @@ -36,6 +37,12 @@ export async function syncSnowflakeConnection(connectorId: ModelId) { const { credentials, connector } = getConnectorAndCredentialsRes.value; + if (!isSnowflakeCredentials(credentials)) { + throw new Error( + "Invalid credentials type - expected snowflake credentials" + ); + } + const connectionRes = await connectToSnowflake(credentials); if (connectionRes.isErr()) { throw connectionRes.error; diff --git a/core/src/databases/remote_databases/get_remote_database.rs b/core/src/databases/remote_databases/get_remote_database.rs index 5bce19650670..9dca1d73e06e 100644 --- a/core/src/databases/remote_databases/get_remote_database.rs +++ b/core/src/databases/remote_databases/get_remote_database.rs @@ -17,5 +17,8 @@ pub async fn get_remote_database( let db = SnowflakeRemoteDatabase::new(content)?; Ok(Box::new(db) as Box) } + _ => { + anyhow::bail!("{:?} is not a supported remote database provider", provider) + } } } diff --git a/core/src/databases/remote_databases/remote_database.rs b/core/src/databases/remote_databases/remote_database.rs index 6c1ae0f6b719..2f2ae5bb65d5 100644 --- a/core/src/databases/remote_databases/remote_database.rs +++ b/core/src/databases/remote_databases/remote_database.rs @@ -35,5 +35,9 @@ pub async fn get_remote_database( let db = SnowflakeRemoteDatabase::new(content)?; Ok(Box::new(db) as Box) } + _ => Err(anyhow::anyhow!( + "Provider {} is not a remote database", + provider + )), } } diff --git a/core/src/oauth/credential.rs b/core/src/oauth/credential.rs index b15789d0f097..06d9ffbdda67 100644 --- a/core/src/oauth/credential.rs +++ b/core/src/oauth/credential.rs @@ -13,6 +13,7 @@ use super::encryption::{seal_str, unseal_str}; #[serde(rename_all = "snake_case")] pub enum CredentialProvider { Snowflake, + Modjo, } impl fmt::Display for CredentialProvider { @@ -107,6 +108,9 @@ impl Credential { CredentialProvider::Snowflake => { vec!["account", "warehouse", "username", "password", "role"] } + CredentialProvider::Modjo => { + vec!["api_key"] + } }; for key in keys_to_check { diff --git a/types/src/oauth/lib.ts b/types/src/oauth/lib.ts index c133db1e32a3..62de1f17826d 100644 --- a/types/src/oauth/lib.ts +++ b/types/src/oauth/lib.ts @@ -56,7 +56,7 @@ export function isValidZendeskSubdomain(s: unknown): s is string { // Credentials Providers -export const CREDENTIALS_PROVIDERS = ["snowflake"] as const; +export const CREDENTIALS_PROVIDERS = ["snowflake", "modjo"] as const; export type CredentialsProvider = (typeof CREDENTIALS_PROVIDERS)[number]; export function isCredentialProvider(obj: unknown): obj is CredentialsProvider { @@ -73,7 +73,19 @@ export const SnowflakeCredentialsSchema = t.type({ warehouse: t.string, }); export type SnowflakeCredentials = t.TypeOf; -export type ConnectionCredentials = SnowflakeCredentials; + +export const ApiKeyCredentialsSchema = t.type({ + api_key: t.string, +}); +export type ModjoCredentials = t.TypeOf; + +export type ConnectionCredentials = SnowflakeCredentials | ModjoCredentials; + +export function isSnowflakeCredentials( + credentials: ConnectionCredentials +): credentials is SnowflakeCredentials { + return "username" in credentials && "password" in credentials; +} // POST Credentials