Skip to content

Commit 0cf0cc7

Browse files
committed
Do not alias fields of tracked_struct Values when updating
1 parent 4d92253 commit 0cf0cc7

File tree

4 files changed

+71
-85
lines changed

4 files changed

+71
-85
lines changed

src/attach.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@ impl Attached {
4141
// Already attached? Assert that the database has not changed.
4242
// NOTE: It's important to use `addr_eq` here because `NonNull::eq`
4343
// not only compares the address but also the type's metadata.
44-
if !std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()) {
45-
panic!(
46-
"Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}",
47-
);
48-
}
44+
assert!(
45+
std::ptr::addr_eq(current_db.as_ptr(), new_db.as_ptr()),
46+
"Cannot change database mid-query. current: {current_db:?}, new: {new_db:?}"
47+
);
4948

5049
Self { state: None }
5150
} else {

src/revision.rs

+7
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,13 @@ impl OptionalAtomicRevision {
104104
)
105105
}
106106

107+
pub(crate) fn swap_mut(&mut self, val: Option<Revision>) -> Option<Revision> {
108+
Revision::from_opt(std::mem::replace(
109+
self.data.get_mut(),
110+
val.map_or(0, |r| r.as_usize()),
111+
))
112+
}
113+
107114
pub(crate) fn compare_exchange(
108115
&self,
109116
current: Option<Revision>,

src/tracked_struct.rs

+59-79
Original file line numberDiff line numberDiff line change
@@ -400,18 +400,17 @@ where
400400
disambiguator,
401401
};
402402

