Skip to content

Commit

Permalink
Add ApiKeyCredentials to oAuth creds (#9283)
Browse files Browse the repository at this point in the history
* Add ApiKeyCredentials to postCredentials and getCredentials

* lint

* Rename ModjoCredentials and add oauth modjo support

* Lint

* snake_case + make errors generic

* SnowflakeCredentials Typeguard

* Add more snowflake typeguards
  • Loading branch information
albandum authored Dec 17, 2024
1 parent 324bf2d commit aaaaf5f
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 6 deletions.
29 changes: 28 additions & 1 deletion connectors/src/connectors/snowflake/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type {
ContentNodesViewType,
Result,
} from "@dust-tt/types";
import { assertNever, Err, Ok } from "@dust-tt/types";
import { assertNever, Err, isSnowflakeCredentials, Ok } from "@dust-tt/types";

import type {
CreateConnectorErrorCode,
Expand Down Expand Up @@ -73,6 +73,15 @@ export class SnowflakeConnectorManager extends BaseConnectorManager<null> {
}
const credentials = credentialsRes.value.credentials;

if (!isSnowflakeCredentials(credentials)) {
return new Err(
new ConnectorManagerError(
"INVALID_CONFIGURATION",
"Invalid credentials type - expected snowflake credentials"
)
);
}

// Then we test the connection is successful.
const connectionRes = await testConnection({ credentials });
if (connectionRes.isErr()) {
Expand Down Expand Up @@ -130,6 +139,15 @@ export class SnowflakeConnectorManager extends BaseConnectorManager<null> {

const newCredentials = newCredentialsRes.value.credentials;

if (!isSnowflakeCredentials(newCredentials)) {
return new Err(
new ConnectorManagerError(
"INVALID_CONFIGURATION",
"Invalid credentials type - expected snowflake credentials"
)
);
}

const connectionRes = await testConnection({ credentials: newCredentials });
if (connectionRes.isErr()) {
return new Err(
Expand Down Expand Up @@ -262,6 +280,15 @@ export class SnowflakeConnectorManager extends BaseConnectorManager<null> {
});
}

if (!isSnowflakeCredentials(credentials)) {
return new Err(
new ConnectorManagerError(
"INVALID_CONFIGURATION",
"Invalid credentials type - expected snowflake credentials"
)
);
}

// We display all available nodes with our credentials.
return fetchAvailableChildrenInSnowflake({
connectorId: connector.id,
Expand Down
33 changes: 30 additions & 3 deletions connectors/src/connectors/snowflake/lib/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import type { ConnectionCredentials, ModelId, Result } from "@dust-tt/types";
import { Err, getConnectionCredentials, Ok } from "@dust-tt/types";
import {
Err,
getConnectionCredentials,
isSnowflakeCredentials,
Ok,
} from "@dust-tt/types";

import { apiConfig } from "@connectors/lib/api/config";
import type { Logger } from "@connectors/logger/logger";
Expand Down Expand Up @@ -51,8 +56,19 @@ export const getCredentials = async ({
logger.error({ credentialsId }, "Failed to retrieve credentials");
return new Err(Error("Failed to retrieve credentials"));
}
// Narrow the type of credentials to just the username/password variant
const credentials = credentialsRes.value.credential.content;
if (!isSnowflakeCredentials(credentials)) {
logger.error(
{ credentialsId },
"Invalid credentials type - expected snowflake credentials"
);
return new Err(
Error("Invalid credentials type - expected snowflake credentials")
);
}
return new Ok({
credentials: credentialsRes.value.credential.content,
credentials,
});
};

Expand Down Expand Up @@ -86,8 +102,19 @@ export const getConnectorAndCredentials = async ({
logger.error({ connectorId }, "Failed to retrieve credentials");
return new Err(Error("Failed to retrieve credentials"));
}
// Narrow the type of credentials to just the username/password variant
const credentials = credentialsRes.value.credential.content;
if (!isSnowflakeCredentials(credentials)) {
logger.error(
{ connectorId },
"Invalid credentials type - expected snowflake credentials"
);
return new Err(
Error("Invalid credentials type - expected snowflake credentials")
);
}
return new Ok({
connector,
credentials: credentialsRes.value.credential.content,
credentials,
});
};
7 changes: 7 additions & 0 deletions connectors/src/connectors/snowflake/temporal/activities.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ModelId } from "@dust-tt/types";
import { isSnowflakeCredentials } from "@dust-tt/types";

import {
connectToSnowflake,
Expand Down Expand Up @@ -36,6 +37,12 @@ export async function syncSnowflakeConnection(connectorId: ModelId) {

const { credentials, connector } = getConnectorAndCredentialsRes.value;

if (!isSnowflakeCredentials(credentials)) {
throw new Error(
"Invalid credentials type - expected snowflake credentials"
);
}

const connectionRes = await connectToSnowflake(credentials);
if (connectionRes.isErr()) {
throw connectionRes.error;
Expand Down
3 changes: 3 additions & 0 deletions core/src/databases/remote_databases/get_remote_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,8 @@ pub async fn get_remote_database(
let db = SnowflakeRemoteDatabase::new(content)?;
Ok(Box::new(db) as Box<dyn RemoteDatabase + Sync + Send>)
}
_ => {
anyhow::bail!("{:?} is not a supported remote database provider", provider)
}
}
}
4 changes: 4 additions & 0 deletions core/src/databases/remote_databases/remote_database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,9 @@ pub async fn get_remote_database(
let db = SnowflakeRemoteDatabase::new(content)?;
Ok(Box::new(db) as Box<dyn RemoteDatabase + Sync + Send>)
}
_ => Err(anyhow::anyhow!(
"Provider {} is not a remote database",
provider
)),
}
}
4 changes: 4 additions & 0 deletions core/src/oauth/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use super::encryption::{seal_str, unseal_str};
#[serde(rename_all = "snake_case")]
pub enum CredentialProvider {
Snowflake,
Modjo,
}

impl fmt::Display for CredentialProvider {
Expand Down Expand Up @@ -107,6 +108,9 @@ impl Credential {
CredentialProvider::Snowflake => {
vec!["account", "warehouse", "username", "password", "role"]
}
CredentialProvider::Modjo => {
vec!["api_key"]
}
};

for key in keys_to_check {
Expand Down
16 changes: 14 additions & 2 deletions types/src/oauth/lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ export function isValidZendeskSubdomain(s: unknown): s is string {

// Credentials Providers

export const CREDENTIALS_PROVIDERS = ["snowflake"] as const;
export const CREDENTIALS_PROVIDERS = ["snowflake", "modjo"] as const;
export type CredentialsProvider = (typeof CREDENTIALS_PROVIDERS)[number];

export function isCredentialProvider(obj: unknown): obj is CredentialsProvider {
Expand All @@ -73,7 +73,19 @@ export const SnowflakeCredentialsSchema = t.type({
warehouse: t.string,
});
export type SnowflakeCredentials = t.TypeOf<typeof SnowflakeCredentialsSchema>;
export type ConnectionCredentials = SnowflakeCredentials;

export const ApiKeyCredentialsSchema = t.type({
api_key: t.string,
});
export type ModjoCredentials = t.TypeOf<typeof ApiKeyCredentialsSchema>;

export type ConnectionCredentials = SnowflakeCredentials | ModjoCredentials;

export function isSnowflakeCredentials(
credentials: ConnectionCredentials
): credentials is SnowflakeCredentials {
return "username" in credentials && "password" in credentials;
}

// POST Credentials

Expand Down

0 comments on commit aaaaf5f

Please sign in to comment.