|
| 1 | +from typing import Any |
| 2 | + |
| 3 | + |
| 4 | +class WalkerAlias: |
| 5 | + """Usage example: |
| 6 | + ```python |
| 7 | + from random import random |
| 8 | +
|
| 9 | + walker = WalkerAlias({"A": 0.001, "B": 0.3, "C": (1 - 0.001 - 0.3)}, random) |
| 10 | + print(walker.get_random()) |
| 11 | + ``` |
| 12 | + """ |
| 13 | + |
| 14 | + def __init__(self, weighted: dict[Any, float], random) -> None: |
| 15 | + """Performs calculations necessary for Walker's alias method. |
| 16 | + Takes a dictionary of items and their weights as floats between 0 and 1, |
| 17 | + as well as a function returning random values between 0 and 1""" |
| 18 | + |
| 19 | + if sum(weighted.values()) != 1: |
| 20 | + raise Exception("sum of probabilities must be 1") |
| 21 | + |
| 22 | + self.keys = list(weighted.keys()) |
| 23 | + weights = [int(i * 10000) for i in weighted.values()] |
| 24 | + weight_sum = sum(weights) |
| 25 | + length = len(weights) |
| 26 | + probabilities = [] |
| 27 | + indices = [] |
| 28 | + over = [] |
| 29 | + under = [] |
| 30 | + |
| 31 | + for w in weights: |
| 32 | + indices.append(-1) |
| 33 | + probabilities.append(w * length / weight_sum) |
| 34 | + |
| 35 | + for n, p in enumerate(probabilities): |
| 36 | + if p < 1: |
| 37 | + under.append(n) |
| 38 | + else: |
| 39 | + over.append(n) |
| 40 | + |
| 41 | + while len(under) > 0 and len(over) > 0: |
| 42 | + i = over[-1] |
| 43 | + j = under.pop() |
| 44 | + probabilities[i] -= 1 - probabilities[i] |
| 45 | + indices[j] = i |
| 46 | + |
| 47 | + if probabilities[i] < 1: |
| 48 | + under.append(i) |
| 49 | + _ = over.pop() |
| 50 | + |
| 51 | + self.probabilities = probabilities |
| 52 | + self.indices = indices |
| 53 | + self.random = random |
| 54 | + self.keys = list(weighted.keys()) |
| 55 | + self.length = length |
| 56 | + |
| 57 | + def get_random(self) -> Any: |
| 58 | + random = self.random() |
| 59 | + i = int(random * self.length) |
| 60 | + |
| 61 | + return ( |
| 62 | + self.keys[i] |
| 63 | + if random <= self.probabilities[i] |
| 64 | + else self.keys[self.indices[i]] |
| 65 | + ) |
| 66 | + |
| 67 | + |
| 68 | +if __name__ == "__main__": |
| 69 | + from random import random |
| 70 | + |
| 71 | + walker = WalkerAlias({"A": 0.001, "B": 0.3, "C": (1 - 0.001 - 0.3)}, random) |
| 72 | + print(walker.get_random()) |
0 commit comments