Skip to content

Commit

Permalink
23055: Fixes issues where distance contributions would be returned as…
Browse files Browse the repository at this point in the history
… null (#365)
  • Loading branch information
howsohazard authored Mar 3, 2025
1 parent acc450f commit 38fbb18
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 56 deletions.
6 changes: 4 additions & 2 deletions src/Amalgam/Conviction.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ class ConvictionProcessor
knnCache->GetKnn(entity_reference, numNearestNeighbors, true,
buffers->neighbors, additional_holdout_reference);

return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_reference);
double entity_weight = distanceTransform->getEntityWeightFunction(entity_reference);
return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_weight);
}

//Like the other ComputeDistanceContribution, but only includes included_entities
Expand All @@ -64,7 +65,8 @@ class ConvictionProcessor
buffers->neighbors.clear();
knnCache->GetKnn(entity_reference, numNearestNeighbors, true, buffers->neighbors, included_entities);

return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_reference);
double entity_weight = distanceTransform->getEntityWeightFunction(entity_reference);
return distanceTransform->ComputeDistanceContribution(buffers->neighbors, entity_weight);
}

//Computes the Distance Contributions for each entity specified in entities_to_compute
Expand Down
5 changes: 5 additions & 0 deletions src/Amalgam/DistanceReferencePair.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class DistanceReferencePair
return distance == drp.distance;
}

constexpr bool operator ==(const double &value) const
{
return distance == value;
}

constexpr bool SameReference(const DistanceReferencePair<ReferenceType> &drp) const
{
return reference == drp.reference;
Expand Down
14 changes: 10 additions & 4 deletions src/Amalgam/SeparableBoxFilterDataStore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include "Entity.h"
#include "SeparableBoxFilterDataStore.h"

//system headers
#include <limits>

#if defined(MULTITHREAD_SUPPORT) || defined(MULTITHREAD_INTERFACE)
thread_local
#endif
Expand Down Expand Up @@ -566,11 +569,14 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized

//cache kth smallest distance to target search node
double worst_candidate_distance = std::numeric_limits<double>::infinity();
//assume there's an error in each addition and subtraction
double distance_threshold_to_consider_zero = 2
* static_cast<double>(num_enabled_features) * std::numeric_limits<double>::epsilon();
if(sorted_results.Size() == top_k)
{
double top_distance = sorted_results.Top().distance;
//don't clamp top distance if we're expanding and only have 0 distances
if(! (expand_to_first_nonzero_distance && top_distance == 0.0) )
if(! (expand_to_first_nonzero_distance && top_distance <= distance_threshold_to_consider_zero) )
worst_candidate_distance = top_distance;
}

Expand All @@ -589,7 +595,7 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized
{
double top_distance = sorted_results.Top().distance;
//don't clamp top distance if we're expanding and only have 0 distances
if(!(expand_to_first_nonzero_distance && top_distance == 0.0))
if(!(expand_to_first_nonzero_distance && top_distance <= distance_threshold_to_consider_zero))
worst_candidate_distance = top_distance;
}

Expand All @@ -605,7 +611,7 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized
continue;

//if not expanding and pushing a zero distance onto the stack, then push and pop a value onto the stack
if(!(expand_to_first_nonzero_distance && distance == 0.0))
if(!(expand_to_first_nonzero_distance && distance <= distance_threshold_to_consider_zero))
worst_candidate_distance = sorted_results.PushAndPop(DistanceReferencePair(distance, entity_index)).distance;
else //adding a zero and need to expand beyond zeros
{
Expand All @@ -617,7 +623,7 @@ void SeparableBoxFilterDataStore::FindEntitiesNearestToIndexedEntity(Generalized
sorted_results.Pop();

//if the next largest size is zero, then need to put the non-zero value back in sorted_results
if(sorted_results.Size() > 0 && sorted_results.Top().distance == 0)
if(sorted_results.Size() > 0 && sorted_results.Top().distance <= distance_threshold_to_consider_zero)
sorted_results.Push(drp);
}
}
Expand Down
13 changes: 6 additions & 7 deletions src/Amalgam/SeparableBoxFilterDataStore.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,24 +439,23 @@ class SeparableBoxFilterDataStore

//returns a function that will take in an entity index and reference to a double to store the value and return true if the value is found
// assumes and requires column_index is a valid column (not a feature_id)
inline std::function<bool(size_t, double &)> GetNumberValueFromEntityIndexFunction(size_t column_index)
inline std::function<double(size_t)> GetNumberValueFromEntityIndexFunction(size_t column_index)
{
//if invalid column_index, then always return false
//if invalid column_index, then always return 1.0
if(column_index >= columnData.size())
return [](size_t i, double &value) { return false; };
return [](size_t i) { return 1.0; };

auto column_data = columnData[column_index].get();
auto number_indices_ptr = &column_data->numberIndices;
auto value_type = column_data->GetUnresolvedValueType(ENIVT_NUMBER);

return [&, number_indices_ptr, column_index, column_data, value_type]
(size_t i, double &value)
(size_t i)
{
if(!number_indices_ptr->contains(i))
return false;
return 1.0;

value = column_data->GetResolvedValue(value_type, GetValue(i, column_index)).number;
return true;
return column_data->GetResolvedValue(value_type, GetValue(i, column_index)).number;
};
}

Expand Down
16 changes: 8 additions & 8 deletions src/Amalgam/entity/EntityQueries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -725,12 +725,12 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain
it.distance = GetConditionDistanceMeasure(it.reference, true);
}

