Skip to content

Commit 0041eca

Browse files
committed
Disconnect clients when updating RLS rules
1 parent 5c42b09 commit 0041eca

File tree

4 files changed

+134
-3
lines changed

4 files changed

+134
-3
lines changed

crates/schema/src/auto_migrate.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,13 +1019,24 @@ fn auto_migrate_constraints(plan: &mut AutoMigratePlan, new_tables: &HashSet<&Id
10191019
// Because we can refer to many tables and fields on the row level-security query, we need to remove all of them,
10201020
// then add the new ones, instead of trying to track the graph of dependencies.
10211021
fn auto_migrate_row_level_security(plan: &mut AutoMigratePlan) -> Result<()> {
1022+
// Track if any RLS rules were changed.
1023+
let mut old_rls = HashSet::new();
1024+
let mut new_rls = HashSet::new();
1025+
10221026
for rls in plan.old.row_level_security() {
1027+
old_rls.insert(rls.key());
10231028
plan.steps.push(AutoMigrateStep::RemoveRowLevelSecurity(rls.key()));
10241029
}
10251030
for rls in plan.new.row_level_security() {
1031+
new_rls.insert(rls.key());
10261032
plan.steps.push(AutoMigrateStep::AddRowLevelSecurity(rls.key()));
10271033
}
10281034

1035+
// We can force flush the cache by force disconnecting all clients if an RLS rule has been added, removed, or updated.
1036+
if old_rls != new_rls && !plan.disconnects_all_users() {
1037+
plan.steps.push(AutoMigrateStep::DisconnectAllUsers);
1038+
}
1039+
10291040
Ok(())
10301041
}
10311042

@@ -2245,4 +2256,52 @@ mod tests {
22452256
);
22462257
}
22472258
}
2259+
2260+
#[test]
2261+
fn change_rls_disconnect_clients() {
2262+
let old_def = create_module_def(|_builder| {});
2263+
2264+
let new_def = create_module_def(|_builder| {});
2265+
2266+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2267+
assert!(!plan.disconnects_all_users(), "{plan:#?}");
2268+
2269+
let old_def = create_module_def(|builder| {
2270+
builder.add_row_level_security("SELECT true;");
2271+
});
2272+
let new_def = create_module_def(|builder| {
2273+
builder.add_row_level_security("SELECT false;");
2274+
});
2275+
2276+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2277+
assert!(plan.disconnects_all_users(), "{plan:#?}");
2278+
2279+
let old_def = create_module_def(|builder| {
2280+
builder.add_row_level_security("SELECT true;");
2281+
});
2282+
2283+
let new_def = create_module_def(|_builder| {
2284+
// Remove RLS
2285+
});
2286+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2287+
assert!(plan.disconnects_all_users(), "{plan:#?}");
2288+
2289+
let old_def = create_module_def(|_builder| {});
2290+
2291+
let new_def = create_module_def(|builder| {
2292+
builder.add_row_level_security("SELECT false;");
2293+
});
2294+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2295+
assert!(plan.disconnects_all_users(), "{plan:#?}");
2296+
2297+
let old_def = create_module_def(|builder| {
2298+
builder.add_row_level_security("SELECT true;");
2299+
});
2300+
2301+
let new_def = create_module_def(|builder| {
2302+
builder.add_row_level_security("SELECT true;");
2303+
});
2304+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2305+
assert!(!plan.disconnects_all_users(), "{plan:#?}");
2306+
}
22482307
}

crates/schema/src/snapshots/spacetimedb_schema__auto_migrate__tests__empty_to_populated_migration.snap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ expression: "plan.pretty_print(PrettyPrintStyle::AnsiColor).expect(\"should pret
6161

6262
▸ Created row level security policy:
6363
`SELECT * FROM Apples`
64+
!!! Warning: All clients will be disconnected due to breaking schema changes

smoketests/tests/auto_migration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,15 @@ def test_add_table_auto_migration(self):
146146
)
147147

148148
self.write_module_code(self.MODULE_CODE_UPDATED)
149-
self.publish_module(self.database_identity, clear=False)
149+
self.publish_module(self.database_identity, clear=False, break_clients=True)
150150

151151
logging.info("Updated")
152152
self.call("add_person", "Husserl", "Student")
153153

154-
# If subscription, we should get 4 rows corresponding to 4 reducer calls (including before and after update)
154+
# If subscription, we should get 3 rows corresponding to 3 reducer calls (including before and after update)
155+
# minus the `print` one
155156
sub = sub();
156-
self.assertEqual(len(sub), 4)
157+
self.assertEqual(len(sub), 3)
157158

158159
# Check the row-level SQL filter is added correctly
159160
self.assertSql(

smoketests/tests/rls.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from .. import Smoketest, random_string
24

35
class Rls(Smoketest):
@@ -78,3 +80,71 @@ def test_publish_fails_for_rls_on_private_table(self):
7880

7981
with self.assertRaises(Exception):
8082
self.publish_module(name)
83+
84+
class DisconnectRls(Smoketest):
85+
AUTOPUBLISH = False
86+
87+
MODULE_CODE = """
88+
use spacetimedb::{Identity, ReducerContext, Table};
89+
90+
#[spacetimedb::table(name = users, public)]
91+
pub struct Users {
92+
name: String,
93+
identity: Identity,
94+
}
95+
96+
#[spacetimedb::reducer]
97+
pub fn add_user(ctx: &ReducerContext, name: String) {
98+
ctx.db.users().insert(Users { name, identity: ctx.sender });
99+
}
100+
101+
// RLS
102+
"""
103+
104+
ADD_RLS = """
105+
#[spacetimedb::client_visibility_filter]
106+
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
107+
"SELECT * FROM users WHERE identity = :sender"
108+
);
109+
"""
110+
111+
def test_rls_disconnect_if_change(self):
112+
"""This tests that changing the RLS rules disconnects existing clients"""
113+
114+
name = random_string()
115+
116+
self.write_module_code(self.MODULE_CODE)
117+
118+
self.publish_module(name)
119+
logging.info("Initial publish complete")
120+
121+
# Now add the RLS rules
122+
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
123+
self.publish_module(name, clear=False, break_clients=True)
124+
125+
logging.info("Re-publish with RLS complete")
126+
127+
logs = self.logs(100)
128+
129+
# Validate disconnect + schema migration logs
130+
self.assertIn("Disconnecting all users", logs)
131+
132+
def test_rls_disconnect_no(self):
133+
"""This tests that not changing the RLS rules does not disconnect existing clients"""
134+
135+
name = random_string()
136+
137+
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
138+
139+
self.publish_module(name)
140+
logging.info("Initial publish complete")
141+
142+
# Now re-publish the same module code
143+
self.publish_module(name, clear=False, break_clients=False)
144+
145+
logging.info("Re-publish without RLS change complete")
146+
147+
logs = self.logs(100)
148+
149+
# Validate no disconnect logs
150+
self.assertNotIn("Disconnecting all users", logs)

0 commit comments

Comments
 (0)