Skip to content

Commit 8a61975

Browse files
committed
Simplify unsafe code a bit
1 parent 0cf0cc7 commit 8a61975

File tree

1 file changed

+63
-62
lines changed

1 file changed

+63
-62
lines changed

src/tracked_struct.rs

+63-62
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ where
160160
ingredient_index: IngredientIndex,
161161

162162
/// Phantom data: we fetch `Value<C>` out from `Table`
163-
phantom: PhantomData<fn() -> Value<C>>,
163+
phantom: PhantomData<fn() -> ValueWithLock<C>>,
164164

165165
/// Store freed ids
166166
free_list: SegQueue<Id>,
@@ -255,26 +255,10 @@ impl IdentityMap {
255255
}
256256

257257
// ANCHOR: ValueStruct
258-
#[derive(Debug)]
259-
pub struct Value<C>
258+
pub struct ValueWithLock<C>
260259
where
261260
C: Configuration,
262261
{
263-
/// The durability minimum durability of all inputs consumed
264-
/// by the creator query prior to creating this tracked struct.
265-
/// If any of those inputs changes, then the creator query may
266-
/// create this struct with different values.
267-
durability: Durability,
268-
269-
/// The revision in which the tracked struct was first created.
270-
///
271-
/// Unlike `updated_at`, which gets bumped on every read,
272-
/// `created_at` is updated whenever an untracked field is updated.
273-
/// This is necessary to detect reused tracked struct ids _after_
274-
/// they've been freed in a prior revision or tracked structs that have been updated
275-
/// in-place because of a bad `Hash` implementation.
276-
created_at: Revision,
277-
278262
/// The revision when this tracked struct was last updated.
279263
/// This field also acts as a kind of "lock". Once it is equal
280264
/// to `Some(current_revision)`, the fields are locked and
@@ -295,6 +279,27 @@ where
295279
/// This `None` value should never be observable by users unless they have
296280
/// leaked a reference across threads somehow.
297281
updated_at: OptionalAtomicRevision,
282+
value: Value<C>,
283+
}
284+
285+
pub struct Value<C>
286+
where
287+
C: Configuration,
288+
{
289+
/// The durability minimum durability of all inputs consumed
290+
/// by the creator query prior to creating this tracked struct.
291+
/// If any of those inputs changes, then the creator query may
292+
/// create this struct with different values.
293+
durability: Durability,
294+
295+
/// The revision in which the tracked struct was first created.
296+
///
297+
/// Unlike `updated_at`, which gets bumped on every read,
298+
/// `created_at` is updated whenever an untracked field is updated.
299+
/// This is necessary to detect reused tracked struct ids _after_
300+
/// they've been freed in a prior revision or tracked structs that have been updated
301+
/// in-place because of a bad `Hash` implementation.
302+
created_at: Revision,
298303

299304
/// Fields of this tracked struct. They can change across revisions,
300305
/// but they do not change within a particular revision.
@@ -427,14 +432,16 @@ where
427432
fields: C::Fields<'db>,
428433
) -> Id {
429434
let current_revision = zalsa.current_revision();
430-
let value = |_| Value {
431-
created_at: current_revision,
435+
let value = |_| ValueWithLock {
432436
updated_at: OptionalAtomicRevision::new(Some(current_revision)),
433-
durability: current_deps.durability,
434-
fields: unsafe { self.to_static(fields) },
435-
revisions: C::new_revisions(current_deps.changed_at),
436-
memos: Default::default(),
437-
syncs: Default::default(),
437+
value: Value {
438+
created_at: current_revision,
439+
durability: current_deps.durability,
440+
fields: unsafe { self.to_static(fields) },
441+
revisions: C::new_revisions(current_deps.changed_at),
442+
memos: Default::default(),
443+
syncs: Default::default(),
444+
},
438445
};
439446

440447
if let Some(id) = self.free_list.pop() {
@@ -450,7 +457,7 @@ where
450457

451458
id
452459
} else {
453-
zalsa_local.allocate::<Value<C>>(zalsa.table(), self.ingredient_index, value)
460+
zalsa_local.allocate::<ValueWithLock<C>>(zalsa.table(), self.ingredient_index, value)
454461
}
455462
}
456463

@@ -505,9 +512,6 @@ where
505512
// that is still live.
506513

507514
let current_revision = zalsa.current_revision();
508-
// UNSAFE: Marking as mut requires exclusive access for the duration of
509-
// the `mut`. We have now *claimed* this data by swapping in `None`,
510-
// any attempt to read concurrently will panic.
511515
let last_updated_at = unsafe { (*data_raw).updated_at.load() };
512516
assert!(
513517
last_updated_at.is_some(),
@@ -528,39 +532,36 @@ where
528532
"failed to acquire write lock, id `{id:?}` must have been leaked across threads"
529533
);
530534
}
531-
532535
// SAFETY: Marking as mut requires exclusive access for the duration of
533536
// the `mut`. We have now *claimed* this data by swapping in `None`,
534-
// any attempt to read concurrently will panic. Note that we cannot create
535-
// a `&mut` reference to the full `Value` though because
536-
// another thread may access `updated_at` concurrently.
537+
// any attempt to read concurrently will panic.
538+
let data = unsafe { &mut (*data_raw).value };
537539