auto weight_function = [this](Entity *e, double &weight_value)
auto weight_function = [this](Entity *e)
{
auto [ret_val, found] = e->GetValueAtLabelAsNumber(weightLabel);
auto [weight, found] = e->GetValueAtLabelAsNumber(weightLabel);
if(found)
weight_value = ret_val;
return found;
return weight;
return 1.0;
};

//transform distances as appropriate
Expand Down Expand Up @@ -774,12 +774,12 @@ EvaluableNodeReference EntityQueryCondition::GetMatchingEntities(Entity *contain


auto weight_function = [this]
(Entity *e, double &weight_value)
(Entity *e)
{
auto [ret_val, found] = e->GetValueAtLabelAsNumber(weightLabel);
auto [weight, found] = e->GetValueAtLabelAsNumber(weightLabel);
if(found)
weight_value = ret_val;
return found;
return weight;
return 1.0;
};

//transform distances as appropriate
Expand Down
61 changes: 26 additions & 35 deletions src/Amalgam/entity/EntityQueriesStatistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ class EntityQueriesStatistics
//if compute_surprisal is false, distance_weight_exponent is the exponent each distance is raised to
//uses min_to_retrieve and max_to_retrieve to determine how many entities to keep, stopping when the first
// entity's marginal probability falls below the num_to_retrieve_min_increment_prob threshold
//has_weight, if set, will use get_weight, taking in a function of an entity reference and a reference to an output
//has_weight, if set, will use the function get_weight to deterimne the given entity's weight
// double to set the weight, and should return true if the entity has a weight, false if not
template<typename EntityReference>
class DistanceTransform
Expand All @@ -727,7 +727,7 @@ class EntityQueriesStatistics
double distance_weight_exponent,
size_t min_to_retrieve, size_t max_to_retrieve,
double num_to_retrieve_min_increment_prob,
bool has_weight, double min_weight, std::function<bool(EntityReference, double &)> get_weight)
bool has_weight, double min_weight, std::function<double(EntityReference)> get_weight)
{
distanceWeightExponent = distance_weight_exponent;
computeSurprisal = compute_surprisal;
Expand Down Expand Up @@ -782,6 +782,10 @@ class EntityQueriesStatistics

if(minToRetrieve < maxToRetrieve || numToRetrieveMinIncrementalProbability > 0.0)
{
//if no elements, just return zero
if(entity_distance_pair_container_begin == entity_distance_pair_container_end)
return 0;

auto [first_weighted_value, first_unweighted_value, first_prob, first_prob_mass, first_weight]
= transform_func(entity_distance_pair_container_begin);
result_func(entity_distance_pair_container_begin, first_weighted_value, first_unweighted_value, first_prob_mass, first_weight);
Expand Down Expand Up @@ -842,8 +846,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(prob, prob, prob, prob, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);
double weighted_prob = prob * weight;

return std::make_tuple(weighted_prob, prob, prob, weighted_prob, weight);
Expand All @@ -860,8 +863,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(surprisal, surprisal, prob, prob, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);

return std::make_tuple(surprisal, surprisal, prob, prob * weight, weight);
}, result_func);
Expand All @@ -879,8 +881,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(prob, prob, prob, prob, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);
double weighted_prob = prob * weight;

