Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions examples/network-hhmodel/seir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ fn calculate_waiting_time(context: &Context, shape: f64, mean_period: f64) -> f6
}

fn expose_network<ET: EdgeType<Person>>(context: &mut Context, beta: f64) {
let mut infectious_people = Vec::new();
context.with_query_results((DiseaseStatus::I,), &mut |infected| {
infectious_people = infected.iter().copied().collect();
});
let infectious_people = context.query((DiseaseStatus::I,)).to_owned_vec();

for infectious in infectious_people {
let edges = context.get_matching_edges::<Person, ET>(infectious, |context, edge| {
Expand Down Expand Up @@ -191,9 +188,7 @@ mod tests {

let mut to_infect = Vec::<PersonId>::new();
context.with_query_results((Id(71),), &mut |people| {
for p in people {
to_infect.push(*p);
}
to_infect.extend(people);
});

init(&mut context, &to_infect);
Expand Down
6 changes: 3 additions & 3 deletions ixa-bench/criterion/index_benches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn criterion_benchmark(criterion: &mut Criterion) {
context.with_query_results(
black_box((Property10(number % 10),)),
&mut |entity_ids| {
black_box(entity_ids.len());
black_box(entity_ids.len_upper());
},
);
}
Expand All @@ -65,7 +65,7 @@ pub fn criterion_benchmark(criterion: &mut Criterion) {
Property100(*number),
)),
&mut |entity_ids| {
black_box(entity_ids.len());
black_box(entity_ids.len_upper());
},
);
}
Expand All @@ -85,7 +85,7 @@ pub fn criterion_benchmark(criterion: &mut Criterion) {
MultiProperty100(*number),
)),
&mut |entity_ids| {
black_box(entity_ids.len());
black_box(entity_ids.len_upper());
},
);
}
Expand Down
123 changes: 93 additions & 30 deletions src/entity/context_extension.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use std::any::{Any, TypeId};

use crate::entity::entity_set::EntitySetIterator;
use crate::entity::entity_set::{EntitySet, EntitySetIterator, SourceSet};
use crate::entity::events::{EntityCreatedEvent, PartialPropertyChangeEvent};
use crate::entity::index::{IndexCountResult, IndexSetResult, PropertyIndexType};
use crate::entity::property::Property;
use crate::entity::property_list::PropertyList;
use crate::entity::query::Query;
use crate::entity::{Entity, EntityId, PopulationIterator};
use crate::hashing::IndexSet;
use crate::rand::Rng;
use crate::random::sample_multiple_from_known_length;
use crate::{warn, Context, ContextRandomExt, IxaError, RngId};

