Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions rust/cubesql/cubesql/e2e/tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1302,16 +1302,6 @@ impl AsyncTestSuite for PostgresIntegrationTestSuite {
)
.await?;

self.test_simple_query(r#"SET ROLE "cube""#.to_string(), |messages| {
assert_eq!(messages.len(), 1);

// SET
if let SimpleQueryMessage::Row(_) = messages[0] {
panic!("Must be CommandComplete command, (SET is used)")
}
})
.await?;

// Tableau Desktop
self.test_simple_query(
r#"SET DateStyle = 'ISO';SET extra_float_digits = 2;show transaction_isolation"#
Expand Down
30 changes: 28 additions & 2 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5177,15 +5177,41 @@ ORDER BY
);

insta::assert_snapshot!(
"pg_set_role_show",
"pg_set_role_good_user",
execute_queries_with_flags(
vec!["SET ROLE NONE".to_string(), "SHOW ROLE".to_string()],
vec!["SET ROLE good_user".to_string(), "SHOW ROLE".to_string()],
DatabaseProtocol::PostgreSQL
)
.await?
.0
);

insta::assert_snapshot!(
"pg_set_role_none",
execute_queries_with_flags(
vec![
"SET ROLE good_user".to_string(),
"SET ROLE NONE".to_string(),
"SHOW ROLE".to_string()
],
DatabaseProtocol::PostgreSQL
)
.await?
.0
);

insta::assert_snapshot!(
"pg_set_role_bad_user",
execute_queries_with_flags(
vec!["SET ROLE bad_user".to_string()],
DatabaseProtocol::PostgreSQL
)
.await
.err()
.unwrap()
.to_string()
);

Ok(())
}

Expand Down
121 changes: 67 additions & 54 deletions rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ impl QueryRouter {
StatusFlags::empty(),
Box::new(dataframe::DataFrame::new(vec![], vec![])),
)),
(ast::Statement::SetRole { role_name, .. }, _) => self.set_role_to_plan(role_name),
(ast::Statement::SetRole { role_name, .. }, _) => {
self.set_role_to_plan(role_name).await
}
(ast::Statement::SetVariable { key_values }, _) => {
self.set_variable_to_plan(&key_values).await
}
Expand Down Expand Up @@ -283,19 +285,24 @@ impl QueryRouter {
}
}

fn set_role_to_plan(
async fn set_role_to_plan(
&self,
role_name: &Option<ast::Ident>,
) -> Result<QueryPlan, CompilationError> {
let flags = StatusFlags::SERVER_STATE_CHANGED;
let role_name = role_name
.as_ref()
.map(|role_name| role_name.value.clone())
.unwrap_or("none".to_string());
let variable =
DatabaseVariable::system("role".to_string(), ScalarValue::Utf8(Some(role_name)), None);
let username = role_name.as_ref().map(|role_name| role_name.value.clone());
let Some(to_user) = username.clone().or_else(|| self.state.original_user()) else {
return Err(CompilationError::user(
"Cannot reset role when original role has not been set".to_string(),
));
};
self.change_user(to_user).await?;
let variable = DatabaseVariable::system(
"role".to_string(),
ScalarValue::Utf8(Some(username.unwrap_or("none".to_string()))),
None,
);
self.state.set_variables(vec![variable]);

Ok(QueryPlan::MetaOk(flags, CommandCompletion::Set))
}

Expand Down Expand Up @@ -419,11 +426,6 @@ impl QueryRouter {
});

for v in user_variables {
self.reauthenticate_if_needed().await?;

let auth_context = self.state.auth_context().ok_or(CompilationError::user(
"No auth context set but tried to set current user".to_string(),
))?;
let to_user = match v.value {
ScalarValue::Utf8(Some(user)) => user,
_ => {
Expand All @@ -433,46 +435,7 @@ impl QueryRouter {
)))
}
};
if self
.session_manager
.server
.transport
.can_switch_user_for_session(auth_context.clone(), to_user.clone())
.await
.map_err(|e| {
CompilationError::internal(format!(
"Error calling can_switch_user_for_session: {}",
e
))
})?
{
self.state.set_user(Some(to_user.clone()));
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
protocol: "postgres".to_string(),
method: "password".to_string(),
};
let authenticate_response = self
.session_manager
.server
.auth
.authenticate(sql_auth_request, Some(to_user.clone()), None)
.await
.map_err(|e| {
CompilationError::internal(format!("Error calling authenticate: {}", e))
})?;
self.state
.set_auth_context(Some(authenticate_response.context));
} else {
return Err(CompilationError::user(format!(
"user '{}' is not allowed to switch to '{}'",
auth_context
.user()
.as_ref()
.map(|v| v.as_str())
.unwrap_or("not specified"),
to_user
)));
}
self.change_user(to_user).await?;
}