return std::make_tuple(weighted_prob, prob, prob, weighted_prob, weight);
Expand All @@ -898,8 +899,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(iter->distance, iter->distance, prob, prob, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);

return std::make_tuple(iter->distance, iter->distance, prob, weight * prob, weight);
}, result_func);
Expand All @@ -913,8 +913,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(1.0, 1.0, 1.0, 1.0, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);

return std::make_tuple(weight, 1.0, 1.0, weight, weight);
}, result_func);
Expand All @@ -933,8 +932,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(iter->distance, iter->distance, prob, prob, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);

return std::make_tuple(iter->distance, iter->distance, prob, weight * prob, weight);
}, result_func);
Expand All @@ -951,8 +949,7 @@ class EntityQueriesStatistics
if(!hasWeight)
return std::make_tuple(iter->distance, iter->distance, prob, prob, 1.0);

double weight = 1.0;
getEntityWeightFunction(iter->reference, weight);
double weight = getEntityWeightFunction(iter->reference);
double weighted_prob = prob * weight;

return std::make_tuple(weighted_prob, prob, prob, weighted_prob, weight);
Expand Down Expand Up @@ -1111,11 +1108,14 @@ class EntityQueriesStatistics
}

//Computes the distance contribution as a type of generalized mean with special handling for distances of zero
// entity is the entity that the distance contribution is being performed on, and entity_distance_pair_container are the distances to
// its nearest entities
// entity_distance_pair_container are the distances to its nearest entities,
// and entity_weight is the weight of the entity for which this distance contribution is being computed
// the functions get_entity and get_distance_ref return the entity and reference to the distance for an iterator of entity_distance_pair_container
double ComputeDistanceContribution(std::vector<DistanceReferencePair<EntityReference>> &entity_distance_pair_container, EntityReference entity)
double ComputeDistanceContribution(std::vector<DistanceReferencePair<EntityReference>> &entity_distance_pair_container, double entity_weight)
{
if(entity_weight == 0.0)
return 0.0;

double distance_contribution = 0.0;
//there's at least one entity in question
size_t num_identical_entities = 1;
Expand Down Expand Up @@ -1150,11 +1150,7 @@ class EntityQueriesStatistics
if(entity_distance_iter->distance != 0.0)
break;

double weight = 1.0;
if(getEntityWeightFunction(entity_distance_iter->reference, weight))
weight_of_identical_entities += weight;
else
weight_of_identical_entities += 1.0;
weight_of_identical_entities += getEntityWeightFunction(entity_distance_iter->reference);
}

distance_contribution = TransformDistancesToExpectedValue(entity_distance_iter, end(entity_distance_pair_container));
Expand All @@ -1163,17 +1159,11 @@ class EntityQueriesStatistics
if(FastIsNaN(distance_contribution))
return 0.0;

double entity_weight = 1.0;
if(getEntityWeightFunction(entity, entity_weight))
{
if(entity_weight != 0)
distance_contribution *= entity_weight;
else
return 0.0;
}

//split the distance contribution among the identical entities
return distance_contribution * entity_weight / (weight_of_identical_entities + entity_weight);
double fraction_per_identical_entity = entity_weight / (weight_of_identical_entities + entity_weight);

//return the distance contribution modified by weights and identical entities
return entity_weight * distance_contribution * fraction_per_identical_entity;
}

//exponent by which to scale the distances
Expand All @@ -1197,6 +1187,7 @@ class EntityQueriesStatistics

//if hasWeight is true, then will call getEntityWeightFunction and apply the respective entity weight to each distance
bool hasWeight;
std::function<bool(EntityReference, double &)> getEntityWeightFunction;
//return the entity weight for the entity reference if it exists, 1.0 if it does not
std::function<double(EntityReference)> getEntityWeightFunction;
};
};

0 comments on commit 38fbb18

Please sign in to comment.