-
Notifications
You must be signed in to change notification settings - Fork 0
/
bayesian_bandit_simu.py
86 lines (61 loc) · 2.03 KB
/
bayesian_bandit_simu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Created on Mon May 11 13:03:49 2020
@author: Vince
Exploration and Exploitation play a key role in any business.
And any good business will try to “explore” various opportunities where it can make a profit.
Any good business at the same time also tries to focus on a particular opportunity it has found already and tries to “exploits” it.
Thought Experiment: Assume that we have infinite slot machines.
Every slot machine has some win probability. But we don’t know these probability values.
"""
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import beta
NUM_TRIALS = 20000
BANDIT_PROBABILITIES = [0.2, 0.5, 0.75]
class Bandit(object):
def __init__(self, p):
self.p = p
self.a = 1
self.b = 1
def pull(self):
return np.random.random() < self.p
def sample(self):
return np.random.beta(self.a, self.b)
def update(self, x):
self.a += x
self.b += 1 - x
def plot(bandits, trial):
x = np.linspace(0, 1, 200)
for b in bandits:
y = beta.pdf(x, b.a, b.b)
plt.plot(x, y, label="real p: %.4f" % b.p)
plt.title("Bandit distributions after %s trials" % trial)
plt.legend()
plt.show()
def experiment(NUM_TRIALS):
bandits = [Bandit(p) for p in BANDIT_PROBABILITIES]
sample_points = [i*int(NUM_TRIALS/10) for i in range(1,11)]
for i in range(NUM_TRIALS):
# take a sample from each bandit
bestb = None
maxsample = -1
allsamples = [] # let's collect these just to print for debugging
for b in bandits:
sample = b.sample()
allsamples.append("%.4f" % sample)
if sample > maxsample:
maxsample = sample
bestb = b
if i in sample_points:
print("current samples: %s" % allsamples)
plot(bandits, i)
# pull the arm for the bandit with the largest sample
x = bestb.pull()
# update the distribution for the bandit whose arm we just pulled
bestb.update(x)
if __name__ == "__main__":
experiment(NUM_TRIALS)
"""
Reference:
https://towardsdatascience.com/bayesian-bandits-explained-simply-a5b43d9d5e38
"""