if !session_columns_to_update.is_empty() {
Expand All @@ -488,6 +451,56 @@ impl QueryRouter {
Ok(QueryPlan::MetaOk(flags, CommandCompletion::Set))
}

async fn change_user(&self, username: String) -> Result<(), CompilationError> {
self.reauthenticate_if_needed().await?;

let auth_context = self.state.auth_context().ok_or(CompilationError::user(
"No auth context set but tried to set current user".to_string(),
))?;

let can_switch_user = self
.session_manager
.server
.transport
.can_switch_user_for_session(auth_context.clone(), username.clone())
.await
.map_err(|e| {
CompilationError::internal(format!(
"Error calling can_switch_user_for_session: {}",
e
))
})?;
if !can_switch_user {
return Err(CompilationError::user(format!(
"user '{}' is not allowed to switch to '{}'",
auth_context
.user()
.as_ref()
.map(|v| v.as_str())
.unwrap_or("not specified"),
username
)));
}

self.state.set_user(Some(username.clone()));
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
protocol: "postgres".to_string(),
method: "password".to_string(),
};
let authenticate_response = self
.session_manager
.server
.auth
.authenticate(sql_auth_request, Some(username), None)
.await
.map_err(|e| {
CompilationError::internal(format!("Error calling authenticate: {}", e))
})?;
self.state
.set_auth_context(Some(authenticate_response.context));
Ok(())
}

async fn create_table_to_plan(
&self,
name: &ast::ObjectName,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
source: cubesql/src/compile/mod.rs
expression: "execute_queries_with_flags(vec![\"SET ROLE bad_user\".to_string(),],\nDatabaseProtocol::PostgreSQL).await.err().unwrap().to_string()"
---
Error during planning: SQLCompilationError: User: user 'not specified' is not allowed to switch to 'bad_user'
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: cubesql/src/compile/mod.rs
expression: "execute_queries_with_flags(vec![\"SET ROLE good_user\".to_string(),\n\"SHOW ROLE\".to_string()], DatabaseProtocol::PostgreSQL).await? .0"
---
+-----------+
| setting |
+-----------+
| good_user |
+-----------+
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
source: cubesql/src/compile/mod.rs
expression: "execute_queries_with_flags(vec![\"SET ROLE good_user\".to_string(),\n\"SET ROLE NONE\".to_string(), \"SHOW ROLE\".to_string()],\nDatabaseProtocol::PostgreSQL).await? .0"
---
+---------+
| setting |
+---------+
| none |
+---------+

This file was deleted.

3 changes: 2 additions & 1 deletion rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ async fn get_test_session_with_config_and_transport(
// Populate like shims
session.state.set_database(Some(db_name.to_string()));
session.state.set_user(Some("ovr".to_string()));
session.state.set_original_user(Some("ovr".to_string()));

let auth_ctx = HttpAuthContext {
access_token: "access_token".to_string(),
Expand Down Expand Up @@ -938,7 +939,7 @@ impl TransportService for TestConnectionTransport {
_ctx: AuthContextRef,
to_user: String,
) -> Result<bool, CubeError> {
if to_user == "good_user" {
if matches!(to_user.as_str(), "good_user" | "ovr") {
Ok(true)
} else {
Ok(false)
Expand Down
3 changes: 2 additions & 1 deletion rust/cubesql/cubesql/src/sql/postgres/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,8 @@ impl AsyncPostgresShim {
.cloned()
.unwrap_or("db".to_string());
self.session.state.set_database(Some(database));
self.session.state.set_user(Some(user));
self.session.state.set_user(Some(user.clone()));
self.session.state.set_original_user(Some(user));
self.session.state.set_auth_context(Some(auth_context));

self.write(protocol::Authentication::new(AuthenticationRequest::Ok))
Expand Down
22 changes: 22 additions & 0 deletions rust/cubesql/cubesql/src/sql/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub struct SessionState {
// @todo Remove RWLock after split of Connection & SQLWorker
// Context for Transport
auth_context: RwLockSync<(Option<AuthContextRef>, SystemTime)>,
// Used to reset user with SET ROLE NONE
original_user: RwLockSync<Option<String>>,

transaction: RwLockSync<TransactionState>,
query: RwLockSync<QueryState>,
Expand Down Expand Up @@ -116,6 +118,7 @@ impl SessionState {
temp_tables: Arc::new(TempTableManager::new(session_manager)),
properties: RwLockSync::new(SessionProperties::new(None, None)),
auth_context: RwLockSync::new((auth_context, SystemTime::now())),
original_user: RwLockSync::new(None),
transaction: RwLockSync::new(TransactionState::None),
query: RwLockSync::new(QueryState::None),
statements: RWLockAsync::new(HashMap::new()),
Expand Down Expand Up @@ -271,6 +274,25 @@ impl SessionState {
guard.user = user;
}

pub fn original_user(&self) -> Option<String> {
let guard = self
.original_user
.read()
.expect("failed to unlock original_user for reading");
guard.clone()
}

pub fn set_original_user(&self, user: Option<String>) {
let mut guard = self
.original_user
.write()
.expect("failed to unlock original_user for writing");
if guard.is_none() {
// Silently ignore writing original user if it's already set
*guard = user;
}
}

pub fn database(&self) -> Option<String> {
let guard = self
.properties
Expand Down
Loading