diff --git a/rust/cubesql/cubesql/e2e/tests/postgres.rs b/rust/cubesql/cubesql/e2e/tests/postgres.rs index aa8f403578c73..e79d3578e27f3 100644 --- a/rust/cubesql/cubesql/e2e/tests/postgres.rs +++ b/rust/cubesql/cubesql/e2e/tests/postgres.rs @@ -1302,16 +1302,6 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite { ) .await?; - self.test_simple_query(r#"SET ROLE "cube""#.to_string(), |messages| { - assert_eq!(messages.len(), 1); - - // SET - if let SimpleQueryMessage::Row(_) = messages[0] { - panic!("Must be CommandComplete command, (SET is used)") - } - }) - .await?; - // Tableau Desktop self.test_simple_query( r#"SET DateStyle = 'ISO';SET extra_float_digits = 2;show transaction_isolation"# diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 98b18e4ce9c2d..6deda6ae7ba1a 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -5177,15 +5177,41 @@ ORDER BY ); insta::assert_snapshot!( - "pg_set_role_show", + "pg_set_role_good_user", execute_queries_with_flags( - vec!["SET ROLE NONE".to_string(), "SHOW ROLE".to_string()], + vec!["SET ROLE good_user".to_string(), "SHOW ROLE".to_string()], DatabaseProtocol::PostgreSQL ) .await? .0 ); + insta::assert_snapshot!( + "pg_set_role_none", + execute_queries_with_flags( + vec![ + "SET ROLE good_user".to_string(), + "SET ROLE NONE".to_string(), + "SHOW ROLE".to_string() + ], + DatabaseProtocol::PostgreSQL + ) + .await? + .0 + ); + + insta::assert_snapshot!( + "pg_set_role_bad_user", + execute_queries_with_flags( + vec!["SET ROLE bad_user".to_string()], + DatabaseProtocol::PostgreSQL + ) + .await + .err() + .unwrap() + .to_string() + ); + Ok(()) } diff --git a/rust/cubesql/cubesql/src/compile/router.rs b/rust/cubesql/cubesql/src/compile/router.rs index 1abe85e935fdf..4d5cbe28cd25b 100644 --- a/rust/cubesql/cubesql/src/compile/router.rs +++ b/rust/cubesql/cubesql/src/compile/router.rs @@ -102,7 +102,9 @@ impl QueryRouter { StatusFlags::empty(), Box::new(dataframe::DataFrame::new(vec![], vec![])), )), - (ast::Statement::SetRole { role_name, .. }, _) => self.set_role_to_plan(role_name), + (ast::Statement::SetRole { role_name, .. }, _) => { + self.set_role_to_plan(role_name).await + } (ast::Statement::SetVariable { key_values }, _) => { self.set_variable_to_plan(&key_values).await } @@ -283,19 +285,24 @@ impl QueryRouter { } } - fn set_role_to_plan( + async fn set_role_to_plan( &self, role_name: &Option, ) -> Result { let flags = StatusFlags::SERVER_STATE_CHANGED; - let role_name = role_name - .as_ref() - .map(|role_name| role_name.value.clone()) - .unwrap_or("none".to_string()); - let variable = - DatabaseVariable::system("role".to_string(), ScalarValue::Utf8(Some(role_name)), None); + let username = role_name.as_ref().map(|role_name| role_name.value.clone()); + let Some(to_user) = username.clone().or_else(|| self.state.original_user()) else { + return Err(CompilationError::user( + "Cannot reset role when original role has not been set".to_string(), + )); + }; + self.change_user(to_user).await?; + let variable = DatabaseVariable::system( + "role".to_string(), + ScalarValue::Utf8(Some(username.unwrap_or("none".to_string()))), + None, + ); self.state.set_variables(vec![variable]); - Ok(QueryPlan::MetaOk(flags, CommandCompletion::Set)) } @@ -419,11 +426,6 @@ impl QueryRouter { }); for v in user_variables { - self.reauthenticate_if_needed().await?; - - let auth_context = self.state.auth_context().ok_or(CompilationError::user( - "No auth context set but tried to set current user".to_string(), - ))?; let to_user = match v.value { ScalarValue::Utf8(Some(user)) => user, _ => { @@ -433,46 +435,7 @@ impl QueryRouter { ))) } }; - if self - .session_manager - .server - .transport - .can_switch_user_for_session(auth_context.clone(), to_user.clone()) - .await - .map_err(|e| { - CompilationError::internal(format!( - "Error calling can_switch_user_for_session: {}", - e - )) - })? - { - self.state.set_user(Some(to_user.clone())); - let sql_auth_request = SqlAuthServiceAuthenticateRequest { - protocol: "postgres".to_string(), - method: "password".to_string(), - }; - let authenticate_response = self - .session_manager - .server - .auth - .authenticate(sql_auth_request, Some(to_user.clone()), None) - .await - .map_err(|e| { - CompilationError::internal(format!("Error calling authenticate: {}", e)) - })?; - self.state - .set_auth_context(Some(authenticate_response.context)); - } else { - return Err(CompilationError::user(format!( - "user '{}' is not allowed to switch to '{}'", - auth_context - .user() - .as_ref() - .map(|v| v.as_str()) - .unwrap_or("not specified"), - to_user - ))); - } + self.change_user(to_user).await?; } if !session_columns_to_update.is_empty() { @@ -488,6 +451,56 @@ impl QueryRouter { Ok(QueryPlan::MetaOk(flags, CommandCompletion::Set)) } + async fn change_user(&self, username: String) -> Result<(), CompilationError> { + self.reauthenticate_if_needed().await?; + + let auth_context = self.state.auth_context().ok_or(CompilationError::user( + "No auth context set but tried to set current user".to_string(), + ))?; + + let can_switch_user = self + .session_manager + .server + .transport + .can_switch_user_for_session(auth_context.clone(), username.clone()) + .await + .map_err(|e| { + CompilationError::internal(format!( + "Error calling can_switch_user_for_session: {}", + e + )) + })?; + if !can_switch_user { + return Err(CompilationError::user(format!( + "user '{}' is not allowed to switch to '{}'", + auth_context + .user() + .as_ref() + .map(|v| v.as_str()) + .unwrap_or("not specified"), + username + ))); + } + + self.state.set_user(Some(username.clone())); + let sql_auth_request = SqlAuthServiceAuthenticateRequest { + protocol: "postgres".to_string(), + method: "password".to_string(), + }; + let authenticate_response = self + .session_manager + .server + .auth + .authenticate(sql_auth_request, Some(username), None) + .await + .map_err(|e| { + CompilationError::internal(format!("Error calling authenticate: {}", e)) + })?; + self.state + .set_auth_context(Some(authenticate_response.context)); + Ok(()) + } + async fn create_table_to_plan( &self, name: &ast::ObjectName, diff --git a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_bad_user.snap b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_bad_user.snap new file mode 100644 index 0000000000000..4bd28ba64dc67 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_bad_user.snap @@ -0,0 +1,5 @@ +--- +source: cubesql/src/compile/mod.rs +expression: "execute_queries_with_flags(vec![\"SET ROLE bad_user\".to_string(),],\nDatabaseProtocol::PostgreSQL).await.err().unwrap().to_string()" +--- +Error during planning: SQLCompilationError: User: user 'not specified' is not allowed to switch to 'bad_user' diff --git a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_good_user.snap b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_good_user.snap new file mode 100644 index 0000000000000..b74032de08d3b --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_good_user.snap @@ -0,0 +1,9 @@ +--- +source: cubesql/src/compile/mod.rs +expression: "execute_queries_with_flags(vec![\"SET ROLE good_user\".to_string(),\n\"SHOW ROLE\".to_string()], DatabaseProtocol::PostgreSQL).await? .0" +--- ++-----------+ +| setting | ++-----------+ +| good_user | ++-----------+ diff --git a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_none.snap b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_none.snap new file mode 100644 index 0000000000000..c90e3b452db80 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_none.snap @@ -0,0 +1,9 @@ +--- +source: cubesql/src/compile/mod.rs +expression: "execute_queries_with_flags(vec![\"SET ROLE good_user\".to_string(),\n\"SET ROLE NONE\".to_string(), \"SHOW ROLE\".to_string()],\nDatabaseProtocol::PostgreSQL).await? .0" +--- ++---------+ +| setting | ++---------+ +| none | ++---------+ diff --git a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_show.snap b/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_show.snap deleted file mode 100644 index 415d9c43d3efd..0000000000000 --- a/rust/cubesql/cubesql/src/compile/snapshots/cubesql__compile__tests__pg_set_role_show.snap +++ /dev/null @@ -1,9 +0,0 @@ ---- -source: cubesql/src/compile/mod.rs -expression: "execute_queries_with_flags(vec![\"SET ROLE NONE\".to_string(),\n \"SHOW ROLE\".to_string()],\n DatabaseProtocol::PostgreSQL).await?.0" ---- -+---------+ -| setting | -+---------+ -| none | -+---------+ diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 4f4dbf8f71aae..54589d7713db1 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -771,6 +771,7 @@ async fn get_test_session_with_config_and_transport( // Populate like shims session.state.set_database(Some(db_name.to_string())); session.state.set_user(Some("ovr".to_string())); + session.state.set_original_user(Some("ovr".to_string())); let auth_ctx = HttpAuthContext { access_token: "access_token".to_string(), @@ -938,7 +939,7 @@ impl TransportService for TestConnectionTransport { _ctx: AuthContextRef, to_user: String, ) -> Result { - if to_user == "good_user" { + if matches!(to_user.as_str(), "good_user" | "ovr") { Ok(true) } else { Ok(false) diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index e534087f70578..e542d7a6a3d20 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -806,7 +806,8 @@ impl AsyncPostgresShim { .cloned() .unwrap_or("db".to_string()); self.session.state.set_database(Some(database)); - self.session.state.set_user(Some(user)); + self.session.state.set_user(Some(user.clone())); + self.session.state.set_original_user(Some(user)); self.session.state.set_auth_context(Some(auth_context)); self.write(protocol::Authentication::new(AuthenticationRequest::Ok)) diff --git a/rust/cubesql/cubesql/src/sql/session.rs b/rust/cubesql/cubesql/src/sql/session.rs index abc909f09939b..dd7a7d39cfa85 100644 --- a/rust/cubesql/cubesql/src/sql/session.rs +++ b/rust/cubesql/cubesql/src/sql/session.rs @@ -82,6 +82,8 @@ pub struct SessionState { // @todo Remove RWLock after split of Connection & SQLWorker // Context for Transport auth_context: RwLockSync<(Option, SystemTime)>, + // Used to reset user with SET ROLE NONE + original_user: RwLockSync>, transaction: RwLockSync, query: RwLockSync, @@ -116,6 +118,7 @@ impl SessionState { temp_tables: Arc::new(TempTableManager::new(session_manager)), properties: RwLockSync::new(SessionProperties::new(None, None)), auth_context: RwLockSync::new((auth_context, SystemTime::now())), + original_user: RwLockSync::new(None), transaction: RwLockSync::new(TransactionState::None), query: RwLockSync::new(QueryState::None), statements: RWLockAsync::new(HashMap::new()), @@ -271,6 +274,25 @@ impl SessionState { guard.user = user; } + pub fn original_user(&self) -> Option { + let guard = self + .original_user + .read() + .expect("failed to unlock original_user for reading"); + guard.clone() + } + + pub fn set_original_user(&self, user: Option) { + let mut guard = self + .original_user + .write() + .expect("failed to unlock original_user for writing"); + if guard.is_none() { + // Silently ignore writing original user if it's already set + *guard = user; + } + } + pub fn database(&self) -> Option { let guard = self .properties