Skip to content

Commit 8dab9b7

Browse files
committed
Disconnect clients when updating RLS rules
1 parent 30b8eac commit 8dab9b7

File tree

4 files changed

+148
-29
lines changed

4 files changed

+148
-29
lines changed

crates/schema/src/auto_migrate.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,13 +1019,24 @@ fn auto_migrate_constraints(plan: &mut AutoMigratePlan, new_tables: &HashSet<&Id
10191019
// Because we can refer to many tables and fields on the row level-security query, we need to remove all of them,
10201020
// then add the new ones, instead of trying to track the graph of dependencies.
10211021
fn auto_migrate_row_level_security(plan: &mut AutoMigratePlan) -> Result<()> {
1022+
// Track if any RLS rules were changed.
1023+
let mut old_rls = HashSet::new();
1024+
let mut new_rls = HashSet::new();
1025+
10221026
for rls in plan.old.row_level_security() {
1027+
old_rls.insert(rls.key());
10231028
plan.steps.push(AutoMigrateStep::RemoveRowLevelSecurity(rls.key()));
10241029
}
10251030
for rls in plan.new.row_level_security() {
1031+
new_rls.insert(rls.key());
10261032
plan.steps.push(AutoMigrateStep::AddRowLevelSecurity(rls.key()));
10271033
}
10281034

1035+
// We can force flush the cache by force disconnecting all clients if an RLS rule has been added, removed, or updated.
1036+
if old_rls != new_rls && !plan.disconnects_all_users() {
1037+
plan.steps.push(AutoMigrateStep::DisconnectAllUsers);
1038+
}
1039+
10291040
Ok(())
10301041
}
10311042

@@ -2245,4 +2256,52 @@ mod tests {
22452256
);
22462257
}
22472258
}
2259+
2260+
#[test]
2261+
fn change_rls_disconnect_clients() {
2262+
let old_def = create_module_def(|_builder| {});
2263+
2264+
let new_def = create_module_def(|_builder| {});
2265+
2266+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2267+
assert!(!plan.disconnects_all_users(), "{plan:#?}");
2268+
2269+
let old_def = create_module_def(|builder| {
2270+
builder.add_row_level_security("SELECT true;");
2271+
});
2272+
let new_def = create_module_def(|builder| {
2273+
builder.add_row_level_security("SELECT false;");
2274+
});
2275+
2276+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2277+
assert!(plan.disconnects_all_users(), "{plan:#?}");
2278+
2279+
let old_def = create_module_def(|builder| {
2280+
builder.add_row_level_security("SELECT true;");
2281+
});
2282+
2283+
let new_def = create_module_def(|_builder| {
2284+
// Remove RLS
2285+
});
2286+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2287+
assert!(plan.disconnects_all_users(), "{plan:#?}");
2288+
2289+
let old_def = create_module_def(|_builder| {});
2290+
2291+
let new_def = create_module_def(|builder| {
2292+
builder.add_row_level_security("SELECT false;");
2293+
});
2294+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2295+
assert!(plan.disconnects_all_users(), "{plan:#?}");
2296+
2297+
let old_def = create_module_def(|builder| {
2298+
builder.add_row_level_security("SELECT true;");
2299+
});
2300+
2301+
let new_def = create_module_def(|builder| {
2302+
builder.add_row_level_security("SELECT true;");
2303+
});
2304+
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
2305+
assert!(!plan.disconnects_all_users(), "{plan:#?}");
2306+
}
22482307
}

crates/schema/src/snapshots/spacetimedb_schema__auto_migrate__tests__empty_to_populated_migration.snap

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,4 @@ expression: "plan.pretty_print(PrettyPrintStyle::AnsiColor).expect(\"should pret
6161

6262
▸ Created row level security policy:
6363
`SELECT * FROM Apples`
64+
!!! Warning: All clients will be disconnected due to breaking schema changes

