diff --git a/connectors/src/connectors/snowflake/lib/content_nodes.ts b/connectors/src/connectors/snowflake/lib/content_nodes.ts index 0cc4f0ce95e8..f94659e8c03d 100644 --- a/connectors/src/connectors/snowflake/lib/content_nodes.ts +++ b/connectors/src/connectors/snowflake/lib/content_nodes.ts @@ -1,19 +1,11 @@ import type { ConnectorPermission, ContentNode } from "@dust-tt/types"; -/** - * Some databases and schemas are not useful to show in the content tree. - * We exclude them here. - */ -export const EXCLUDE_DATABASES = ["SNOWFLAKE"]; -export const EXCLUDE_SCHEMAS = ["INFORMATION_SCHEMA"]; - /** * 3 types of nodes in a remote database content tree: * - database: internalId = "database_name" * - schema: internalId = "database_name.schema_name" * - table: internalId = "database_name.schema_name.table_name" */ - export type REMOTE_DB_CONTENT_NODE_TYPES = "database" | "schema" | "table"; export const getContentNodeTypeFromInternalId = ( diff --git a/connectors/src/connectors/snowflake/lib/permissions.ts b/connectors/src/connectors/snowflake/lib/permissions.ts index a39b66d6324a..84b825e96093 100644 --- a/connectors/src/connectors/snowflake/lib/permissions.ts +++ b/connectors/src/connectors/snowflake/lib/permissions.ts @@ -4,11 +4,9 @@ import type { Result, SnowflakeCredentials, } from "@dust-tt/types"; -import { Err, Ok } from "@dust-tt/types"; +import { Err, EXCLUDE_DATABASES, EXCLUDE_SCHEMAS, Ok } from "@dust-tt/types"; import { - EXCLUDE_DATABASES, - EXCLUDE_SCHEMAS, getContentNodeFromInternalId, getContentNodeTypeFromInternalId, } from "@connectors/connectors/snowflake/lib/content_nodes"; diff --git a/connectors/src/connectors/snowflake/lib/snowflake_api.ts b/connectors/src/connectors/snowflake/lib/snowflake_api.ts index 75769b27da30..9bf1f7f7c7b9 100644 --- a/connectors/src/connectors/snowflake/lib/snowflake_api.ts +++ b/connectors/src/connectors/snowflake/lib/snowflake_api.ts @@ -1,6 +1,9 @@ import type { Result } from "@dust-tt/types"; import type { SnowflakeCredentials } from "@dust-tt/types"; -import { Err, Ok } from "@dust-tt/types"; +import { Err, EXCLUDE_DATABASES, EXCLUDE_SCHEMAS, Ok } from "@dust-tt/types"; +import { isLeft } from "fp-ts/lib/Either"; +import * as t from "io-ts"; +import * as reporter from "io-ts-reporters"; import type { Connection, RowStatement, SnowflakeError } from "snowflake-sdk"; import snowflake from "snowflake-sdk"; @@ -8,6 +11,31 @@ import snowflake from "snowflake-sdk"; export type SnowflakeRow = Record; export type SnowflakeRows = Array; +const snowflakeDatabaseCodec = t.type({ + name: t.string, +}); +type SnowflakeDatabase = t.TypeOf; + +const snowflakeSchemaCodec = t.type({ + name: t.string, + database_name: t.string, +}); +type SnowflakeSchema = t.TypeOf; + +const snowflakeTableCodec = t.type({ + name: t.string, + database_name: t.string, + schema_name: t.string, +}); +type SnowflakeTable = t.TypeOf; + +const snowflakeGrantCodec = t.type({ + privilege: t.string, + grant_on: t.string, + name: t.string, +}); +type SnowflakeGrant = t.TypeOf; + /** * Test the connection to Snowflake with the provided credentials. * Used to check if the credentials are valid and the connection is successful. @@ -17,44 +45,92 @@ export const testConnection = async ({ }: { credentials: SnowflakeCredentials; }): Promise> => { - const connection = snowflake.createConnection({ - ...credentials, - - // Use proxy if defined to have all requests coming from the same IP. - proxyHost: process.env.PROXY_HOST, - proxyPort: process.env.PROXY_PORT - ? parseInt(process.env.PROXY_PORT) - : undefined, - proxyUser: process.env.PROXY_USER_NAME, - proxyPassword: process.env.PROXY_USER_PASSWORD, - }); + // Connect to snowflake, fetch tables and grants, and close the connection. + const connectionRes = await connectToSnowflake(credentials); + if (connectionRes.isErr()) { + return connectionRes; + } + const connection = connectionRes.value; + const tablesRes = await fetchTables({ credentials, connection }); + const grantsRes = await isConnectionReadonly({ credentials, connection }); + const closeConnectionRes = await _closeConnection(connection); + if (closeConnectionRes.isErr()) { + return closeConnectionRes; + } + + if (grantsRes.isErr()) { + return grantsRes; + } + if (tablesRes.isErr()) { + return tablesRes; + } + + const tables = tablesRes.value.filter( + (t) => + !EXCLUDE_DATABASES.includes(t.database_name) && + !EXCLUDE_SCHEMAS.includes(t.schema_name) + ); + if (tables.length === 0) { + return new Err(new Error("No tables found or no access to any table.")); + } + + return new Ok("Connection successful"); +}; + +export async function connectToSnowflake( + credentials: SnowflakeCredentials +): Promise> { + snowflake.configure({ + // @ts-expect-error OFF is not in the types but it's a valid value. + logLevel: "OFF", + }); try { - const conn = await _connectToSnowflake(connection); - // TODO(SNOWFLAKE): Improve checks: we want to make sure we have read and read-only access. - const rows = await _executeQuery(conn, "SHOW TABLES"); + const connection = await new Promise((resolve, reject) => { + snowflake + .createConnection({ + ...credentials, - if (!rows || rows.length === 0) { - throw new Error("No tables found or no access to any tables"); - } + // Use proxy if defined to have all requests coming from the same IP. + proxyHost: process.env.PROXY_HOST, + proxyPort: process.env.PROXY_PORT + ? parseInt(process.env.PROXY_PORT) + : undefined, + proxyUser: process.env.PROXY_USER_NAME, + proxyPassword: process.env.PROXY_USER_PASSWORD, + }) + .connect((err: SnowflakeError | undefined, conn: Connection) => { + if (err) { + reject(err); + } else { + resolve(conn); + } + }); + }); - await _closeConnection(conn); - return new Ok("Connection successful"); + return new Ok(connection); } catch (error) { return new Err(error instanceof Error ? error : new Error(String(error))); } -}; +} /** * Fetch the tables available in the Snowflake account. */ export const fetchDatabases = async ({ credentials, + connection, }: { credentials: SnowflakeCredentials; -}): Promise> => { + connection?: Connection; +}): Promise, Error>> => { const query = "SHOW DATABASES"; - return _fetchRows({ credentials, query }); + return _fetchRows({ + credentials, + query, + codec: snowflakeDatabaseCodec, + connection, + }); }; /** @@ -63,14 +139,21 @@ export const fetchDatabases = async ({ export const fetchSchemas = async ({ credentials, fromDatabase, + connection, }: { credentials: SnowflakeCredentials; fromDatabase?: string; -}): Promise> => { + connection?: Connection; +}): Promise, Error>> => { const query = fromDatabase ? `SHOW SCHEMAS IN DATABASE ${fromDatabase}` : "SHOW SCHEMAS"; - return _fetchRows({ credentials, query }); + return _fetchRows({ + credentials, + query, + codec: snowflakeSchemaCodec, + connection, + }); }; /** @@ -79,112 +162,201 @@ export const fetchSchemas = async ({ export const fetchTables = async ({ credentials, fromSchema, + connection, }: { credentials: SnowflakeCredentials; fromSchema?: string; -}): Promise> => { + connection?: Connection; +}): Promise, Error>> => { const query = fromSchema ? `SHOW TABLES IN SCHEMA ${fromSchema}` : "SHOW TABLES"; - return _fetchRows({ credentials, query }); + return _fetchRows({ + credentials, + query, + codec: snowflakeTableCodec, + connection, + }); +}; + +/** + * Fetch the grants available for the Snowflake role, + * including future grants, then check if the connection is read-only. + */ +export const isConnectionReadonly = async ({ + credentials, + connection, +}: { + credentials: SnowflakeCredentials; + connection: Connection; +}): Promise, Error>> => { + const currentGrantsRes = await _fetchRows({ + credentials, + query: `SHOW GRANTS TO ROLE ${credentials.role}`, + codec: snowflakeGrantCodec, + connection, + }); + if (currentGrantsRes.isErr()) { + return currentGrantsRes; + } + + const futureGrantsRes = await _fetchRows({ + credentials, + query: `SHOW FUTURE GRANTS TO ROLE ${credentials.role}`, + codec: snowflakeGrantCodec, + connection, + }); + if (futureGrantsRes.isErr()) { + return futureGrantsRes; + } + + const allGrantsRows = [...currentGrantsRes.value, ...futureGrantsRes.value]; + + const grants: Array = []; + for (const row of allGrantsRows) { + const decoded = snowflakeGrantCodec.decode(row); + if (isLeft(decoded)) { + const pathError = reporter.formatValidationErrors(decoded.left); + return new Err(new Error(`Could not parse row: ${pathError}`)); + } + + grants.push(decoded.right); + } + + // We go ove each grant to greenlight them. + for (const g of grants) { + if (g.grant_on === "TABLE") { + // We only allow SELECT grants on tables. + if (g.privilege !== "SELECT") { + return new Err( + new Error( + `Non-select grant found on ${g.grant_on} "${g.name}": privilege=${g.privilege} (connection must be read-only).` + ) + ); + } + } else if (["SCHEMA", "DATABASE", "WAREHOUSE"].includes(g.grant_on)) { + // We only allow USAGE grants on schemas / databases / warehouses. + if (g.privilege !== "USAGE") { + return new Err( + new Error( + `Non-usage grant found on ${g.grant_on} "${g.name}": privilege=${g.privilege} (connection must be read-only).` + ) + ); + } + } else { + // We don't allow any other grants. + return new Err( + new Error( + `Unsupported grant found on ${g.grant_on} "${g.name}": privilege=${g.privilege} (connection must be read-only).` + ) + ); + } + } + + return new Ok(grants); }; // UTILS -async function _fetchRows({ +async function _fetchRows({ credentials, query, + codec, + connection, }: { credentials: SnowflakeCredentials; query: string; -}): Promise> { - snowflake.configure({ - // @ts-expect-error OFF is not in the types but it's a valid value. - logLevel: "OFF", - }); - - const connection = snowflake.createConnection({ - ...credentials, + codec: t.Type; + connection?: Connection; +}): Promise, Error>> { + const connRes = await (() => + connection ? new Ok(connection) : connectToSnowflake(credentials))(); + if (connRes.isErr()) { + return connRes; + } + const conn = connRes.value; - // Use proxy if defined to have all requests coming from the same IP. - proxyHost: process.env.PROXY_HOST, - proxyPort: process.env.PROXY_PORT - ? parseInt(process.env.PROXY_PORT) - : undefined, - proxyUser: process.env.PROXY_USER_NAME, - proxyPassword: process.env.PROXY_USER_PASSWORD, - }); + const rowsRes = await _executeQuery(conn, query); + if (rowsRes.isErr()) { + return rowsRes; + } + const rows = rowsRes.value; - try { - const conn = await _connectToSnowflake(connection); - const rows = await _executeQuery(conn, query); + // We close the connection if we created it. + if (!connection) { await _closeConnection(conn); + } + + if (!rows) { + return new Err(new Error("No tables found or no access to any table.")); + } - if (!rows) { - throw new Error("No tables found or no access to any table."); + const parsedRows: Array = []; + for (const row of rows) { + const decoded = codec.decode(row); + if (isLeft(decoded)) { + const pathError = reporter.formatValidationErrors(decoded.left); + return new Err(new Error(`Could not parse row: ${pathError}`)); } - return new Ok(rows); - } catch (error) { - return new Err(error instanceof Error ? error : new Error(String(error))); + parsedRows.push(decoded.right); } -} -/** - * Util: Connect to Snowflake. - */ -function _connectToSnowflake( - connection: snowflake.Connection -): Promise { - return new Promise((resolve, reject) => { - connection.connect((err: SnowflakeError | undefined, conn: Connection) => { - if (err) { - reject(err); - } else { - resolve(conn); - } - }); - }); + return new Ok(parsedRows); } /** - * Util: Execute a query on the Snowflake connection. + * Util: Close the Snowflake connection. */ -function _executeQuery( - conn: Connection, - sqlText: string -): Promise { - return new Promise((resolve, reject) => { - conn.execute({ - sqlText, - complete: ( - err: SnowflakeError | undefined, - stmt: RowStatement, - rows: SnowflakeRows | undefined - ) => { +async function _closeConnection( + conn: Connection +): Promise> { + try { + await new Promise((resolve, reject) => { + conn.destroy((err: SnowflakeError | undefined) => { if (err) { + console.error("Error closing connection:", err); reject(err); } else { - resolve(rows); + resolve(); } - }, + }); }); - }); + return new Ok(undefined); + } catch (error) { + return new Err(error instanceof Error ? error : new Error(String(error))); + } } /** - * Util: Close the Snowflake connection. + * Util: Execute a query on the Snowflake connection. */ -function _closeConnection(conn: Connection): Promise { - return new Promise((resolve, reject) => { - conn.destroy((err: SnowflakeError | undefined) => { - if (err) { - console.error("Error closing connection:", err); - reject(err); - } else { - resolve(); +async function _executeQuery( + conn: Connection, + sqlText: string +): Promise> { + try { + const r = await new Promise( + (resolve, reject) => { + conn.execute({ + sqlText, + complete: ( + err: SnowflakeError | undefined, + stmt: RowStatement, + rows: SnowflakeRows | undefined + ) => { + if (err) { + reject(err); + } else { + resolve(rows); + } + }, + }); } - }); - }); + ); + return new Ok(r); + } catch (error) { + return new Err(error instanceof Error ? error : new Error(String(error))); + } } diff --git a/connectors/src/connectors/snowflake/temporal/activities.ts b/connectors/src/connectors/snowflake/temporal/activities.ts index 2a93789396aa..450fc8d92608 100644 --- a/connectors/src/connectors/snowflake/temporal/activities.ts +++ b/connectors/src/connectors/snowflake/temporal/activities.ts @@ -1,9 +1,10 @@ import type { ModelId } from "@dust-tt/types"; -import { isLeft } from "fp-ts/lib/Either"; -import * as t from "io-ts"; -import * as reporter from "io-ts-reporters"; -import { fetchTables } from "@connectors/connectors/snowflake/lib/snowflake_api"; +import { + connectToSnowflake, + fetchTables, + isConnectionReadonly, +} from "@connectors/connectors/snowflake/lib/snowflake_api"; import { getConnectorAndCredentials } from "@connectors/connectors/snowflake/lib/utils"; import { dataSourceConfigFromConnector } from "@connectors/lib/api/data_source_config"; import { deleteTable, upsertTable } from "@connectors/lib/data_sources"; @@ -12,15 +13,9 @@ import { RemoteSchemaModel, RemoteTableModel, } from "@connectors/lib/models/remote_databases"; -import { syncSucceeded } from "@connectors/lib/sync_status"; +import { syncFailed, syncSucceeded } from "@connectors/lib/sync_status"; import logger from "@connectors/logger/logger"; -const snowflakeTableCodec = t.type({ - name: t.string, - database_name: t.string, - schema_name: t.string, -}); - export async function syncSnowflakeConnection(connectorId: ModelId) { const getConnectorAndCredentialsRes = await getConnectorAndCredentials({ connectorId, @@ -32,21 +27,11 @@ export async function syncSnowflakeConnection(connectorId: ModelId) { const { credentials, connector } = getConnectorAndCredentialsRes.value; - const tablesRes = await fetchTables({ credentials }); - if (tablesRes.isErr()) { - throw tablesRes.error; - } - const tablesValidation = t.array(snowflakeTableCodec).decode(tablesRes.value); - if (isLeft(tablesValidation)) { - const pathError = reporter.formatValidationErrors(tablesValidation.left); - throw new Error(`Invalid tables response: ${pathError}`); + const connectionRes = await connectToSnowflake(credentials); + if (connectionRes.isErr()) { + throw connectionRes.error; } - const tablesOnSnowflake = tablesValidation.right; - const internalIdsOnSnowflake = new Set( - tablesOnSnowflake.map( - (t) => `${t.database_name}.${t.schema_name}.${t.name}` - ) - ); + const connection = connectionRes.value; const [allDatabases, allSchemas, allTables] = await Promise.all([ RemoteDatabaseModel.findAll({ @@ -66,6 +51,41 @@ export async function syncSnowflakeConnection(connectorId: ModelId) { }), ]); + const readonlyConnectionCheck = await isConnectionReadonly({ + credentials, + connection, + }); + if (readonlyConnectionCheck.isErr()) { + // The connection is not read-only. + // We mark the connector as errored, and garbage collect all the tables that were synced. + await syncFailed(connectorId, "remote_database_connection_not_readonly"); + for (const t of allTables) { + await deleteTable({ + dataSourceConfig: dataSourceConfigFromConnector(connector), + tableId: t.internalId, + }); + if (t.permission === "inherited") { + await t.destroy(); + } else { + await t.update({ + lastUpsertedAt: null, + }); + } + } + return; + } + + const tablesOnSnowflakeRes = await fetchTables({ credentials, connection }); + if (tablesOnSnowflakeRes.isErr()) { + throw tablesOnSnowflakeRes.error; + } + const tablesOnSnowflake = tablesOnSnowflakeRes.value; + const internalIdsOnSnowflake = new Set( + tablesOnSnowflake.map( + (t) => `${t.database_name}.${t.schema_name}.${t.name}` + ) + ); + const readGrantedInternalIds = new Set([ ...allDatabases.map((db) => db.internalId), ...allSchemas.map((s) => s.internalId), diff --git a/connectors/src/lib/models/remote_databases.ts b/connectors/src/lib/models/remote_databases.ts index 749e1a1faa0f..473e89a0bc00 100644 --- a/connectors/src/lib/models/remote_databases.ts +++ b/connectors/src/lib/models/remote_databases.ts @@ -126,7 +126,7 @@ export class RemoteTableModel extends Model< declare id: CreationOptional; declare createdAt: CreationOptional; declare updatedAt: CreationOptional; - declare lastUpsertedAt: CreationOptional; + declare lastUpsertedAt: CreationOptional | null; declare internalId: string; declare name: string; diff --git a/core/src/databases/remote_databases/snowflake.rs b/core/src/databases/remote_databases/snowflake.rs index b5a947ecce00..0622b12cc873 100644 --- a/core/src/databases/remote_databases/snowflake.rs +++ b/core/src/databases/remote_databases/snowflake.rs @@ -37,9 +37,19 @@ struct SnowflakeSchemaColumn { r#type: String, } +#[derive(Debug, Deserialize)] +#[serde(rename_all = "UPPERCASE")] +struct SnowflakeQueryPlanEntry { + objects: Option, + operation: Option, +} + // TODO(SNOWFLAKE) actual limit TBD pub const MAX_QUERY_RESULT_SIZE_BYTES: usize = 128 * 1024 * 1024; // 128MB +// TODO(SNOWFLAKE) make sure we're not missing any +pub const FORBIDDEN_OPERATIONS: [&str; 3] = ["UPDATE", "DELETE", "INSERT"]; + impl TryFrom for TableSchemaColumn { type Error = anyhow::Error; @@ -118,6 +128,14 @@ impl TryFrom for QueryResult { } } +impl TryFrom for SnowflakeQueryPlanEntry { + type Error = anyhow::Error; + + fn try_from(result: QueryResult) -> Result { + serde_json::from_value(result.value).map_err(|e| anyhow!("Error deserializing row: {}", e)) + } +} + impl SnowflakeRemoteDatabase { pub fn new(credentials: serde_json::Map) -> Result { let connection_details: SnowflakeConnectionDetails = @@ -232,23 +250,52 @@ impl SnowflakeRemoteDatabase { Ok((all_rows, schema)) } - async fn get_tables_used_by_query( + async fn get_query_plan( &self, session: &SnowflakeSession, query: &str, - ) -> Result> { - let explain_query = format!("EXPLAIN {}", query); - let used_tables = session - .query(explain_query.clone()) - .await? + ) -> Result, QueryDatabaseError> { + let plan_query = format!("EXPLAIN {}", query); + let (res, _) = self.execute_query(session, &plan_query).await?; + + Ok(res + .into_iter() + .map(|r| r.try_into()) + .collect::>>()?) + } + + async fn authorize_query( + &self, + session: &SnowflakeSession, + tables: &Vec, + query: &str, + ) -> Result<()> { + // Ensure that query only uses tables that are allowed. + let plan = self.get_query_plan(&session, query).await?; + let used_tables: HashSet = plan .iter() - .filter_map(|row| match row.get::("objects") { - Ok(objects) => Some(objects), - _ => None, - }) + .filter_map(|entry| entry.objects.clone()) .collect(); + let allowed_tables: HashSet<&str> = tables.iter().map(|table| table.name()).collect(); + + if used_tables + .iter() + .any(|table| !allowed_tables.contains(table.as_str())) + { + Err(anyhow!("Query uses tables not allowed by the query plan"))? + } - Ok(used_tables) + // Ensure that query does not contain forbidden operations. + for operation in plan.iter().filter_map(|entry| entry.operation.clone()) { + if FORBIDDEN_OPERATIONS + .iter() + .any(|op| operation.to_lowercase() == *op) + { + Err(anyhow!("Query contains forbidden operations"))? + } + } + + Ok(()) } } @@ -265,18 +312,9 @@ impl RemoteDatabase for SnowflakeRemoteDatabase { ) -> Result<(Vec, TableSchema), QueryDatabaseError> { let session = self.get_session().await?; - // Ensure that query only uses tables that are allowed. - let used_tables = self.get_tables_used_by_query(&session, query).await?; - let allowed_tables: HashSet<&str> = tables.iter().map(|table| table.name()).collect(); - - if used_tables - .iter() - .any(|table| !allowed_tables.contains(table.as_str())) - { - Err(QueryDatabaseError::ExecutionError( - "Query uses tables not allowed by the query plan".to_string(), - ))? - } + // Authorize the query based on allowed tables, query plan, + // and forbidden operations. + let _ = self.authorize_query(&session, tables, query).await?; self.execute_query(&session, query).await } diff --git a/types/src/connectors/snowflake.ts b/types/src/connectors/snowflake.ts new file mode 100644 index 000000000000..e509fd05c058 --- /dev/null +++ b/types/src/connectors/snowflake.ts @@ -0,0 +1,6 @@ +/** + * Some databases and schemas are not useful to show in the content tree. + * We exclude them here. + */ +export const EXCLUDE_DATABASES = ["SNOWFLAKE", "SNOWFLAKE_SAMPLE_DATA"]; +export const EXCLUDE_SCHEMAS = ["INFORMATION_SCHEMA"]; diff --git a/types/src/front/lib/connectors_api.ts b/types/src/front/lib/connectors_api.ts index fe48a0b7f93b..c8a79bf8b5d0 100644 --- a/types/src/front/lib/connectors_api.ts +++ b/types/src/front/lib/connectors_api.ts @@ -18,6 +18,7 @@ const CONNECTORS_ERROR_TYPES = [ "oauth_token_revoked", "third_party_internal_error", "webcrawling_error", + "remote_database_connection_not_readonly", ] as const; export type ConnectorErrorType = (typeof CONNECTORS_ERROR_TYPES)[number]; diff --git a/types/src/index.ts b/types/src/index.ts index f70153ad6140..a94cf29d4a9e 100644 --- a/types/src/index.ts +++ b/types/src/index.ts @@ -11,6 +11,7 @@ export * from "./connectors/intercom"; export * from "./connectors/microsoft"; export * from "./connectors/notion"; export * from "./connectors/slack"; +export * from "./connectors/snowflake"; export * from "./connectors/webcrawler"; export * from "./core/data_source"; export * from "./front/acl";