Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions crates/schema/src/auto_migrate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,13 +1028,24 @@ fn auto_migrate_constraints(plan: &mut AutoMigratePlan, new_tables: &HashSet<&Id
// Because we can refer to many tables and fields on the row level-security query, we need to remove all of them,
// then add the new ones, instead of trying to track the graph of dependencies.
fn auto_migrate_row_level_security(plan: &mut AutoMigratePlan) -> Result<()> {
// Track if any RLS rules were changed.
let mut old_rls = HashSet::new();
let mut new_rls = HashSet::new();

for rls in plan.old.row_level_security() {
old_rls.insert(rls.key());
plan.steps.push(AutoMigrateStep::RemoveRowLevelSecurity(rls.key()));
}
for rls in plan.new.row_level_security() {
new_rls.insert(rls.key());
plan.steps.push(AutoMigrateStep::AddRowLevelSecurity(rls.key()));
}

// We can force flush the cache by force disconnecting all clients if an RLS rule has been added, removed, or updated.
if old_rls != new_rls && !plan.disconnects_all_users() {
plan.steps.push(AutoMigrateStep::DisconnectAllUsers);
}

Ok(())
}

Expand Down Expand Up @@ -2293,4 +2304,52 @@ mod tests {
);
}
}

#[test]
fn change_rls_disconnect_clients() {
let old_def = create_module_def(|_builder| {});

let new_def = create_module_def(|_builder| {});

let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
assert!(!plan.disconnects_all_users(), "{plan:#?}");

let old_def = create_module_def(|builder| {
builder.add_row_level_security("SELECT true;");
});
let new_def = create_module_def(|builder| {
builder.add_row_level_security("SELECT false;");
});

let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
assert!(plan.disconnects_all_users(), "{plan:#?}");

let old_def = create_module_def(|builder| {
builder.add_row_level_security("SELECT true;");
});

let new_def = create_module_def(|_builder| {
// Remove RLS
});
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
assert!(plan.disconnects_all_users(), "{plan:#?}");

let old_def = create_module_def(|_builder| {});

let new_def = create_module_def(|builder| {
builder.add_row_level_security("SELECT false;");
});
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
assert!(plan.disconnects_all_users(), "{plan:#?}");

let old_def = create_module_def(|builder| {
builder.add_row_level_security("SELECT true;");
});

let new_def = create_module_def(|builder| {
builder.add_row_level_security("SELECT true;");
});
let plan = ponder_auto_migrate(&old_def, &new_def).expect("auto migration should succeed");
assert!(!plan.disconnects_all_users(), "{plan:#?}");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ expression: "plan.pretty_print(PrettyPrintStyle::AnsiColor).expect(\"should pret

▸ Created row level security policy:
`SELECT * FROM Apples`
!!! Warning: All clients will be disconnected due to breaking schema changes
30 changes: 1 addition & 29 deletions smoketests/tests/auto_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class AddTableAutoMigration(Smoketest):
x: f64,
y: f64,
}

#[spacetimedb::client_visibility_filter]
const PERSON_VISIBLE: spacetimedb::Filter = spacetimedb::Filter::Sql("SELECT * FROM person");
"""

MODULE_CODE = MODULE_CODE_INIT + """
Expand Down Expand Up @@ -100,9 +97,6 @@ class AddTableAutoMigration(Smoketest):
log::info!("{}: {}", prefix, book.isbn);
}
}

#[spacetimedb::client_visibility_filter]
const BOOK_VISIBLE: spacetimedb::Filter = spacetimedb::Filter::Sql("SELECT * FROM book");
"""
)

Expand All @@ -115,17 +109,6 @@ def assertSql(self, sql, expected):

def test_add_table_auto_migration(self):
"""This tests uploading a module with a schema change that should not require clearing the database."""

# Check the row-level SQL filter is created correctly
self.assertSql(
"SELECT sql FROM st_row_level_security",
"""\
sql
------------------------
"SELECT * FROM person"
""",
)

logging.info("Initial publish complete")

# Start a subscription before publishing the module, to test that the subscription remains intact after re-publishing.
Expand Down Expand Up @@ -154,18 +137,7 @@ def test_add_table_auto_migration(self):
# If subscription, we should get 4 rows corresponding to 4 reducer calls (including before and after update)
sub = sub();
self.assertEqual(len(sub), 4)

# Check the row-level SQL filter is added correctly
self.assertSql(
"SELECT sql FROM st_row_level_security",
"""\
sql
------------------------
"SELECT * FROM person"
"SELECT * FROM book"
""",
)


self.logs(100)

self.call("add_person", "Husserl", "Professor")
Expand Down
85 changes: 85 additions & 0 deletions smoketests/tests/rls.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from .. import Smoketest, random_string

class Rls(Smoketest):
Expand Down Expand Up @@ -78,3 +80,86 @@ def test_publish_fails_for_rls_on_private_table(self):

with self.assertRaises(Exception):
self.publish_module(name)

class DisconnectRls(Smoketest):
AUTOPUBLISH = False

MODULE_CODE = """
use spacetimedb::{Identity, ReducerContext, Table};

#[spacetimedb::table(name = users, public)]
pub struct Users {
name: String,
identity: Identity,
}

#[spacetimedb::reducer]
pub fn add_user(ctx: &ReducerContext, name: String) {
ctx.db.users().insert(Users { name, identity: ctx.sender });
}
"""

ADD_RLS = """
#[spacetimedb::client_visibility_filter]
const USER_FILTER: spacetimedb::Filter = spacetimedb::Filter::Sql(
"SELECT * FROM users WHERE identity = :sender"
);
"""

def assertSql(self, sql, expected):
self.maxDiff = None
sql_out = self.spacetime("sql", self.database_identity, sql)
sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()])
expected = "\n".join([line.rstrip() for line in expected.splitlines()])
self.assertMultiLineEqual(sql_out, expected)

def test_rls_disconnect_if_change(self):
"""This tests that changing the RLS rules disconnects existing clients"""

name = random_string()

self.write_module_code(self.MODULE_CODE)

self.publish_module(name)
logging.info("Initial publish complete")

# Now add the RLS rules
self.write_module_code(self.MODULE_CODE + self.ADD_RLS)
self.publish_module(name, clear=False, break_clients=True)

# Check the row-level SQL filter is added correctly
self.assertSql(
"SELECT sql FROM st_row_level_security",
"""\
sql
------------------------------------------------
"SELECT * FROM users WHERE identity = :sender"
""",
)

logging.info("Re-publish with RLS complete")

logs = self.logs(100)

# Validate disconnect + schema migration logs
self.assertIn("Disconnecting all users", logs)

def test_rls_no_disconnect(self):
"""This tests that not changing the RLS rules does not disconnect existing clients"""

name = random_string()

self.write_module_code(self.MODULE_CODE + self.ADD_RLS)

self.publish_module(name)
logging.info("Initial publish complete")

# Now re-publish the same module code
self.publish_module(name, clear=False, break_clients=False)

logging.info("Re-publish without RLS change complete")

logs = self.logs(100)

# Validate no disconnect logs
self.assertNotIn("Disconnecting all users", logs)
Loading