|
34 | 34 | },
|
35 | 35 | {
|
36 | 36 | "cell_type": "code",
|
37 |
| - "execution_count": 1, |
| 37 | + "execution_count": null, |
38 | 38 | "metadata": {
|
39 | 39 | "pycharm": {
|
40 | 40 | "is_executing": false
|
|
56 | 56 | },
|
57 | 57 | {
|
58 | 58 | "cell_type": "code",
|
59 |
| - "execution_count": 2, |
| 59 | + "execution_count": null, |
60 | 60 | "metadata": {
|
61 | 61 | "pycharm": {
|
62 | 62 | "is_executing": false
|
63 | 63 | }
|
64 | 64 | },
|
65 |
| - "outputs": [ |
66 |
| - { |
67 |
| - "name": "stdout", |
68 |
| - "output_type": "stream", |
69 |
| - "text": [ |
70 |
| - "X: context matrix of shape (n_samples, n_features)\n", |
71 |
| - "[[-0.53211475 -0.40592956 0.05892565 -0.88067628 -0.84061481]\n", |
72 |
| - " [-0.95680954 -0.00540581 0.09148556 -0.82021004 -0.63425381]\n", |
73 |
| - " [-0.87792928 -0.51881823 -0.51767022 -0.05385187 -0.64499044]\n", |
74 |
| - " [-0.10569516 0.30847784 -0.353929 -0.94831998 -0.52175713]\n", |
75 |
| - " [-0.05088401 0.17155683 -0.4322128 -0.07509104 -0.78919832]\n", |
76 |
| - " [-0.88604157 0.55037109 0.42634479 -0.87179776 -0.69767766]\n", |
77 |
| - " [-0.0022063 0.99304089 0.76398198 -0.87343131 -0.12363411]\n", |
78 |
| - " [ 0.36371019 0.6660538 0.17177652 -0.08891719 -0.91070485]\n", |
79 |
| - " [-0.1056742 -0.72879406 -0.69367421 -0.8684397 0.70903817]\n", |
80 |
| - " [-0.15422305 0.31069811 -0.47487951 0.00853137 0.23793364]]\n" |
81 |
| - ] |
82 |
| - } |
83 |
| - ], |
| 65 | + "outputs": [], |
84 | 66 | "source": [
|
85 | 67 | "# context\n",
|
86 | 68 | "n_samples = 1000\n",
|
|
92 | 74 | },
|
93 | 75 | {
|
94 | 76 | "cell_type": "code",
|
95 |
| - "execution_count": 3, |
| 77 | + "execution_count": null, |
96 | 78 | "metadata": {},
|
97 | 79 | "outputs": [],
|
98 | 80 | "source": [
|
|
109 | 91 | },
|
110 | 92 | {
|
111 | 93 | "cell_type": "code",
|
112 |
| - "execution_count": 4, |
| 94 | + "execution_count": null, |
113 | 95 | "metadata": {},
|
114 | 96 | "outputs": [],
|
115 | 97 | "source": [
|
|
126 | 108 | },
|
127 | 109 | {
|
128 | 110 | "cell_type": "code",
|
129 |
| - "execution_count": 5, |
| 111 | + "execution_count": null, |
130 | 112 | "metadata": {},
|
131 |
| - "outputs": [ |
132 |
| - { |
133 |
| - "name": "stdout", |
134 |
| - "output_type": "stream", |
135 |
| - "text": [ |
136 |
| - "Recommended action: ['action C' 'action C' 'action B' 'action B' 'action C' 'action C'\n", |
137 |
| - " 'action B' 'action C' 'action B' 'action C']\n" |
138 |
| - ] |
139 |
| - } |
140 |
| - ], |
| 113 | + "outputs": [], |
141 | 114 | "source": [
|
142 | 115 | "# predict action\n",
|
143 | 116 | "pred_actions, _ = cmab.predict(X)\n",
|
|
153 | 126 | },
|
154 | 127 | {
|
155 | 128 | "cell_type": "code",
|
156 |
| - "execution_count": 6, |
| 129 | + "execution_count": null, |
157 | 130 | "metadata": {},
|
158 |
| - "outputs": [ |
159 |
| - { |
160 |
| - "name": "stdout", |
161 |
| - "output_type": "stream", |
162 |
| - "text": [ |
163 |
| - "Simulated rewards: [1 0 0 0 0 0 0 0 1 1]\n" |
164 |
| - ] |
165 |
| - } |
166 |
| - ], |
| 131 | + "outputs": [], |
167 | 132 | "source": [
|
168 | 133 | "# simulate reward from environment\n",
|
169 | 134 | "simulated_rewards = np.random.randint(2, size=n_samples)\n",
|
|
179 | 144 | },
|
180 | 145 | {
|
181 | 146 | "cell_type": "code",
|
182 |
| - "execution_count": 7, |
| 147 | + "execution_count": null, |
183 | 148 | "metadata": {},
|
184 |
| - "outputs": [ |
185 |
| - { |
186 |
| - "name": "stderr", |
187 |
| - "output_type": "stream", |
188 |
| - "text": [ |
189 |
| - "Auto-assigning NUTS sampler...\n", |
190 |
| - "Initializing NUTS using adapt_diag...\n", |
191 |
| - "Sequential sampling (2 chains in 1 job)\n", |
192 |
| - "NUTS: [beta4, beta3, beta2, beta1, beta0, alpha]\n", |
193 |
| - "Sampling 2 chains for 500 tune and 1_000 draw iterations (1_000 + 2_000 draws total) took 5 seconds.\n", |
194 |
| - "Auto-assigning NUTS sampler...\n", |
195 |
| - "Initializing NUTS using adapt_diag...\n", |
196 |
| - "Sequential sampling (2 chains in 1 job)\n", |
197 |
| - "NUTS: [beta4, beta3, beta2, beta1, beta0, alpha]\n", |
198 |
| - "Sampling 2 chains for 500 tune and 1_000 draw iterations (1_000 + 2_000 draws total) took 3 seconds.\n", |
199 |
| - "Auto-assigning NUTS sampler...\n", |
200 |
| - "Initializing NUTS using adapt_diag...\n", |
201 |
| - "Sequential sampling (2 chains in 1 job)\n", |
202 |
| - "NUTS: [beta4, beta3, beta2, beta1, beta0, alpha]\n", |
203 |
| - "Sampling 2 chains for 500 tune and 1_000 draw iterations (1_000 + 2_000 draws total) took 3 seconds.\n" |
204 |
| - ] |
205 |
| - } |
206 |
| - ], |
| 149 | + "outputs": [], |
207 | 150 | "source": [
|
208 | 151 | "# update model\n",
|
209 | 152 | "cmab.update(X, actions=pred_actions, rewards=simulated_rewards)"
|
|
0 commit comments