Skip to content

Commit d88c0a7

Browse files
committed
fix count that was completely broken
1 parent 8fdfefc commit d88c0a7

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

worlds/stardew_valley/stardew_rule/base.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,15 +363,44 @@ def get_difficulty(self):
363363
class Count(BaseStardewRule):
364364
count: int
365365
rules: List[StardewRule]
366+
counter: Counter[StardewRule]
367+
evaluate: Callable[[CollectionState], bool]
368+
369+
total: Optional[int]
370+
rule_mapping: Optional[Dict[StardewRule, StardewRule]]
366371

367372
def __init__(self, rules: List[StardewRule], count: int):
368373
self.count = count
369374
self.counter = Counter(rules)
370-
self.total = sum(self.counter.values())
371-
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
372-
self.rule_mapping = {}
375+
376+
if len(self.counter) / len(rules) < .66:
377+
# Checking if it's worth using the count operation with shortcircuit or not. Value should be fine-tuned when Count has more usage.
378+
self.total = sum(self.counter.values())
379+
self.rules = sorted(self.counter.keys(), key=lambda x: self.counter[x], reverse=True)
380+
self.rule_mapping = {}
381+
self.evaluate = self.evaluate_with_shortcircuit
382+
else:
383+
self.rules = rules
384+
self.evaluate = self.evaluate_without_shortcircuit
373385

374386
def __call__(self, state: CollectionState) -> bool:
387+
return self.evaluate(state)
388+
389+
def evaluate_without_shortcircuit(self, state: CollectionState) -> bool:
390+
c = 0
391+
for i in range(self.rules_count):
392+
self.rules[i], value = self.rules[i].evaluate_while_simplifying(state)
393+
if value:
394+
c += 1
395+
396+
if c >= self.count:
397+
return True
398+
if c + self.rules_count - i < self.count:
399+
break
400+
401+
return False
402+
403+
def evaluate_with_shortcircuit(self, state: CollectionState) -> bool:
375404
c = 0
376405
t = self.total
377406

@@ -395,7 +424,7 @@ def call_evaluate_while_simplifying_cached(self, rule: StardewRule, state: Colle
395424
try:
396425
# A mapping table with the original rule is used here because two rules could resolve to the same rule.
397426
# This would require to change the counter to merge both rules, and quickly become complicated.
398-
return self.rule_mapping[rule].evaluate_while_simplifying(state)
427+
return self.rule_mapping[rule](state)
399428
except KeyError:
400429
self.rule_mapping[rule], value = rule.evaluate_while_simplifying(state)
401430
return value

worlds/stardew_valley/test/TestStardewRule.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ class TestCount(unittest.TestCase):
250250
def test_duplicate_rule_count_double(self):
251251
expected_result = True
252252
collection_state = MagicMock()
253-
simplified_rule = MagicMock()
254-
other_rule = MagicMock(spec=StardewRule)
253+
simplified_rule = Mock()
254+
other_rule = Mock(spec=StardewRule)
255255
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
256256
rule = Count([other_rule, other_rule, other_rule], 2)
257257

@@ -261,10 +261,10 @@ def test_duplicate_rule_count_double(self):
261261
self.assertEqual(expected_result, actual_result)
262262

263263
def test_simplified_rule_is_reused(self):
264-
expected_result = True
264+
expected_result = False
265265
collection_state = MagicMock()
266-
simplified_rule = MagicMock(return_value=expected_result)
267-
other_rule = MagicMock(spec=StardewRule)
266+
simplified_rule = Mock(return_value=expected_result)
267+
other_rule = Mock(spec=StardewRule)
268268
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
269269
rule = Count([other_rule, other_rule, other_rule], 2)
270270

@@ -278,15 +278,15 @@ def test_simplified_rule_is_reused(self):
278278
actual_result = rule(collection_state)
279279

280280
other_rule.evaluate_while_simplifying.assert_not_called()
281-
simplified_rule.assert_not_called()
281+
simplified_rule.assert_called()
282282
self.assertEqual(expected_result, actual_result)
283283

284284
def test_break_if_not_enough_rule_to_complete(self):
285285
expected_result = False
286286
collection_state = MagicMock()
287-
simplified_rule = MagicMock()
288-
never_called_rule = MagicMock()
289-
other_rule = MagicMock(spec=StardewRule)
287+
simplified_rule = Mock()
288+
never_called_rule = Mock()
289+
other_rule = Mock(spec=StardewRule)
290290
other_rule.evaluate_while_simplifying = Mock(return_value=(simplified_rule, expected_result))
291291
rule = Count([other_rule, other_rule, other_rule, never_called_rule], 2)
292292

@@ -296,3 +296,8 @@ def test_break_if_not_enough_rule_to_complete(self):
296296
never_called_rule.assert_not_called()
297297
never_called_rule.evaluate_while_simplifying.assert_not_called()
298298
self.assertEqual(expected_result, actual_result)
299+
300+
def test_evaluate_without_shortcircuit_when_rules_are_all_different(self):
301+
rule = Count([Mock(), Mock(), Mock(), Mock()], 2)
302+
303+
self.assertEqual(rule.evaluate, rule.evaluate_without_shortcircuit)

0 commit comments

Comments
 (0)