Skip to content

Commit

Permalink
fix case removal
Browse files Browse the repository at this point in the history
  • Loading branch information
howsoRes committed Feb 27, 2025
1 parent 476e8e9 commit d28a21c
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 9 deletions.
45 changes: 40 additions & 5 deletions howso/remove_cases.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -227,11 +227,46 @@
has_rebalance_features has_rebalance_features
))
)
;accumulate data mass change equivalent to one feature being changed for each case that satisfies the condition.
(accum_to_entities (assoc
!dataMassChangeSinceLastAnalyze (size cases)
!dataMassChangeSinceLastDataReduction (size cases)
))

(seq
;accumulate data mass change equivalent to one feature being changed for each case that satisfies the condition.
(accum_to_entities (assoc
!dataMassChangeSinceLastAnalyze (size cases)
!dataMassChangeSinceLastDataReduction (size cases)
))

;must reduce cached rebalance and probability masses for proper rebalancing
(if (or !continuousRebalanceFeatures !nominalRebalanceFeatures)
(let
(assoc
removed_probability_mass
(compute_on_contained_entities [
(query_in_entity_list cases)
(query_exists !internalLabelProbabilityMass)
(query_sum !internalLabelProbabilityMass)
])
removed_rebalance_weight
(compute_on_contained_entities [
(query_in_entity_list cases)
(query_exists ".case_weight")
(query_sum ".case_weight")
])
;approximate the scalar for all the reductions
scalar (/ !cachedTotalMass !cachedRebalanceTotalMass)
num_remaining_cases (- (call !GetNumTrainingCases) (size cases))
)
(assign_to_entities (assoc
!cachedTotalMass (- !cachedTotalMass removed_probability_mass)
!cachedRebalanceTotalMass
(max
(- !cachedRebalanceTotalMass (/ removed_rebalance_weight scalar) )
;prevent from setting the value from being too small unless there are no cases left, in which case 0 is correct
(if num_remaining_cases (/ 1 (* num_remaining_cases scalar)) 0)
)
))
)
)
)
)

(declare (assoc re_derivation_series_case_ids (null) ))
Expand Down
4 changes: 2 additions & 2 deletions howso/update_cases.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@
(replace (get hyperparam_map "featureDeviations"))
(replace (get hyperparam_map "p"))
(replace (get hyperparam_map "dt"))
(replace distribute_weight_feature)
(replace original_distribute_weight_feature)
;use a fixed random seed to guarantee deterministic behavior for reacts (named "fixed rand seed")
"fixed rand seed"
(null) ;radius
Expand Down Expand Up @@ -846,7 +846,7 @@
(+
old_neighbor_mass
;normalized portion of influence to accumulate to neighbor
(/ (get closest_cases_map neighbor_case_id) total_influence)
(* case_weight (/ (get closest_cases_map neighbor_case_id) total_influence))
)
))

Expand Down
77 changes: 75 additions & 2 deletions unit_tests/ut_h_rebalance_features.amlg
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@
))

(call_entity "howso" "set_auto_ablation_params" (assoc
auto_ablation_enabled (true)
auto_ablation_enabled (true)
min_num_cases 6
;auto_ablation_influence_weight_entropy_threshold 0.6
batch_size 1
Expand Down Expand Up @@ -225,7 +225,80 @@
]
percent .01
))
(call exit_if_failures (assoc msg "Ablation with rebalancing."))

(call_entity "howso" "remove_cases" (assoc
case_indices [ ["unit_test2" 4]]
distribute_weight_feature ".case_weight"
))

(assign (assoc
result
(call_entity "howso" "get_cases" (assoc
features [".session_training_index" ".case_weight" ".probability_mass"]
session "unit_test2"
))
))

(print "removed case with mass weight distribution: ")
(call assert_approximate (assoc
obs (get result (list 1 "payload" "cases"))
exp
[
[0 1.5 1]
[1 1.5 1]
[2 0.75 1]
;case index 4 is removed, its mass of 1.5 is split into:
;index 3: mass 1.5 * 0.5875 = 0.88125
;new mass: 0.88125 + 1.5 = 2.38125
;new weight: 1.068 * 2.38125 / 1.5 = 1.69574
[3 1.69574 2.38125]
;index 5: mass 1.5 * 0.4125 = 0.61875
;new mass: 0.61875 + 1.0 = 1.61875
;new weight: 0.75 * 1.61875 / 1.0 = 1.21406
[5 1.21406 1.61875]
]
percent .01
))

(call_entity "howso" "remove_cases" (assoc
case_indices [ ["unit_test2" 5]]
))

(assign (assoc
result
(call_entity "howso" "get_cases" (assoc
features [".session_training_index" ".case_weight" ".probability_mass"]
session "unit_test2"
))
))

(print "removed case without distributing: ")
(call assert_approximate (assoc
obs (get result (list 1 "payload" "cases"))
exp
[
[0 1.5 1]
[1 1.5 1]
[2 0.75 1]
[3 1.69574 2.38125]
]
percent .01
))

(print "Cached total mass is reduced: ")
(call assert_same (assoc
exp (- 7 1.61875) ;5.38125
obs (call_entity "howso" "debug_label" (assoc label "!cachedTotalMass"))
))

(print "Cached rebalance mass is reduced: ")
(call assert_approximate (assoc
exp (- 2.2 (/ 1.21406 3.1818)) ; 1.8184
obs (call_entity "howso" "debug_label" (assoc label "!cachedRebalanceTotalMass"))
percent .01
))

(call exit_if_failures (assoc msg "Ablation and removal with rebalancing."))

(call exit_if_failures (assoc msg unit_test_name ))
)
Expand Down

0 comments on commit d28a21c

Please sign in to comment.