538540
// SAFETY: We assert that the pointer to `data.revisions`
539541
// is a pointer into the database referencing a value
540542
// from a previous revision. As such, it continues to meet
541543
// its validity invariant and any owned content also continues
542544
// to meet its safety invariant.
543-
unsafe {
544-
if C::update_fields(
545+
let updated = unsafe {
546+
C::update_fields(
545547
current_revision,
546-
&mut (*data_raw).revisions,
547-
self.to_self_ptr(std::ptr::addr_of_mut!((*data_raw).fields)),
548+
&mut data.revisions,
549+
self.to_self_ptr(std::ptr::addr_of_mut!(data.fields)),
548550
fields,
549-
) {
550-
// Consider this a new tracked-struct (even though it still uses the same id)
551-
// when any non-tracked field got updated.
552-
// This should be rare and only ever happen if there's a hash collision
553-
// which makes Salsa consider two tracked structs to still be the same
554-
// even though the fields are different.
555-
// See `tracked-struct-id-field-bad-hash` for more details.
556-
(*data_raw).revisions = C::new_revisions(current_revision);
557-
(*data_raw).created_at = current_revision;
558-
} else if current_deps.durability < (*data_raw).durability {
559-
(*data_raw).revisions = C::new_revisions(current_revision);
560-
(*data_raw).created_at = current_revision;
561-
}
562-
(*data_raw).durability = current_deps.durability;
551+
)
552+
};
553+
if updated || current_deps.durability < data.durability {
554+
// If `updated`, consider this a new tracked-struct (even though it still uses the same id)
555+
// when any non-tracked field got updated.
556+
// This should be rare and only ever happen if there's a hash collision
557+
// which makes Salsa consider two tracked structs to still be the same
558+
// even though the fields are different.
559+
// See `tracked-struct-id-field-bad-hash` for more details.
560+
data.revisions = C::new_revisions(current_revision);
561+
data.created_at = current_revision;
563562
}
563+
data.durability = current_deps.durability;
564+
// release the lock
564565
let swapped_out = unsafe { (*data_raw).updated_at.swap_mut(Some(current_revision)) };
565566
assert!(swapped_out.is_none(), "lock was acquired twice!");
566567
}
@@ -571,10 +572,10 @@ where
571572
let val = Self::data_raw(table, id);
572573
acquire_read_lock(unsafe { &(*val).updated_at }, current_revision);
573574
// We have acquired the read lock, so it is safe to return a reference to the data.
574-
unsafe { &*val }
575+
unsafe { &(*val).value }
575576
}
576577

577-
fn data_raw(table: &Table, id: Id) -> *mut Value<C> {
578+
fn data_raw(table: &Table, id: Id) -> *mut ValueWithLock<C> {
578579
table.get_raw(id)
579580
}
580581

@@ -612,7 +613,7 @@ where
612613

613614
// Take the memo table. This is safe because we have modified `data_ref.updated_at` to `None`
614615
// signalling that we have acquired the write lock
615-
let memo_table = std::mem::take(unsafe { &mut (*data).memos });
616+
let memo_table = std::mem::take(unsafe { &mut (*data).value.memos });
616617

617618
// SAFETY: We have verified that no more references to these memos exist and so we are good
618619
// to drop them.
@@ -709,12 +710,12 @@ where
709710
pub fn entries<'db>(
710711
&'db self,
711712
db: &'db dyn crate::Database,
712-
) -> impl Iterator<Item = &'db Value<C>> {
713+
) -> impl Iterator<Item = &'db ValueWithLock<C>> {
713714
db.zalsa()
714715
.table()
715716
.pages
716717
.iter()
717-
.filter_map(|(_, page)| page.cast_type::<crate::table::Page<Value<C>>>())
718+
.filter_map(|(_, page)| page.cast_type::<crate::table::Page<ValueWithLock<C>>>())
718719
.flat_map(|page| page.slots())
719720
}
720721
}
@@ -789,7 +790,7 @@ where
789790
}
790791
}
791792

792-
impl<C> Value<C>
793+
impl<C> ValueWithLock<C>
793794
where
794795
C: Configuration,
795796
{
@@ -799,7 +800,7 @@ where
799800
/// a particular revision.
800801
#[cfg(feature = "salsa_unstable")]
801802
pub fn fields(&self) -> &C::Fields<'static> {
802-
&self.fields
803+
&self.value.fields
803804
}
804805
}
805806

@@ -823,7 +824,7 @@ fn acquire_read_lock(updated_at: &OptionalAtomicRevision, current_revision: Revi
823824
}
824825
}
825826

826-
impl<C> Slot for Value<C>
827+
impl<C> Slot for ValueWithLock<C>
827828
where
828829
C: Configuration,
829830
{
@@ -833,11 +834,11 @@ where
833834
// ensures that there is no danger of a race
834835
// when deleting a tracked struct.
835836
acquire_read_lock(&self.updated_at, current_revision);
836-
&self.memos
837+
&self.value.memos
837838
}
838839

839840
fn memos_mut(&mut self) -> &mut crate::table::memo::MemoTable {
840-
&mut self.memos
841+
&mut self.value.memos
841842
}
842843

843844
// FIXME: `&self` may alias here?
@@ -846,7 +847,7 @@ where
846847
// ensures that there is no danger of a race
847848
// when deleting a tracked struct.
848849
acquire_read_lock(&self.updated_at, current_revision);
849-
&self.syncs
850+
&self.value.syncs
850851
}
851852
}
852853

0 commit comments

Comments
 (0)