Skip to content

Commit

Permalink
fix: return correct error types for snowflake remote DBs (#7953)
Browse files Browse the repository at this point in the history
Co-authored-by: Henry Fontanier <henry@dust.tt>
  • Loading branch information
fontanierh and Henry Fontanier authored Oct 8, 2024
1 parent 326cb4f commit 169f7bc
Showing 1 changed file with 62 additions and 37 deletions.
99 changes: 62 additions & 37 deletions core/src/databases/remote_databases/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,13 @@ impl TryFrom<QueryResult> for SnowflakeQueryPlanEntry {
}

impl SnowflakeRemoteDatabase {
pub fn new(credentials: serde_json::Map<String, serde_json::Value>) -> Result<Self> {
pub fn new(
credentials: serde_json::Map<String, serde_json::Value>,
) -> Result<Self, QueryDatabaseError> {
let connection_details: SnowflakeConnectionDetails =
serde_json::from_value(serde_json::Value::Object(credentials))?;
serde_json::from_value(serde_json::Value::Object(credentials)).map_err(|e| {
QueryDatabaseError::GenericError(anyhow!("Error deserializing credentials: {}", e))
})?;

let mut client = SnowflakeClient::new(
&connection_details.username,
Expand All @@ -155,21 +159,30 @@ impl SnowflakeRemoteDatabase {
schema: None,
timeout: Some(std::time::Duration::from_secs(30)),
},
)?;
)
.map_err(|e| {
QueryDatabaseError::GenericError(anyhow!("Error creating Snowflake client: {}", e))
})?;

if let (Ok(proxy_host), Ok(proxy_port), Ok(proxy_user_name), Ok(proxy_user_password)) = (
env::var("PROXY_HOST"),
env::var("PROXY_PORT"),
env::var("PROXY_USER_NAME"),
env::var("PROXY_USER_PASSWORD"),
) {
let proxy_port = proxy_port.parse::<u16>()?;
client = client.with_proxy(
&proxy_host,
proxy_port,
&proxy_user_name,
&proxy_user_password,
)?;
let proxy_port = proxy_port.parse::<u16>().map_err(|e| {
QueryDatabaseError::GenericError(anyhow!("Error parsing proxy port: {}", e))
})?;
client = client
.with_proxy(
&proxy_host,
proxy_port,
&proxy_user_name,
&proxy_user_password,
)
.map_err(|e| {
QueryDatabaseError::GenericError(anyhow!("Error setting proxy: {}", e))
})?;
}

Ok(Self {
Expand All @@ -178,24 +191,22 @@ impl SnowflakeRemoteDatabase {
})
}

async fn try_get_session(&self) -> Result<SnowflakeSession> {
async fn try_get_session(&self) -> Result<SnowflakeSession, QueryDatabaseError> {
let session = self.client.create_session().await.map_err(|e| {
QueryDatabaseError::ExecutionError(anyhow!("Error creating session: {}", e).to_string())
QueryDatabaseError::GenericError(anyhow!("Error creating session: {}", e))
})?;

let _ = session
.execute(format!("USE WAREHOUSE {}", self.warehouse))
.await
.map_err(|e| {
QueryDatabaseError::ExecutionError(
anyhow!("Error setting warehouse: {}", e).to_string(),
)
QueryDatabaseError::GenericError(anyhow!("Error setting warehouse: {}", e))
})?;

Ok(session)
}

async fn get_session(&self) -> Result<SnowflakeSession> {
async fn get_session(&self) -> Result<SnowflakeSession, QueryDatabaseError> {
let mut tries = 0;
let mut backoff = tokio::time::Duration::from_millis(100);

Expand Down Expand Up @@ -224,9 +235,10 @@ impl SnowflakeRemoteDatabase {
Err(snowflake_connector_rs::Error::TimedOut) => Err(
QueryDatabaseError::ExecutionError("Query execution timed out".to_string()),
),
Err(e) => Err(QueryDatabaseError::ExecutionError(
anyhow!("Error executing query: {}", e).to_string(),
)),
Err(e) => Err(QueryDatabaseError::ExecutionError(format!(
"Error executing query: {}",
e
))),
}?;

let mut query_result_size: usize = 0;
Expand All @@ -237,9 +249,7 @@ impl SnowflakeRemoteDatabase {
// Stop fetching when chunk is None.
'fetch_rows: loop {
match executor.fetch_next_chunk().await.map_err(|e| {
QueryDatabaseError::ExecutionError(
anyhow!("Error fetching rows: {}", e).to_string(),
)
QueryDatabaseError::GenericError(anyhow!("Error fetching rows: {}", e))
})? {
Some(snowflake_rows) => {
// Convert SnowflakeRow to QueryResult.
Expand Down Expand Up @@ -291,7 +301,7 @@ impl SnowflakeRemoteDatabase {
session: &SnowflakeSession,
tables: &Vec<Table>,
query: &str,
) -> Result<()> {
) -> Result<(), QueryDatabaseError> {
// Ensure that query only uses tables that are allowed.
let plan = self.get_query_plan(&session, query).await?;
let used_tables: HashSet<&str> = plan
Expand All @@ -302,22 +312,37 @@ impl SnowflakeRemoteDatabase {
})
.collect();
let allowed_tables: HashSet<&str> = tables.iter().map(|table| table.name()).collect();

if used_tables
.iter()
.any(|table| !allowed_tables.contains(table))
{
Err(anyhow!("Query uses tables not allowed by the query plan"))?
let used_forbidden_tables = used_tables
.into_iter()
.filter(|table| !allowed_tables.contains(*table))
.collect::<Vec<_>>();

if !used_forbidden_tables.is_empty() {
Err(QueryDatabaseError::ExecutionError(format!(
"Query uses tables that are not allowed: {}",
used_forbidden_tables.join(", ")
)))?
}

// Ensure that query does not contain forbidden operations.
for operation in plan.into_iter().filter_map(|entry| entry.operation) {
if FORBIDDEN_OPERATIONS
.iter()
.any(|op| operation.to_lowercase() == *op)
{
Err(anyhow!("Query contains forbidden operations"))?
}
let used_forbidden_operations = plan
.into_iter()
.filter_map(|entry| match entry.operation {
Some(op)
if FORBIDDEN_OPERATIONS
.iter()
.any(|forbidden_op| op.to_lowercase() == *forbidden_op) =>
{
Some(op)
}
_ => None,
})
.collect::<Vec<_>>();

if !used_forbidden_operations.is_empty() {
Err(QueryDatabaseError::ExecutionError(format!(
"Query contains forbidden operations: {}",
used_forbidden_operations.join(", ")
)))?
}

Ok(())
Expand Down

0 comments on commit 169f7bc

Please sign in to comment.