smoketests/tests/auto_migration.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ class AddTableAutoMigration(Smoketest):
4040
x: f64,
4141
y: f64,
4242
}
43-
44-
#[spacetimedb::client_visibility_filter]
45-
const PERSON_VISIBLE: spacetimedb::Filter = spacetimedb::Filter::Sql("SELECT * FROM person");
4643
"""
4744

4845
MODULE_CODE = MODULE_CODE_INIT + """
@@ -100,9 +97,6 @@ class AddTableAutoMigration(Smoketest):
10097
log::info!("{}: {}", prefix, book.isbn);
10198
}
10299
}
103-
104-
#[spacetimedb::client_visibility_filter]
105-
const BOOK_VISIBLE: spacetimedb::Filter = spacetimedb::Filter::Sql("SELECT * FROM book");
106100
"""
107101
)
108102

@@ -115,17 +109,6 @@ def assertSql(self, sql, expected):
115109

116110
def test_add_table_auto_migration(self):
117111
"""This tests uploading a module with a schema change that should not require clearing the database."""
118-
119-
# Check the row-level SQL filter is created correctly
120-
self.assertSql(
121-
"SELECT sql FROM st_row_level_security",
122-
"""\
123-
sql
124-
------------------------
125-
"SELECT * FROM person"
126-
""",
127-
)
128-
129112
logging.info("Initial publish complete")
130113

131114
# Start a subscription before publishing the module, to test that the subscription remains intact after re-publishing.
@@ -154,18 +137,7 @@ def test_add_table_auto_migration(self):
154137
# If subscription, we should get 4 rows corresponding to 4 reducer calls (including before and after update)
155138
sub = sub();
156139
self.assertEqual(len(sub), 4)
157-
158-
# Check the row-level SQL filter is added correctly
159-
self.assertSql(
160-
"SELECT sql FROM st_row_level_security",
161-
"""\
162-
sql
163-
------------------------
164-
"SELECT * FROM person"
165-
"SELECT * FROM book"
166-
""",
167-
)
168-
140+
169141
self.logs(100)
170142

171143
self.call("add_person", "Husserl", "Professor")

smoketests/tests/rls.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from .. import Smoketest, random_string
24

35
class Rls(Smoketest):
@@ -78,3 +80,88 @@ def test_publish_fails_for_rls_on_private_table(self):
7880

7981
with self.assertRaises(Exception):
8082
self.publish_module(name)
83+
84+
class DisconnectRls(Smoketest):
85+
AUTOPUBLISH = False
86+
87+
MODULE_CODE = """
88+
use spacetimedb::{Identity, ReducerContext, Table};
89+
90+
#[spacetimedb::table(name = users, public)]
91+
pub struct Users {
92+
name: String,
93+
identity: Identity,
94+
}
95+
96+
#[spacetimedb::reducer]
97+
pub fn add_user(ctx: &ReducerContext, name: String) {
98+
ctx.db.users().insert(Users { name, identity: ctx.sender });
99+
}
100+
101+
// RLS
102+
"""
103+
104+
ADD_RLS = """
105+
#[spacetimedb::client_visibility_filter]
106+
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
107+
"SELECT * FROM users WHERE identity = :sender"
108+
);
109+
"""
110+
111+
def assertSql(self, sql, expected):
112+
self.maxDiff = None
113+
sql_out = self.spacetime("sql", self.database_identity, sql)
114+
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
115+
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
116+
self.assertMultiLineEqual(sql_out, expected)
117+
118+
def test_rls_disconnect_if_change(self):
119+
"""This tests that changing the RLS rules disconnects existing clients"""
120+
121+
name = random_string()
122+
123+
self.write_module_code(self.MODULE_CODE)
124+
125+
self.publish_module(name)
126+
logging.info("Initial publish complete")
127+
128+
# Now add the RLS rules
129+
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
130+
self.publish_module(name, clear=False, break_clients=True)
131+
132+
# Check the row-level SQL filter is added correctly
133+
self.assertSql(
134+
"SELECT sql FROM st_row_level_security",
135+
"""\
136+
sql
137+
------------------------------------------------
138+
"SELECT * FROM users WHERE identity = :sender"
139+
""",
140+
)
141+
142+
logging.info("Re-publish with RLS complete")
143+
144+
logs = self.logs(100)
145+
146+
# Validate disconnect + schema migration logs
147+
self.assertIn("Disconnecting all users", logs)
148+
149+
def test_rls_disconnect_no(self):
150+
"""This tests that not changing the RLS rules does not disconnect existing clients"""
151+
152+
name = random_string()
153+
154+
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
155+
156+
self.publish_module(name)
157+
logging.info("Initial publish complete")
158+
159+
# Now re-publish the same module code
160+
self.publish_module(name, clear=False, break_clients=False)
161+
162+
logging.info("Re-publish without RLS change complete")
163+
164+
logs = self.logs(100)
165+
166+
# Validate no disconnect logs
167+
self.assertNotIn("Disconnecting all users", logs)

0 commit comments

Comments
 (0)