403-
let current_revision = zalsa.current_revision();
404403
match zalsa_local.tracked_struct_id(&identity) {
405404
Some(id) => {
406405
// The struct already exists in the intern map.
407406
zalsa_local.add_output(self.database_key_index(id).into());
408-
self.update(zalsa, current_revision, id, &current_deps, fields);
407+
self.update(zalsa, id, &current_deps, fields);
409408
C::struct_from_id(id)
410409
}
411410

412411
None => {
413412
// This is a new tracked struct, so create an entry in the struct map.
414-
let id = self.allocate(zalsa, zalsa_local, current_revision, &current_deps, fields);
413+
let id = self.allocate(zalsa, zalsa_local, &current_deps, fields);
415414
let key = self.database_key_index(id);
416415
zalsa_local.add_output(key.into());
417416
zalsa_local.store_tracked_struct_id(identity, id);
@@ -424,10 +423,10 @@ where
424423
&'db self,
425424
zalsa: &'db Zalsa,
426425
zalsa_local: &'db ZalsaLocal,
427-
current_revision: Revision,
428426
current_deps: &StampedValue<()>,
429427
fields: C::Fields<'db>,
430428
) -> Id {
429+
let current_revision = zalsa.current_revision();
431430
let value = |_| Value {
432431
created_at: current_revision,
433432
updated_at: OptionalAtomicRevision::new(Some(current_revision)),
@@ -440,16 +439,14 @@ where
440439

441440
if let Some(id) = self.free_list.pop() {
442441
let data_raw = Self::data_raw(zalsa.table(), id);
443-
assert!(
442+
debug_assert!(
444443
unsafe { (*data_raw).updated_at.load().is_none() },
445-
"free list entry for `{id:?}` does not have `None` for `updated_at`"
444+
"free list entry for `{id:?}` should not be locked"
446445
);
447446

448447
// Overwrite the free-list entry. Use `*foo = ` because the entry
449448
// has been previously initialized and we want to free the old contents.
450-
unsafe {
451-
*data_raw = value(id);
452-
}
449+
unsafe { *data_raw = value(id) };
453450

454451
id
455452
} else {
@@ -467,7 +464,6 @@ where
467464
fn update<'db>(
468465
&'db self,
469466
zalsa: &'db Zalsa,
470-
current_revision: Revision,
471467
id: Id,
472468
current_deps: &StampedValue<()>,
473469
fields: C::Fields<'db>,
@@ -508,6 +504,7 @@ where
508504
// during the current revision and thus obtained an `&` reference to those fields
509505
// that is still live.
510506

507+
let current_revision = zalsa.current_revision();
511508
// UNSAFE: Marking as mut requires exclusive access for the duration of
512509
// the `mut`. We have now *claimed* this data by swapping in `None`,
513510
// any attempt to read concurrently will panic.
@@ -524,17 +521,19 @@ where
524521
// Acquire the write-lock. This can only fail if there is a parallel thread
525522
// reading from this same `id`, which can only happen if the user has leaked it.
526523
// Tsk tsk.
527-
let swapped_out = unsafe { (*data_raw).updated_at.swap(None) };
528-
if swapped_out != last_updated_at {
524+
525+
let swapped = unsafe { (*data_raw).updated_at.swap(None) };
526+
if last_updated_at != swapped {
529527
panic!(
530528
"failed to acquire write lock, id `{id:?}` must have been leaked across threads"
531529
);
532530
}
533531

534-
// UNSAFE: Marking as mut requires exclusive access for the duration of
532+
// SAFETY: Marking as mut requires exclusive access for the duration of
535533
// the `mut`. We have now *claimed* this data by swapping in `None`,
536-
// any attempt to read concurrently will panic.
537-
let data = unsafe { &mut *data_raw };
534+
// any attempt to read concurrently will panic. Note that we cannot create
535+
// a `&mut` reference to the full `Value` though because
536+
// another thread may access `updated_at` concurrently.
538537

539538
// SAFETY: We assert that the pointer to `data.revisions`
540539
// is a pointer into the database referencing a value
@@ -544,8 +543,8 @@ where
544543
unsafe {
545544
if C::update_fields(
546545
current_revision,
547-
&mut data.revisions,
548-
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
546+
&mut (*data_raw).revisions,
547+
self.to_self_ptr(std::ptr::addr_of_mut!((*data_raw).fields)),
549548
fields,
550549
) {
551550
// Consider this a new tracked-struct (even though it still uses the same id)
@@ -554,22 +553,25 @@ where
554553
// which makes Salsa consider two tracked structs to still be the same
555554
// even though the fields are different.
556555
// See `tracked-struct-id-field-bad-hash` for more details.
557-
data.created_at = current_revision;
556+
(*data_raw).revisions = C::new_revisions(current_revision);
557+
(*data_raw).created_at = current_revision;
558+
} else if current_deps.durability < (*data_raw).durability {
559+
(*data_raw).revisions = C::new_revisions(current_revision);
560+
(*data_raw).created_at = current_revision;
558561
}
562+
(*data_raw).durability = current_deps.durability;
559563
}
560-
if current_deps.durability < data.durability {
561-
data.revisions = C::new_revisions(current_revision);
562-
data.created_at = current_revision;
563-
}
564-
data.durability = current_deps.durability;
565-
let swapped_out = data.updated_at.swap(Some(current_revision));
566-
assert!(swapped_out.is_none());
564+
let swapped_out = unsafe { (*data_raw).updated_at.swap_mut(Some(current_revision)) };
565+
assert!(swapped_out.is_none(), "lock was acquired twice!");
567566
}
568567

569568
/// Fetch the data for a given id created by this ingredient from the table,
570569
/// -giving it the appropriate type.
571-
fn data(table: &Table, id: Id) -> &Value<C> {
572-
table.get(id)
570+
fn data(table: &Table, id: Id, current_revision: Revision) -> &Value<C> {
571+
let val = Self::data_raw(table, id);
572+
acquire_read_lock(unsafe { &(*val).updated_at }, current_revision);
573+
// We have acquired the read lock, so it is safe to return a reference to the data.
574+
unsafe { &*val }
573575
}
574576

575577
fn data_raw(table: &Table, id: Id) -> *mut Value<C> {
@@ -594,29 +596,23 @@ where
594596
});
595597

596598
let zalsa = db.zalsa();
597-
let current_revision = zalsa.current_revision();
598599
let data = Self::data_raw(zalsa.table(), id);
599600

600601
// We want to set `updated_at` to `None`, signalling that other field values
601602
// cannot be read. The current value should be `Some(R0)` for some older revision.
602-
let data_ref = unsafe { &*data };
603-
match data_ref.updated_at.load() {
603+
match unsafe { (*data).updated_at.swap(None) }{
604604
None => {
605605
panic!("cannot delete write-locked id `{id:?}`; value leaked across threads");
606606
}
607-
Some(r) if r == current_revision => panic!(
607+
Some(r) if r == zalsa.current_revision() => panic!(
608608
"cannot delete read-locked id `{id:?}`; value leaked across threads or user functions not deterministic"
609609
),
610-
Some(r) => {
611-
if data_ref.updated_at.compare_exchange(Some(r), None).is_err() {
612-
panic!("race occurred when deleting value `{id:?}`")
613-
}
614-
}
610+
Some(_) => ()
615611
}
616612

617613
// Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None`
618-
// and the code that references the memo-table has a read-lock.
619-
let memo_table = unsafe { (*data).take_memo_table() };
614+
// signalling that we have acquired the write lock
615+
let memo_table = std::mem::take(unsafe { &mut (*data).memos });
620616

621617
// SAFETY: We have verified that no more references to these memos exist and so we are good
622618
// to drop them.
@@ -648,7 +644,7 @@ where
648644
s: C::Struct<'db>,
649645
) -> &'db C::Fields<'db> {
650646
let id = C::deref_struct(s);
651-
let value = Self::data(db.zalsa().table(), id);
647+
let value = Self::data(db.zalsa().table(), id, db.zalsa().current_revision());
652648
unsafe { self.to_self_ref(&value.fields) }
653649
}
654650

@@ -670,9 +666,7 @@ where
670666
let (zalsa, zalsa_local) = db.zalsas();
671667
let id = C::deref_struct(s);
672668
let field_ingredient_index = self.ingredient_index.successor(relative_tracked_index);
673-
let data = Self::data(zalsa.table(), id);
674-
675-
data.read_lock(zalsa.current_revision());
669+
let data = Self::data(zalsa.table(), id, zalsa.current_revision());
676670

677671
let field_changed_at = data.revisions[relative_tracked_index];
678672

@@ -697,9 +691,7 @@ where
697691
) -> &'db C::Fields<'db> {
698692
let (zalsa, zalsa_local) = db.zalsas();
699693
let id = C::deref_struct(s);
700-
let data = Self::data(zalsa.table(), id);
701-
702-
data.read_lock(zalsa.current_revision());
694+
let data = Self::data(zalsa.table(), id, zalsa.current_revision());
703695

704696
// Add a dependency on the tracked struct itself.
705697
zalsa_local.report_tracked_read(
@@ -742,7 +734,7 @@ where
742734
revision: Revision,
743735
) -> MaybeChangedAfter {
744736
let zalsa = db.zalsa();
745-
let data = Self::data(zalsa.table(), input);
737+
let data = Self::data(zalsa.table(), input, zalsa.current_revision());
746738

747739
MaybeChangedAfter::from(data.created_at > revision)
748740
}
@@ -761,9 +753,7 @@ where
761753
_executor: DatabaseKeyIndex,
762754
_output_key: crate::Id,
763755
) {
764-
// we used to update `update_at` field but now we do it lazilly when data is accessed
765-
//
766-
// FIXME: delete this method
756+
// we used to update `update_at` field but now we do it lazily when data is accessed
767757
}
768758

769759
fn remove_stale_output(
@@ -776,7 +766,7 @@ where
776766
// `executor` creates a tracked struct `salsa_output_key`,
777767
// but it did not in the current revision.
778768
// In that case, we can delete `stale_output_key` and any data associated with it.
779-
self.delete_entity(db.as_dyn_database(), stale_output_key);
769+
self.delete_entity(db, stale_output_key);
780770
}
781771

782772
fn fmt_index(&self, index: Option<crate::Id>, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -811,34 +801,22 @@ where
811801
pub fn fields(&self) -> &C::Fields<'static> {
812802
&self.fields
813803
}
804+
}
814805

815-
fn take_memo_table(&mut self) -> MemoTable {
816-
// This fn is only called after `updated_at` has been set to `None`;
817-
// this ensures that there is no concurrent access
818-
// (and that the `&mut self` is accurate...).
819-
assert!(self.updated_at.load().is_none());
820-
821-
std::mem::take(&mut self.memos)
822-
}
823-
824-
fn read_lock(&self, current_revision: Revision) {
825-
loop {
826-
match self.updated_at.load() {
827-
None => {
828-
panic!("access to field whilst the value is being initialized");
829-
}
830-
Some(r) => {
831-
if r == current_revision {
832-
return;
833-
}
834-
835-
if self
836-
.updated_at
837-
.compare_exchange(Some(r), Some(current_revision))
838-
.is_ok()
839-
{
840-
break;
841-
}
806+
fn acquire_read_lock(updated_at: &OptionalAtomicRevision, current_revision: Revision) {
807+
loop {
808+
match updated_at.load() {
809+
None => panic!(
810+
"write lock taken; value leaked across threads or user functions not deterministic"
811+
),
812+
// the read lock was taken by someone else, so we also succeed
813+
Some(r) if r == current_revision => return,
814+
Some(r) => {
815+
if updated_at
816+
.compare_exchange(Some(r), Some(current_revision))
817+
.is_ok()
818+
{
819+
break;
842820
}
843821
}
844822
}
@@ -849,23 +827,25 @@ impl<C> Slot for Value<C>
849827
where
850828
C: Configuration,
851829
{
830+
// FIXME: `&self` may alias here before the lock is taken?
852831
unsafe fn memos(&self, current_revision: Revision) -> &crate::table::memo::MemoTable {
853832
// Acquiring the read lock here with the current revision
854833
// ensures that there is no danger of a race
855834
// when deleting a tracked struct.
856-
self.read_lock(current_revision);
835+
acquire_read_lock(&self.updated_at, current_revision);
857836
&self.memos
858837
}
859838

860839
fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable {
861840
&mut self.memos
862841
}
863842

843+
// FIXME: `&self` may alias here?
864844
unsafe fn syncs(&self, current_revision: Revision) -> &crate::table::sync::SyncTable {
865845
// Acquiring the read lock here with the current revision
866846
// ensures that there is no danger of a race
867847
// when deleting a tracked struct.
868-
self.read_lock(current_revision);
848+
acquire_read_lock(&self.updated_at, current_revision);
869849
&self.syncs
870850
}
871851
}

src/tracked_struct/tracked_field.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ where
6262
revision: crate::Revision,
6363
) -> MaybeChangedAfter {
6464
let zalsa = db.zalsa();
65-
let data = <super::IngredientImpl<C>>::data(zalsa.table(), input);
65+
let data = <super::IngredientImpl<C>>::data(zalsa.table(), input, zalsa.current_revision());
6666
let field_changed_at = data.revisions[self.field_index];
6767
MaybeChangedAfter::from(field_changed_at > revision)
6868
}

0 commit comments

Comments
 (0)