diff --git a/src/Amalgam/Conviction.h b/src/Amalgam/Conviction.h index f7336e20..fdbdb38f 100644 --- a/src/Amalgam/Conviction.h +++ b/src/Amalgam/Conviction.h @@ -54,7 +54,8 @@ class ConvictionProcessor knnCache->GetKnn(entity_reference, numNearestNeighbors, true, buffers->neighbors, additional_holdout_reference); - return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_reference); + double entity_weight = distanceTransform->getEntityWeightFunction(entity_reference); + return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_weight); } //Like the other ComputeDistanceContribution, but only includes included_entities @@ -64,7 +65,8 @@ class ConvictionProcessor buffers->neighbors.clear(); knnCache->GetKnn(entity_reference, numNearestNeighbors, true, buffers->neighbors, included_entities); - return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_reference); + double entity_weight = distanceTransform->getEntityWeightFunction(entity_reference); + return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_weight); } //Computes the Distance Contributions for each entity specified in entities_to_compute diff --git a/src/Amalgam/DistanceReferencePair.h b/src/Amalgam/DistanceReferencePair.h index f97d0592..719425fd 100644 --- a/src/Amalgam/DistanceReferencePair.h +++ b/src/Amalgam/DistanceReferencePair.h @@ -33,6 +33,11 @@ class DistanceReferencePair return distance == drp.distance; } + constexpr bool operator ==(const double &value) const + { + return distance == value; + } + constexpr bool SameReference(const DistanceReferencePair &drp) const { return reference == drp.reference; diff --git a/src/Amalgam/SeparableBoxFilterDataStore.cpp b/src/Amalgam/SeparableBoxFilterDataStore.cpp index 94c85ee0..943bb2cb 100644 --- a/src/Amalgam/SeparableBoxFilterDataStore.cpp +++ b/src/Amalgam/SeparableBoxFilterDataStore.cpp @@ -2,6 +2,9 @@ #include "Entity.h" #include "SeparableBoxFilterDataStore.h" +//system headers +#include + #if defined(MULTITHREAD_SUPPORT) || defined(MULTITHREAD_INTERFACE) thread_local #endif @@ -566,11 +569,14 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized //cache kth smallest distance to target search node double worst_candidate_distance = std::numeric_limits::infinity(); + //assume there's an error in each addition and subtraction + double distance_threshold_to_consider_zero = 2 + * static_cast(num_enabled_features) * std::numeric_limits::epsilon(); if(sorted_results.Size() == top_k) { double top_distance = sorted_results.Top().distance; //don't clamp top distance if we're expanding and only have 0 distances - if(! (expand_to_first_nonzero_distance && top_distance == 0.0) ) + if(! (expand_to_first_nonzero_distance && top_distance <= distance_threshold_to_consider_zero) ) worst_candidate_distance = top_distance; } @@ -589,7 +595,7 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized { double top_distance = sorted_results.Top().distance; //don't clamp top distance if we're expanding and only have 0 distances - if(!(expand_to_first_nonzero_distance && top_distance == 0.0)) + if(!(expand_to_first_nonzero_distance && top_distance <= distance_threshold_to_consider_zero)) worst_candidate_distance = top_distance; } @@ -605,7 +611,7 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized continue; //if not expanding and pushing a zero distance onto the stack, then push and pop a value onto the stack - if(!(expand_to_first_nonzero_distance && distance == 0.0)) + if(!(expand_to_first_nonzero_distance && distance <= distance_threshold_to_consider_zero)) worst_candidate_distance = sorted_results.PushAndPop(DistanceReferencePair(distance, entity_index)).distance; else //adding a zero and need to expand beyond zeros { @@ -617,7 +623,7 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized sorted_results.Pop(); //if the next largest size is zero, then need to put the non-zero value back in sorted_results - if(sorted_results.Size() > 0 && sorted_results.Top().distance == 0) + if(sorted_results.Size() > 0 && sorted_results.Top().distance <= distance_threshold_to_consider_zero) sorted_results.Push(drp); } } diff --git a/src/Amalgam/SeparableBoxFilterDataStore.h b/src/Amalgam/SeparableBoxFilterDataStore.h index 5329e90c..af0625ea 100644 --- a/src/Amalgam/SeparableBoxFilterDataStore.h +++ b/src/Amalgam/SeparableBoxFilterDataStore.h @@ -439,24 +439,23 @@ class SeparableBoxFilterDataStore //returns a function that will take in an entity index and reference to a double to store the value and return true if the value is found // assumes and requires column_index is a valid column (not a feature_id) - inline std::function GetNumberValueFromEntityIndexFunction(size_t column_index) + inline std::function GetNumberValueFromEntityIndexFunction(size_t column_index) { - //if invalid column_index, then always return false + //if invalid column_index, then always return 1.0 if(column_index >= columnData.size()) - return [](size_t i, double &value) { return false; }; + return [](size_t i) { return 1.0; }; auto column_data = columnData[column_index].get(); auto number_indices_ptr = &column_data->numberIndices; auto value_type = column_data->GetUnresolvedValueType(ENIVT_NUMBER); return [&, number_indices_ptr, column_index, column_data, value_type] - (size_t i, double &value) + (size_t i) { if(!number_indices_ptr->contains(i)) - return false; + return 1.0; - value = column_data->GetResolvedValue(value_type, GetValue(i, column_index)).number; - return true; + return column_data->GetResolvedValue(value_type, GetValue(i, column_index)).number; }; } diff --git a/src/Amalgam/entity/EntityQueries.cpp b/src/Amalgam/entity/EntityQueries.cpp index 7d7a1ecb..4d5d30b3 100644 --- a/src/Amalgam/entity/EntityQueries.cpp +++ b/src/Amalgam/entity/EntityQueries.cpp @@ -725,12 +725,12 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain it.distance = GetConditionDistanceMeasure(it.reference, true); } - auto weight_function = [this](Entity *e, double &weight_value) + auto weight_function = [this](Entity *e) { - auto [ret_val, found] = e->GetValueAtLabelAsNumber(weightLabel); + auto [weight, found] = e->GetValueAtLabelAsNumber(weightLabel); if(found) - weight_value = ret_val; - return found; + return weight; + return 1.0; }; //transform distances as appropriate @@ -774,12 +774,12 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain auto weight_function = [this] - (Entity *e, double &weight_value) + (Entity *e) { - auto [ret_val, found] = e->GetValueAtLabelAsNumber(weightLabel); + auto [weight, found] = e->GetValueAtLabelAsNumber(weightLabel); if(found) - weight_value = ret_val; - return found; + return weight; + return 1.0; }; //transform distances as appropriate diff --git a/src/Amalgam/entity/EntityQueriesStatistics.h b/src/Amalgam/entity/EntityQueriesStatistics.h index f62fba25..a1d49ccc 100644 --- a/src/Amalgam/entity/EntityQueriesStatistics.h +++ b/src/Amalgam/entity/EntityQueriesStatistics.h @@ -717,7 +717,7 @@ class EntityQueriesStatistics //if compute_surprisal is false, distance_weight_exponent is the exponent each distance is raised to //uses min_to_retrieve and max_to_retrieve to determine how many entities to keep, stopping when the first // entity's marginal probability falls below the num_to_retrieve_min_increment_prob threshold - //has_weight, if set, will use get_weight, taking in a function of an entity reference and a reference to an output + //has_weight, if set, will use the function get_weight to deterimne the given entity's weight // double to set the weight, and should return true if the entity has a weight, false if not template class DistanceTransform @@ -727,7 +727,7 @@ class EntityQueriesStatistics double distance_weight_exponent, size_t min_to_retrieve, size_t max_to_retrieve, double num_to_retrieve_min_increment_prob, - bool has_weight, double min_weight, std::function get_weight) + bool has_weight, double min_weight, std::function get_weight) { distanceWeightExponent = distance_weight_exponent; computeSurprisal = compute_surprisal; @@ -782,6 +782,10 @@ class EntityQueriesStatistics if(minToRetrieve < maxToRetrieve || numToRetrieveMinIncrementalProbability > 0.0) { + //if no elements, just return zero + if(entity_distance_pair_container_begin == entity_distance_pair_container_end) + return 0; + auto [first_weighted_value, first_unweighted_value, first_prob, first_prob_mass, first_weight] = transform_func(entity_distance_pair_container_begin); result_func(entity_distance_pair_container_begin, first_weighted_value, first_unweighted_value, first_prob_mass, first_weight); @@ -842,8 +846,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(prob, prob, prob, prob, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); double weighted_prob = prob * weight; return std::make_tuple(weighted_prob, prob, prob, weighted_prob, weight); @@ -860,8 +863,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(surprisal, surprisal, prob, prob, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); return std::make_tuple(surprisal, surprisal, prob, prob * weight, weight); }, result_func); @@ -879,8 +881,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(prob, prob, prob, prob, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); double weighted_prob = prob * weight; return std::make_tuple(weighted_prob, prob, prob, weighted_prob, weight); @@ -898,8 +899,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(iter->distance, iter->distance, prob, prob, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); return std::make_tuple(iter->distance, iter->distance, prob, weight * prob, weight); }, result_func); @@ -913,8 +913,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(1.0, 1.0, 1.0, 1.0, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); return std::make_tuple(weight, 1.0, 1.0, weight, weight); }, result_func); @@ -933,8 +932,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(iter->distance, iter->distance, prob, prob, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); return std::make_tuple(iter->distance, iter->distance, prob, weight * prob, weight); }, result_func); @@ -951,8 +949,7 @@ class EntityQueriesStatistics if(!hasWeight) return std::make_tuple(iter->distance, iter->distance, prob, prob, 1.0); - double weight = 1.0; - getEntityWeightFunction(iter->reference, weight); + double weight = getEntityWeightFunction(iter->reference); double weighted_prob = prob * weight; return std::make_tuple(weighted_prob, prob, prob, weighted_prob, weight); @@ -1111,11 +1108,14 @@ class EntityQueriesStatistics } //Computes the distance contribution as a type of generalized mean with special handling for distances of zero - // entity is the entity that the distance contribution is being performed on, and entity_distance_pair_container are the distances to - // its nearest entities + // entity_distance_pair_container are the distances to its nearest entities, + // and entity_weight is the weight of the entity for which this distance contribution is being computed // the functions get_entity and get_distance_ref return the entity and reference to the distance for an iterator of entity_distance_pair_container - double ComputeDistanceContribution(std::vector> &entity_distance_pair_container, EntityReference entity) + double ComputeDistanceContribution(std::vector> &entity_distance_pair_container, double entity_weight) { + if(entity_weight == 0.0) + return 0.0; + double distance_contribution = 0.0; //there's at least one entity in question size_t num_identical_entities = 1; @@ -1150,11 +1150,7 @@ class EntityQueriesStatistics if(entity_distance_iter->distance != 0.0) break; - double weight = 1.0; - if(getEntityWeightFunction(entity_distance_iter->reference, weight)) - weight_of_identical_entities += weight; - else - weight_of_identical_entities += 1.0; + weight_of_identical_entities += getEntityWeightFunction(entity_distance_iter->reference); } distance_contribution = TransformDistancesToExpectedValue(entity_distance_iter, end(entity_distance_pair_container)); @@ -1163,17 +1159,11 @@ class EntityQueriesStatistics if(FastIsNaN(distance_contribution)) return 0.0; - double entity_weight = 1.0; - if(getEntityWeightFunction(entity, entity_weight)) - { - if(entity_weight != 0) - distance_contribution *= entity_weight; - else - return 0.0; - } - //split the distance contribution among the identical entities - return distance_contribution * entity_weight / (weight_of_identical_entities + entity_weight); + double fraction_per_identical_entity = entity_weight / (weight_of_identical_entities + entity_weight); + + //return the distance contribution modified by weights and identical entities + return entity_weight * distance_contribution * fraction_per_identical_entity; } //exponent by which to scale the distances @@ -1197,6 +1187,7 @@ class EntityQueriesStatistics //if hasWeight is true, then will call getEntityWeightFunction and apply the respective entity weight to each distance bool hasWeight; - std::function getEntityWeightFunction; + //return the entity weight for the entity reference if it exists, 1.0 if it does not + std::function getEntityWeightFunction; }; };