/// A trait extension for [`Context`] that exposes entity-related
Expand Down Expand Up @@ -59,13 +59,13 @@ pub trait ContextEntitiesExt {
#[cfg(test)]
fn is_property_indexed<E: Entity, P: Property<E>>(&self) -> bool;

/// This method gives client code direct immutable access to the fully realized set of
/// entity IDs. This is especially efficient for indexed queries, as this method reduces
/// to a simple lookup of a hash bucket. Otherwise, the set is allocated and computed.
fn with_query_results<E: Entity, Q: Query<E>>(
&self,
/// This method gives client code direct access to the query result as an `EntitySet`.
/// This is especially efficient for indexed queries, as this method can reduce to wrapping
/// a single indexed source.
fn with_query_results<'a, E: Entity, Q: Query<E>>(
&'a self,
query: Q,
callback: &mut dyn FnMut(&IndexSet<EntityId<E>>),
callback: &mut dyn FnMut(EntitySet<'a, E>),
);

/// Gives the count of distinct entity IDs satisfying the query. This is especially
Expand Down Expand Up @@ -103,6 +103,9 @@ pub trait ContextEntitiesExt {
/// Returns an iterator over all created entities of type `E`.
fn get_entity_iterator<E: Entity>(&self) -> PopulationIterator<E>;

/// Generates an `EntitySet` representing the query results.
fn query<E: Entity, Q: Query<E>>(&self, query: Q) -> EntitySet<E>;

/// Generates an iterator over the results of the query.
fn query_result_iterator<E: Entity, Q: Query<E>>(&self, query: Q) -> EntitySetIterator<E>;

Expand Down Expand Up @@ -239,14 +242,14 @@ impl ContextEntitiesExt for Context {
property_store.is_property_indexed::<P>()
}

fn with_query_results<E: Entity, Q: Query<E>>(
&self,
fn with_query_results<'a, E: Entity, Q: Query<E>>(
&'a self,
query: Q,
callback: &mut dyn FnMut(&IndexSet<EntityId<E>>),
callback: &mut dyn FnMut(EntitySet<'a, E>),
) {
// The fast path for indexed queries.

// This mirrors the indexed case in `SourceSet<'a, E>::new()` and `Query:: new_query_result_iterator`.
// This mirrors the indexed case in `SourceSet<'a, E>::new()` and `Query::new_query_result`.
// The difference is, we access the index set if we find it.
if let Some(multi_property_id) = query.multi_property_id() {
let property_store = self.entity_store.get_property_store::<E>();
Expand All @@ -256,12 +259,11 @@ impl ContextEntitiesExt for Context {
query.multi_property_value_hash(),
) {
IndexSetResult::Set(people_set) => {
callback(&people_set);
callback(EntitySet::from_source(SourceSet::IndexSet(people_set)));
return;
}
IndexSetResult::Empty => {
let people_set = IndexSet::default();
callback(&people_set);
callback(EntitySet::empty());
return;
}
IndexSetResult::Unsupported => {}
Expand All @@ -272,25 +274,23 @@ impl ContextEntitiesExt for Context {
// Special case the empty query, which creates a set containing the entire population.
if query.type_id() == TypeId::of::<()>() {
warn!("Called Context::with_query_results() with an empty query. Prefer Context::get_entity_iterator::<E>() for working with the entire population.");
let entity_set = self.get_entity_iterator::<E>().collect::<IndexSet<_>>();
callback(&entity_set);
callback(EntitySet::from_source(SourceSet::Population(
self.get_entity_count::<E>(),
)));
return;
}

// The slow path of computing the full query set.
warn!("Called Context::with_query_results() with an unindexed query. It's almost always better to use Context::query_result_iterator() for unindexed queries.");

// Fall back to `EntitySetIterator`.
let people_set = query
.new_query_result_iterator(self)
.collect::<IndexSet<_>>();
callback(&people_set);
// Fall back to the query's `EntitySet`.
callback(self.query(query));
}

fn query_entity_count<E: Entity, Q: Query<E>>(&self, query: Q) -> usize {
// The fast path for indexed queries.
//
// This mirrors the indexed case in `SourceSet<'a, E>::new()` and `Query:: new_query_result_iterator`.
// This mirrors the indexed case in `SourceSet<'a, E>::new()` and `Query::new_query_result`.
if let Some(multi_property_id) = query.multi_property_id() {
let property_store = self.entity_store.get_property_store::<E>();
match property_store.get_index_count_with_hash_for_property_id(
Expand All @@ -313,6 +313,22 @@ impl ContextEntitiesExt for Context {
R: RngId + 'static,
R::RngType: Rng,
{
if query.type_id() == TypeId::of::<()>() {
let population = self.get_entity_count::<E>();
return self.sample(rng_id, move |rng| {
if population == 0 {
warn!("Requested a sample entity from an empty population");
return None;
}
let index = if population <= u32::MAX as usize {
rng.random_range(0..population as u32) as usize
} else {
rng.random_range(0..population)
};
Some(EntityId::new(index))
});
}

let query_result = self.query_result_iterator(query);
self.sample(rng_id, move |rng| query_result.sample_entity(rng))
}
Expand All @@ -324,6 +340,20 @@ impl ContextEntitiesExt for Context {
R: RngId + 'static,
R::RngType: Rng,
{
if query.type_id() == TypeId::of::<()>() {
let population = self.get_entity_count::<E>();
return self.sample(rng_id, move |rng| {
if population == 0 {
warn!("Requested a sample of entities from an empty population");
return vec![];
}
if n >= population {
return PopulationIterator::<E>::new(population).collect();
}
sample_multiple_from_known_length(rng, PopulationIterator::<E>::new(population), n)
});
}

let query_result = self.query_result_iterator(query);
self.sample(rng_id, move |rng| query_result.sample_entities(rng, n))
}
Expand All @@ -336,6 +366,10 @@ impl ContextEntitiesExt for Context {
self.entity_store.get_entity_iterator::<E>()
}

fn query<E: Entity, Q: Query<E>>(&self, query: Q) -> EntitySet<E> {
query.new_query_result(self)
}

fn query_result_iterator<E: Entity, Q: Query<E>>(&self, query: Q) -> EntitySetIterator<E> {
query.new_query_result_iterator(self)
}
Expand Down Expand Up @@ -524,13 +558,13 @@ mod tests {

let mut existing_len = 0;
context.with_query_results((existing_value,), &mut |people_set| {
existing_len = people_set.len();
existing_len = people_set.into_iter().count();
});
assert_eq!(existing_len, 2);

let mut missing_len = 0;
context.with_query_results((missing_value,), &mut |people_set| {
missing_len = people_set.len();
missing_len = people_set.into_iter().count();
});
assert_eq!(missing_len, 0);

Expand Down Expand Up @@ -786,12 +820,12 @@ mod tests {
// Force an index build by running a query.
let _ = context.query_result_iterator((InfectionStatus::Susceptible, Vaccinated(true)));

// Capture the address of the has set given by `with_query_result`
let mut address: *const IndexSet<EntityId<Person>> = std::ptr::null();
// Capture the set given by `with_query_results`.
let mut result_entities: IndexSet<EntityId<Person>> = IndexSet::default();
context.with_query_results(
(InfectionStatus::Susceptible, Vaccinated(true)),
&mut |result_set| {
address = result_set as *const _;
result_entities = result_set.into_iter().collect::<IndexSet<_>>();
},
);

Expand All @@ -818,8 +852,37 @@ mod tests {
)
.unwrap();

let address2 = &*bucket as *const _;
assert_eq!(address2, address);
let expected_entities = bucket.iter().copied().collect::<IndexSet<_>>();
assert_eq!(expected_entities, result_entities);
}

#[test]
fn query_returns_entity_set_and_query_result_iterator_remains_compatible() {
let mut context = Context::new();
let p1 = context
.add_entity((Age(21), InfectionStatus::Susceptible, Vaccinated(true)))
.unwrap();
let _p2 = context
.add_entity((Age(22), InfectionStatus::Susceptible, Vaccinated(false)))
.unwrap();
let p3 = context
.add_entity((Age(23), InfectionStatus::Infected, Vaccinated(true)))
.unwrap();

let query = (Vaccinated(true),);

let from_set = context
.query::<Person, _>(query)
.into_iter()
.collect::<IndexSet<_>>();
let from_iterator = context
.query_result_iterator(query)
.collect::<IndexSet<_>>();

assert_eq!(from_set, from_iterator);
assert!(from_set.contains(&p1));
assert!(from_set.contains(&p3));
assert_eq!(from_set.len(), 2);
}

#[test]
Expand Down
6 changes: 6 additions & 0 deletions src/entity/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ impl<E: Entity> PopulationIterator<E> {
_phantom: PhantomData,
}
}

#[must_use]
#[allow(dead_code)]
pub(crate) fn population(&self) -> usize {
self.population
}
}

impl<E: Entity> Iterator for PopulationIterator<E> {
Expand Down
Loading