From 169f7bc8da9d899053eeb3b45a1810c767aa2563 Mon Sep 17 00:00:00 2001 From: Henry Fontanier Date: Tue, 8 Oct 2024 17:46:12 +0200 Subject: [PATCH] fix: return correct error types for snowflake remote DBs (#7953) Co-authored-by: Henry Fontanier --- .../databases/remote_databases/snowflake.rs | 99 ++++++++++++------- 1 file changed, 62 insertions(+), 37 deletions(-) diff --git a/core/src/databases/remote_databases/snowflake.rs b/core/src/databases/remote_databases/snowflake.rs index bc7cb8cdcd0e..4f71256d194c 100644 --- a/core/src/databases/remote_databases/snowflake.rs +++ b/core/src/databases/remote_databases/snowflake.rs @@ -140,9 +140,13 @@ impl TryFrom for SnowflakeQueryPlanEntry { } impl SnowflakeRemoteDatabase { - pub fn new(credentials: serde_json::Map) -> Result { + pub fn new( + credentials: serde_json::Map, + ) -> Result { let connection_details: SnowflakeConnectionDetails = - serde_json::from_value(serde_json::Value::Object(credentials))?; + serde_json::from_value(serde_json::Value::Object(credentials)).map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error deserializing credentials: {}", e)) + })?; let mut client = SnowflakeClient::new( &connection_details.username, @@ -155,7 +159,10 @@ impl SnowflakeRemoteDatabase { schema: None, timeout: Some(std::time::Duration::from_secs(30)), }, - )?; + ) + .map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error creating Snowflake client: {}", e)) + })?; if let (Ok(proxy_host), Ok(proxy_port), Ok(proxy_user_name), Ok(proxy_user_password)) = ( env::var("PROXY_HOST"), @@ -163,13 +170,19 @@ impl SnowflakeRemoteDatabase { env::var("PROXY_USER_NAME"), env::var("PROXY_USER_PASSWORD"), ) { - let proxy_port = proxy_port.parse::()?; - client = client.with_proxy( - &proxy_host, - proxy_port, - &proxy_user_name, - &proxy_user_password, - )?; + let proxy_port = proxy_port.parse::().map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error parsing proxy port: {}", e)) + })?; + client = client + .with_proxy( + &proxy_host, + proxy_port, + &proxy_user_name, + &proxy_user_password, + ) + .map_err(|e| { + QueryDatabaseError::GenericError(anyhow!("Error setting proxy: {}", e)) + })?; } Ok(Self { @@ -178,24 +191,22 @@ impl SnowflakeRemoteDatabase { }) } - async fn try_get_session(&self) -> Result { + async fn try_get_session(&self) -> Result { let session = self.client.create_session().await.map_err(|e| { - QueryDatabaseError::ExecutionError(anyhow!("Error creating session: {}", e).to_string()) + QueryDatabaseError::GenericError(anyhow!("Error creating session: {}", e)) })?; let _ = session .execute(format!("USE WAREHOUSE {}", self.warehouse)) .await .map_err(|e| { - QueryDatabaseError::ExecutionError( - anyhow!("Error setting warehouse: {}", e).to_string(), - ) + QueryDatabaseError::GenericError(anyhow!("Error setting warehouse: {}", e)) })?; Ok(session) } - async fn get_session(&self) -> Result { + async fn get_session(&self) -> Result { let mut tries = 0; let mut backoff = tokio::time::Duration::from_millis(100); @@ -224,9 +235,10 @@ impl SnowflakeRemoteDatabase { Err(snowflake_connector_rs::Error::TimedOut) => Err( QueryDatabaseError::ExecutionError("Query execution timed out".to_string()), ), - Err(e) => Err(QueryDatabaseError::ExecutionError( - anyhow!("Error executing query: {}", e).to_string(), - )), + Err(e) => Err(QueryDatabaseError::ExecutionError(format!( + "Error executing query: {}", + e + ))), }?; let mut query_result_size: usize = 0; @@ -237,9 +249,7 @@ impl SnowflakeRemoteDatabase { // Stop fetching when chunk is None. 'fetch_rows: loop { match executor.fetch_next_chunk().await.map_err(|e| { - QueryDatabaseError::ExecutionError( - anyhow!("Error fetching rows: {}", e).to_string(), - ) + QueryDatabaseError::GenericError(anyhow!("Error fetching rows: {}", e)) })? { Some(snowflake_rows) => { // Convert SnowflakeRow to QueryResult. @@ -291,7 +301,7 @@ impl SnowflakeRemoteDatabase { session: &SnowflakeSession, tables: &Vec, query: &str, - ) -> Result<()> { + ) -> Result<(), QueryDatabaseError> { // Ensure that query only uses tables that are allowed. let plan = self.get_query_plan(&session, query).await?; let used_tables: HashSet<&str> = plan @@ -302,22 +312,37 @@ impl SnowflakeRemoteDatabase { }) .collect(); let allowed_tables: HashSet<&str> = tables.iter().map(|table| table.name()).collect(); - - if used_tables - .iter() - .any(|table| !allowed_tables.contains(table)) - { - Err(anyhow!("Query uses tables not allowed by the query plan"))? + let used_forbidden_tables = used_tables + .into_iter() + .filter(|table| !allowed_tables.contains(*table)) + .collect::>(); + + if !used_forbidden_tables.is_empty() { + Err(QueryDatabaseError::ExecutionError(format!( + "Query uses tables that are not allowed: {}", + used_forbidden_tables.join(", ") + )))? } - // Ensure that query does not contain forbidden operations. - for operation in plan.into_iter().filter_map(|entry| entry.operation) { - if FORBIDDEN_OPERATIONS - .iter() - .any(|op| operation.to_lowercase() == *op) - { - Err(anyhow!("Query contains forbidden operations"))? - } + let used_forbidden_operations = plan + .into_iter() + .filter_map(|entry| match entry.operation { + Some(op) + if FORBIDDEN_OPERATIONS + .iter() + .any(|forbidden_op| op.to_lowercase() == *forbidden_op) => + { + Some(op) + } + _ => None, + }) + .collect::>(); + + if !used_forbidden_operations.is_empty() { + Err(QueryDatabaseError::ExecutionError(format!( + "Query contains forbidden operations: {}", + used_forbidden_operations.join(", ") + )))? } Ok(())