Skip to content

Commit 0d34235

Browse files
committed
Implement Walker's alias method
1 parent 3444feb commit 0d34235

File tree

5 files changed

+105
-0
lines changed

5 files changed

+105
-0
lines changed

walker_alias/__init__.py

Whitespace-only changes.

walker_alias/src/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .lib import WalkerAlias
2+
3+
__all__ = ["WalkerAlias"]

walker_alias/src/lib.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any
2+
3+
4+
class WalkerAlias:
5+
"""Usage example:
6+
```python
7+
from random import random
8+
9+
walker = WalkerAlias({"A": 0.001, "B": 0.3, "C": (1 - 0.001 - 0.3)}, random)
10+
print(walker.get_random())
11+
```
12+
"""
13+
14+
def __init__(self, weighted: dict[Any, float], random) -> None:
15+
"""Performs calculations necessary for Walker's alias method.
16+
Takes a dictionary of items and their weights as floats between 0 and 1,
17+
as well as a function returning random values between 0 and 1"""
18+
19+
if sum(weighted.values()) != 1:
20+
raise Exception("sum of probabilities must be 1")
21+
22+
self.keys = list(weighted.keys())
23+
weights = [int(i * 10000) for i in weighted.values()]
24+
weight_sum = sum(weights)
25+
length = len(weights)
26+
probabilities = []
27+
indices = []
28+
over = []
29+
under = []
30+
31+
for w in weights:
32+
indices.append(-1)
33+
probabilities.append(w * length / weight_sum)
34+
35+
for n, p in enumerate(probabilities):
36+
if p < 1:
37+
under.append(n)
38+
else:
39+
over.append(n)
40+
41+
while len(under) > 0 and len(over) > 0:
42+
i = over[-1]
43+
j = under.pop()
44+
probabilities[i] -= 1 - probabilities[i]
45+
indices[j] = i
46+
47+
if probabilities[i] < 1:
48+
under.append(i)
49+
_ = over.pop()
50+
51+
self.probabilities = probabilities
52+
self.indices = indices
53+
self.random = random
54+
self.keys = list(weighted.keys())
55+
self.length = length
56+
57+
def get_random(self) -> Any:
58+
random = self.random()
59+
i = int(random * self.length)
60+
61+
return (
62+
self.keys[i]
63+
if random <= self.probabilities[i]
64+
else self.keys[self.indices[i]]
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
from random import random
70+
71+
walker = WalkerAlias({"A": 0.001, "B": 0.3, "C": (1 - 0.001 - 0.3)}, random)
72+
print(walker.get_random())

walker_alias/tests/__init__.py

Whitespace-only changes.

walker_alias/tests/test.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from random import random
2+
3+
import pytest
4+
5+
from ..src import WalkerAlias
6+
7+
8+
def test_edge():
9+
walker = WalkerAlias({"A": 0, "B": 0.3, "C": 0.7}, random)
10+
11+
for _ in range(100):
12+
assert walker.get_random() != "A"
13+
14+
15+
def test_validity():
16+
with pytest.raises(Exception) as e:
17+
# 0.1 + 0.5 + 0.5 > 1
18+
walker = WalkerAlias({"A": 0.1, "B": 0.5, "C": 0.5}, random)
19+
walker.get_random()
20+
21+
assert e == "sum of probabilities must be 1"
22+
23+
24+
def test_values():
25+
keys = ["A", "B", "C"]
26+
weights = [0.1, 0.3, 0.6]
27+
walker = WalkerAlias(dict(zip(keys, weights)), random)
28+
29+
for _ in range(100):
30+
assert walker.get_random() in keys

0 commit comments

Comments
 (0)