diff --git a/docs/tutorials/mab.ipynb b/docs/tutorials/mab.ipynb index 22c5666..c756b2e 100644 --- a/docs/tutorials/mab.ipynb +++ b/docs/tutorials/mab.ipynb @@ -102,14 +102,14 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, @@ -169,14 +169,14 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, @@ -422,14 +422,14 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m2\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m3\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m3\u001B[0m\u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, @@ -494,14 +494,14 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m337\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m369\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m4448\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m4315\u001b[0m\u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m246\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m296\u001b[0m\u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m337\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m369\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m4448\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m4315\u001B[0m\u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m246\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m296\u001B[0m\u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, diff --git a/docs/tutorials/smab_mo_cc.ipynb b/docs/tutorials/smab_mo_cc.ipynb index 880654c..3dbb86b 100644 --- a/docs/tutorials/smab_mo_cc.ipynb +++ b/docs/tutorials/smab_mo_cc.ipynb @@ -74,9 +74,9 @@ "source": [ "mab = SmabBernoulliMOCC(\n", " actions={\n", - " \"a1\": BetaMOCC(counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=30),\n", - " \"a2\": BetaMOCC(counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=10),\n", - " \"a3\": BetaMOCC(counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=20),\n", + " \"a1\": BetaMOCC(models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=30),\n", + " \"a2\": BetaMOCC(models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=10),\n", + " \"a3\": BetaMOCC(models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=20),\n", " }\n", ")" ] @@ -93,15 +93,15 @@ "
SmabBernoulliMOCC(\n", " actions={\n", " 'a1': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=30.0\n", " ),\n", " 'a2': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=10.0\n", " ),\n", " 'a3': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=20.0\n", " )\n", " },\n", @@ -110,23 +110,23 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulliMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m30\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m10\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m20\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mMultiObjectiveCostControlBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulliMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m30\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m10\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m20\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mMultiObjectiveCostControlBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, @@ -180,15 +180,15 @@ "
SmabBernoulliMOCC(\n", " actions={\n", " 'a1': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=30.0\n", " ),\n", " 'a2': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=10.0\n", " ),\n", " 'a3': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=20.0\n", " )\n", " },\n", @@ -197,23 +197,23 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulliMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m30\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m10\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m20\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mMultiObjectiveCostControlBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulliMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m30\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m10\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m20\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mMultiObjectiveCostControlBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, @@ -449,15 +449,15 @@ "
SmabBernoulliMOCC(\n", " actions={\n", " 'a1': BetaMOCC(\n", - " counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", + " models=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)],\n", " cost=30.0\n", " ),\n", " 'a2': BetaMOCC(\n", - " counters=[Beta(n_successes=7, n_failures=3), Beta(n_successes=7, n_failures=3)],\n", + " models=[Beta(n_successes=7, n_failures=3), Beta(n_successes=7, n_failures=3)],\n", " cost=10.0\n", " ),\n", " 'a3': BetaMOCC(\n", - " counters=[Beta(n_successes=3, n_failures=1), Beta(n_successes=3, n_failures=1)],\n", + " models=[Beta(n_successes=3, n_failures=1), Beta(n_successes=3, n_failures=1)],\n", " cost=20.0\n", " )\n", " },\n", @@ -466,23 +466,23 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulliMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m30\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m7\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m7\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m10\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m20\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mMultiObjectiveCostControlBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulliMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m30\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m7\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m3\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m7\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m3\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m10\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m3\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m3\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m20\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mMultiObjectiveCostControlBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, @@ -541,15 +541,15 @@ "
SmabBernoulliMOCC(\n", " actions={\n", " 'a1': BetaMOCC(\n", - " counters=[Beta(n_successes=450, n_failures=488), Beta(n_successes=450, n_failures=488)],\n", + " models=[Beta(n_successes=450, n_failures=488), Beta(n_successes=450, n_failures=488)],\n", " cost=30.0\n", " ),\n", " 'a2': BetaMOCC(\n", - " counters=[Beta(n_successes=8541, n_failures=8325), Beta(n_successes=8541, n_failures=8325)],\n", + " models=[Beta(n_successes=8541, n_failures=8325), Beta(n_successes=8541, n_failures=8325)],\n", " cost=10.0\n", " ),\n", " 'a3': BetaMOCC(\n", - " counters=[Beta(n_successes=1110, n_failures=1102), Beta(n_successes=1110, n_failures=1102)],\n", + " models=[Beta(n_successes=1110, n_failures=1102), Beta(n_successes=1110, n_failures=1102)],\n", " cost=20.0\n", " )\n", " },\n", @@ -558,23 +558,23 @@ "\n" ], "text/plain": [ - "\u001b[1;35mSmabBernoulliMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n", - " \u001b[32m'a1'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m450\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m488\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m450\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m488\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m30\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a2'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m8541\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m8325\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m8541\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m8325\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m10\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m,\n", - " \u001b[32m'a3'\u001b[0m: \u001b[1;35mBetaMOCC\u001b[0m\u001b[1m(\u001b[0m\n", - " \u001b[33mcounters\u001b[0m=\u001b[1m[\u001b[0m\u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1110\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1102\u001b[0m\u001b[1m)\u001b[0m, \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1110\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1102\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m,\n", - " \u001b[33mcost\u001b[0m=\u001b[1;36m20\u001b[0m\u001b[1;36m.0\u001b[0m\n", - " \u001b[1m)\u001b[0m\n", - " \u001b[1m}\u001b[0m,\n", - " \u001b[33mstrategy\u001b[0m=\u001b[1;35mMultiObjectiveCostControlBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[1m)\u001b[0m\n" + "\u001B[1;35mSmabBernoulliMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n", + " \u001B[32m'a1'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m450\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m488\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m450\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m488\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m30\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a2'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m8541\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m8325\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m8541\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m8325\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m10\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m,\n", + " \u001B[32m'a3'\u001B[0m: \u001B[1;35mBetaMOCC\u001B[0m\u001B[1m(\u001B[0m\n", + " \u001B[33mmodels\u001B[0m=\u001B[1m[\u001B[0m\u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1110\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1102\u001B[0m\u001B[1m)\u001B[0m, \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1110\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1102\u001B[0m\u001B[1m)\u001B[0m\u001B[1m]\u001B[0m,\n", + " \u001B[33mcost\u001B[0m=\u001B[1;36m20\u001B[0m\u001B[1;36m.0\u001B[0m\n", + " \u001B[1m)\u001B[0m\n", + " \u001B[1m}\u001B[0m,\n", + " \u001B[33mstrategy\u001B[0m=\u001B[1;35mMultiObjectiveCostControlBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n", + "\u001B[1m)\u001B[0m\n" ] }, "metadata": {}, diff --git a/pybandits/cmab.py b/pybandits/cmab.py index a4f2246..e21c41d 100644 --- a/pybandits/cmab.py +++ b/pybandits/cmab.py @@ -19,7 +19,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - +from abc import ABC from typing import Dict, List, Optional, Set, Union from numpy import array @@ -29,21 +29,31 @@ from pybandits.base import ActionId, BinaryReward, CmabPredictions from pybandits.mab import BaseMab -from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC +from pybandits.model import ( + BaseBayesianLogisticRegression, + BaseBayesianLogisticRegressionMO, + BayesianLogisticRegression, + BayesianLogisticRegressionCC, + BayesianLogisticRegressionMO, + BayesianLogisticRegressionMOCC, +) from pybandits.strategy import ( BestActionIdentificationBandit, ClassicBandit, CostControlBandit, + MultiObjectiveBandit, + MultiObjectiveCostControlBandit, + MultiObjectiveStrategy, ) -class BaseCmabBernoulli(BaseMab): +class BaseCmabBernoulli(BaseMab, ABC): """ Base model for a Contextual Multi-Armed Bandit for Bernoulli bandits with Thompson Sampling. Parameters ---------- - actions: Dict[ActionId, BayesianLogisticRegression] + actions: Dict[ActionId, Union[BayesianLogisticRegression, BayesianLogisticRegressionMO]] The list of possible actions, and their associated Model. strategy: Strategy The strategy used to select actions. @@ -54,25 +64,42 @@ class BaseCmabBernoulli(BaseMab): bandit strategy. """ - actions: Dict[ActionId, BayesianLogisticRegression] + actions: Dict[ActionId, Union[BaseBayesianLogisticRegression, BaseBayesianLogisticRegressionMO]] predict_with_proba: bool predict_actions_randomly: bool + @classmethod + def _check_single_bayesian_logistic_regression_model(cls, action, first_action_type, first_action): + if not isinstance(action, first_action_type): + raise AttributeError("All actions should follow the same type.") + if not len(action.betas) == len(first_action.betas): + raise AttributeError("All actions should have the same number of betas.") + if not action.update_method == first_action.update_method: + raise AttributeError("All actions should have the same update method.") + if not action.update_kwargs == first_action.update_kwargs: + raise AttributeError("All actions should have the same update kwargs.") + @field_validator("actions", mode="after") @classmethod def check_bayesian_logistic_regression_models(cls, v): action_models = list(v.values()) first_action = action_models[0] first_action_type = type(first_action) - for action in action_models[1:]: - if not isinstance(action, first_action_type): - raise AttributeError("All actions should follow the same type.") - if not len(action.betas) == len(first_action.betas): - raise AttributeError("All actions should have the same number of betas.") - if not action.update_method == first_action.update_method: - raise AttributeError("All actions should have the same update method.") - if not action.update_kwargs == first_action.update_kwargs: - raise AttributeError("All actions should have the same update kwargs.") + if isinstance(first_action, BaseBayesianLogisticRegression): + for action in action_models[1:]: + if not isinstance(action, first_action_type): + raise AttributeError("All actions should follow the same type.") + cls._check_single_bayesian_logistic_regression_model(action, first_action_type, first_action) + elif isinstance(first_action, BaseBayesianLogisticRegressionMO): + first_action = first_action.models[0] + first_action_type = type(first_action) + for action in action_models: + for model in action.models: + cls._check_single_bayesian_logistic_regression_model(model, first_action_type, first_action) + else: + raise NotImplementedError( + "Only BaseBayesianLogisticRegression and BaseBayesianLogisticRegressionMO are supported." + ) return v @validate_call(config=dict(arbitrary_types_allowed=True)) @@ -112,7 +139,16 @@ def predict( if self.predict_actions_randomly: # check that context has the expected number of columns - if context.shape[1] != len(list(self.actions.values())[0].betas): + first_action_model = list(self.actions.values())[0] + if isinstance(first_action_model, BaseBayesianLogisticRegression): + expected_length = len(first_action_model.betas) + elif isinstance(first_action_model, BaseBayesianLogisticRegressionMO): + expected_length = len(first_action_model.models[0].betas) + else: + raise NotImplementedError( + "Only BaseBayesianLogisticRegression and BaseBayesianLogisticRegressionMO are supported." + ) + if context.shape[1] != expected_length: raise AttributeError("Context must have {n_betas} columns") selected_actions = choice(list(valid_actions), size=len(context)).tolist() # predict actions randomly @@ -170,6 +206,7 @@ def update( rewards = [[1, 1], [1, 0], [1, 1], [1, 0], [1, 1], ...] """ self._validate_update_params(actions=actions, rewards=rewards) + if len(context) != len(rewards): raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.") @@ -272,3 +309,68 @@ class CmabBernoulliCC(BaseCmabBernoulli): strategy: CostControlBandit predict_with_proba: bool = True predict_actions_randomly: bool = False + + +class BaseCmabBernoulliMO(BaseCmabBernoulli, ABC): + """ + Base model for a Contextual Multi-Armed Bandit with Thompson Sampling implementation, and Multi-Objectives + strategy. + + Parameters + ---------- + actions: Dict[ActionId, BetaMO] + The list of possible actions, and their associated Model. + strategy: Strategy + The strategy used to select actions. + """ + + actions: Dict[ActionId, BaseBayesianLogisticRegressionMO] + strategy: MultiObjectiveStrategy + + +class CmabBernoulliMO(BaseCmabBernoulliMO): + """ + Contextual Multi-Armed Bandit with Thompson Sampling, and Multi-Objectives strategy. + + The reward pertaining to an action is a multidimensional vector instead of a scalar value. In this setting, + different actions are compared according to Pareto order between their expected reward vectors, and those actions + whose expected rewards are not inferior to that of any other actions are called Pareto optimal actions, all of which + constitute the Pareto front. + + Reference: Thompson Sampling for Multi-Objective Multi-Armed Bandits Problem (Yahyaa and Manderick, 2015) + https://www.researchgate.net/publication/272823659_Thompson_Sampling_for_Multi-Objective_Multi-Armed_Bandits_Problem + + Parameters + ---------- + actions: Dict[ActionId, BayesianLogisticRegressionMO] + The list of possible actions, and their associated Model. + strategy: MultiObjectiveBandit + The strategy used to select actions. + """ + + actions: Dict[ActionId, BayesianLogisticRegressionMO] + strategy: MultiObjectiveBandit + predict_with_proba: bool = False + predict_actions_randomly: bool = False + + +class CmabBernoulliMOCC(BaseCmabBernoulliMO): + """ + Contextual Multi-Armed Bandit with Thompson Sampling implementation for Multi-Objective (MO) with Cost + Control (CC) strategy. + + This Bandit allows the reward to be a multidimensional vector and include a control of the action cost. It merges + the Multi-Objective and Cost Control strategies. + + Parameters + ---------- + actions: Dict[ActionId, BayesianLogisticRegressionMOCC] + The list of possible actions, and their associated Model. + strategy: MultiObjectiveCostControlBandit + The strategy used to select actions. + """ + + actions: Dict[ActionId, BayesianLogisticRegressionMOCC] + strategy: MultiObjectiveCostControlBandit + predict_with_proba: bool = True + predict_actions_randomly: bool = False diff --git a/pybandits/mab.py b/pybandits/mab.py index 55de8d9..1a5d32a 100644 --- a/pybandits/mab.py +++ b/pybandits/mab.py @@ -37,7 +37,7 @@ Predictions, PyBanditsBaseModel, ) -from pybandits.model import Model +from pybandits.model import Model, ModelMO from pybandits.strategy import Strategy from pybandits.utils import extract_argument_names_from_function @@ -48,7 +48,7 @@ class BaseMab(PyBanditsBaseModel, ABC): Parameters ---------- - actions : Dict[ActionId, Model] + actions : Dict[ActionId, Union[Model,ModelMO]] The list of possible actions, and their associated Model. strategy : Strategy The strategy used to select actions. @@ -62,14 +62,14 @@ class BaseMab(PyBanditsBaseModel, ABC): which in turn will be used to instantiate the strategy. """ - actions: Dict[ActionId, Model] + actions: Dict[ActionId, Union[Model, ModelMO]] strategy: Strategy epsilon: Optional[Float01] = None default_action: Optional[ActionId] = None def __init__( self, - actions: Dict[ActionId, Model], + actions: Dict[ActionId, Union[Model, ModelMO]], epsilon: Optional[Float01] = None, default_action: Optional[ActionId] = None, **strategy_kwargs, @@ -88,27 +88,26 @@ def __init__( @field_validator("actions", mode="before") @classmethod - def at_least_one_action_is_defined(cls, v): + def validate_action_configurations(cls, v): # validate number of actions if len(v) == 0: raise AttributeError("At least one action should be defined.") elif len(v) == 1: warnings.warn("Only a single action was supplied. This MAB will be deterministic.") + # validate that all actions are of the same configuration action_models = list(v.values()) first_action = action_models[0] first_action_type = type(first_action) if any(not isinstance(action, first_action_type) for action in action_models[1:]): raise AttributeError("All actions should follow the same type.") - return v - @model_validator(mode="after") - def check_default_action(self): - if not self.epsilon and self.default_action: - raise AttributeError("A default action should only be defined when epsilon is defined.") - if self.default_action and self.default_action not in self.actions: - raise AttributeError("The default action must be valid action defined in the actions set.") - return self + # For multi-objective actions, validate that all actions have the same number of objectives + if isinstance(first_action, ModelMO): + n_objs_per_action = [len(action_model.models) for action_model in v.values()] + if len(set(n_objs_per_action)) != 1: + raise ValueError("All actions should have the same number of objectives") + return v @model_validator(mode="after") def validate_default_action(self): @@ -203,7 +202,7 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None) -> Predicti probs: List[Dict[ActionId, Probability]] of shape (n_samples,) The probabilities of getting a positive reward for each action ws : List[Dict[ActionId, float]], only relevant for some of the MABs - The weighted sum of logistic regression logits.. + The weighted sum of logistic regression logits. """ def get_state(self) -> (str, dict): @@ -224,7 +223,7 @@ def get_state(self) -> (str, dict): def _select_epsilon_greedy_action( self, p: ActionRewardLikelihood, - actions: Optional[Dict[ActionId, Model]] = None, + actions: Optional[Dict[ActionId, Union[Model, ModelMO]]] = None, ) -> ActionId: """ Wraps self.strategy.select_action function with epsilon-greedy strategy, diff --git a/pybandits/model.py b/pybandits/model.py index 32981f9..936302e 100644 --- a/pybandits/model.py +++ b/pybandits/model.py @@ -53,19 +53,72 @@ class Model(PyBanditsBaseModel, ABC): """ @abstractmethod - def sample_proba(self) -> Probability: + def sample_proba(self, **kwargs) -> Probability: """ Sample the probability of getting a positive reward. """ @abstractmethod - def update(self, rewards: List[Any]): + def update(self, rewards: List[Any], **kwargs): """ Update the model parameters. """ -class BaseBeta(Model): +class ModelMO(PyBanditsBaseModel): + """ + Multi-objective extension of Model. + + Parameters + ---------- + models : List[Model] + List of models. + """ + + models: List[Model] = Field(..., min_length=1) + + @validate_call + def sample_proba(self, **kwargs) -> List[Probability]: + """ + Sample the probability of getting a positive reward. + + Returns + ------- + prob: List[Probability] + Probabilities of getting a positive reward for each objective. + """ + return [x.sample_proba(**kwargs) for x in self.models] + + @validate_call + def update(self, rewards: List[List[BinaryReward]], **kwargs): + """ + Update the Beta model using the provided rewards. + + Parameters + ---------- + rewards: List[List[BinaryReward]] + A list of rewards, where each reward is in turn a list containing the reward of the Beta model + associated to each objective. + For example, `[[1, 1], [1, 0], [1, 1], [1, 0], [1, 1]]`. + kwargs: Dict[str, Any] + Additional arguments for the Bayesian Logistic Regression MO child model. + """ + if any(len(x) != len(self.models) for x in rewards): + raise AttributeError("The shape of rewards is incorrect") + + for i, model in enumerate(self.models): + model.update(rewards=[r[i] for r in rewards], **kwargs) + + +class ModelCC(PyBanditsBaseModel, ABC): + """ + Class to augment prior distributions with cost control. + """ + + cost: NonNegativeFloat + + +class BaseBeta(Model, ABC): """ Beta Distribution model for Bernoulli multi-armed bandits. @@ -82,7 +135,7 @@ class BaseBeta(Model): @model_validator(mode="before") @classmethod - def both_or_neither_counters_are_defined(cls, values): + def both_or_neither_models_are_defined(cls, values): if hasattr(values, "n_successes") != hasattr(values, "n_failures"): raise ValueError("Either both or neither n_successes and n_failures should be specified.") return values @@ -136,63 +189,42 @@ def sample_proba(self) -> Probability: class Beta(BaseBeta): """ Beta Distribution model for Bernoulli multi-armed bandits. + + Parameters + ---------- + n_successes: PositiveInt = 1 + Counter of the number of successes. + n_failures: PositiveInt = 1 + Counter of the number of failures. """ -class BetaCC(BaseBeta): +class BetaCC(BaseBeta, ModelCC): """ Beta Distribution model for Bernoulli multi-armed bandits with cost control. Parameters ---------- + n_successes: PositiveInt = 1 + Counter of the number of successes. + n_failures: PositiveInt = 1 + Counter of the number of failures. cost: NonNegativeFloat Cost associated to the Beta distribution. """ - cost: NonNegativeFloat - -class BetaMO(Model): +class BetaMO(ModelMO): """ Beta Distribution model for Bernoulli multi-armed bandits with multi-objectives. Parameters ---------- - counters: List[Beta] of shape (n_objectives,) + models : List[Beta] of shape (n_objectives,) List of Beta distributions. """ - counters: List[Beta] - - @validate_call - def sample_proba(self) -> List[Probability]: - """ - Sample the probability of getting a positive reward. - - Returns - ------- - prob: List[Probability] - Probabilities of getting a positive reward for each objective. - """ - return [x.sample_proba() for x in self.counters] - - @validate_call - def update(self, rewards: List[List[BinaryReward]]): - """ - Update the Beta model using the provided rewards. - - Parameters - ---------- - rewards: List[List[BinaryReward]] - A list of rewards, where each reward is in turn a list containing the reward of the Beta model - associated to each objective. - For example, `[[1, 1], [1, 0], [1, 1], [1, 0], [1, 1]]`. - """ - if any(len(x) != len(self.counters) for x in rewards): - raise AttributeError("The shape of rewards is incorrect") - - for i, counter in enumerate(self.counters): - counter.update([r[i] for r in rewards]) + models: List[Beta] @classmethod def cold_start(cls, n_objectives: PositiveInt, **kwargs) -> "BetaMO": @@ -216,28 +248,26 @@ def cold_start(cls, n_objectives: PositiveInt, **kwargs) -> "BetaMO": Returns ------- - blr: BayesianLogisticRegrssion + beta_mo: BetaMO The Bayesian Logistic Regression model. """ - counters = n_objectives * [Beta()] - blr = cls(counters=counters, **kwargs) - return blr + models = n_objectives * [Beta()] + beta_mo = cls(models=models, **kwargs) + return beta_mo -class BetaMOCC(BetaMO): +class BetaMOCC(BetaMO, ModelCC): """ Beta Distribution model for Bernoulli multi-armed bandits with multi-objectives and cost control. Parameters ---------- - counters: List[BetaCC] of shape (n_objectives,) + models: List[Beta] of shape (n_objectives,) List of Beta distributions. cost: NonNegativeFloat Cost associated to the Beta distribution. """ - cost: NonNegativeFloat - class StudentT(PyBanditsBaseModel): """ @@ -258,9 +288,9 @@ class StudentT(PyBanditsBaseModel): nu: confloat(allow_inf_nan=False) = 5.0 -class BayesianLogisticRegression(Model): +class BaseBayesianLogisticRegression(Model, ABC): """ - Base Bayesian Logistic Regression model. + Bayesian Logistic Regression model. It is modeled as: @@ -465,7 +495,7 @@ def cold_start( update_method: UpdateMethods = "MCMC", update_kwargs: Optional[dict] = None, **kwargs, - ) -> "BayesianLogisticRegression": + ) -> "BaseBayesianLogisticRegression": """ Utility function to create a Bayesian Logistic Regression model or child model with cost control, with default parameters. @@ -492,7 +522,7 @@ def cold_start( Returns ------- - blr: BayesianLogisticRegrssion + blr: BayesianLogisticRegression The Bayesian Logistic Regression model. """ return cls( @@ -504,7 +534,32 @@ def cold_start( ) -class BayesianLogisticRegressionCC(BayesianLogisticRegression): +class BayesianLogisticRegression(BaseBayesianLogisticRegression): + """ + Bayesian Logistic Regression model. + + It is modeled as: + + y = sigmoid(alpha + beta1 * x1 + beta2 * x2 + ... + betaN * xN) + + where the alpha and betas coefficients are Student's t-distributions. + + Parameters + ---------- + alpha: StudentT + Student's t-distribution of the alpha coefficient. + betas: StudentT + Student's t-distributions of the betas coefficients. + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. + """ + + +class BayesianLogisticRegressionCC(BaseBayesianLogisticRegression, ModelCC): """ Bayesian Logistic Regression model with cost control. @@ -530,4 +585,107 @@ class BayesianLogisticRegressionCC(BayesianLogisticRegression): Cost associated to the Bayesian Logistic Regression model. """ + +class BaseBayesianLogisticRegressionMO(ModelMO, ABC): + """ + Bayesian Logistic Regression model with multi-objectives. + + It is modeled as: + + y = sigmoid(alpha + beta1 * x1 + beta2 * x2 + ... + betaN * xN) + + where the alpha and betas coefficients are Student's t-distributions. + + Parameters + ---------- + models: List[BayesianLogisticRegression] of shape (n_objectives,) + List of Bayesian Logistic Regression. + """ + + models: List[BayesianLogisticRegression] = Field(..., min_length=1) + + @classmethod + def cold_start( + cls, + n_objectives: PositiveInt, + n_features: PositiveInt, + update_method: UpdateMethods = "MCMC", + update_kwargs: Optional[dict] = None, + **kwargs, + ) -> "BaseBayesianLogisticRegressionMO": + """ + Utility function to create a Multi-Objective Bayesian Logistic Regression model or child model with cost control, + with default parameters. + + It is modeled as: + + y = sigmoid(alpha + beta1 * x1 + beta2 * x2 + ... + betaN * xN) + + where the alpha and betas coefficients are Student's t-distributions. + + Parameters + ---------- + n_objectives : PositiveInt + The number of objectives. + n_features : PositiveInt + The number of betas of the Bayesian Logistic Regression model. This is also the number of features expected + after in the context matrix. + update_method : UpdateMethods, defaults to "MCMC" + The strategy for computing posterior quantities of the Bayesian models in the update function. Such as Markov + chain Monte Carlo ("MCMC") or Variational Inference ("VI"). Check UpdateMethods in pybandits.model for the + full list. + update_kwargs : Optional[dict], uses default values if not specified + Additional arguments to pass to the update method. + kwargs: Dict[str, Any] + Additional arguments for the Bayesian Logistic Regression child model. + + Returns + ------- + blr_mo: BayesianLogisticRegressionMO + The Multi-Objective Bayesian Logistic Regression model. + """ + return cls( + models=[ + BayesianLogisticRegression.cold_start( + n_features=n_features, update_method=update_method, update_kwargs=update_kwargs + ) + for _ in range(n_objectives) + ], + **kwargs, + ) + + +class BayesianLogisticRegressionMO(BaseBayesianLogisticRegressionMO): + """ + Bayesian Logistic Regression model with multi-objectives. + + It is modeled as: + + y = sigmoid(alpha + beta1 * x1 + beta2 * x2 + ... + betaN * xN) + + where the alpha and betas coefficients are Student's t-distributions. + + Parameters + ---------- + models: List[BayesianLogisticRegression] of shape (n_objectives,) + List of Bayesian Logistic Regression. + """ + + +class BayesianLogisticRegressionMOCC(BaseBayesianLogisticRegressionMO, ModelCC): + """ + Bayesian Logistic Regression model with multi-objectives and cost control. + + It is modeled as: + + y = sigmoid(alpha + beta1 * x1 + beta2 * x2 + ... + betaN * xN) + + where the alpha and betas coefficients are Student's t-distributions. + + Parameters + ---------- + models: List[BayesianLogisticRegression] of shape (n_objectives,) + List of Bayesian Logistic Regression. cost: NonNegativeFloat + Cost associated to the Beta distribution. + """ diff --git a/pybandits/smab.py b/pybandits/smab.py index c1b2958..d240a19 100644 --- a/pybandits/smab.py +++ b/pybandits/smab.py @@ -19,12 +19,11 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - - +from abc import ABC from collections import defaultdict from typing import Dict, List, Optional, Set, Union -from pydantic import PositiveInt, field_validator, validate_call +from pydantic import PositiveInt, validate_call from pybandits.base import ( ActionId, @@ -40,23 +39,23 @@ CostControlBandit, MultiObjectiveBandit, MultiObjectiveCostControlBandit, - Strategy, + MultiObjectiveStrategy, ) -class BaseSmabBernoulli(BaseMab): +class BaseSmabBernoulli(BaseMab, ABC): """ Base model for a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling. Parameters ---------- - actions: Dict[ActionId, BaseBeta] + actions: Dict[ActionId, Union[Beta, BetaMO]] The list of possible actions, and their associated Model. strategy: Strategy The strategy used to select actions. """ - actions: Dict[ActionId, BaseBeta] + actions: Dict[ActionId, Union[BaseBeta, BetaMO]] @validate_call def predict( @@ -189,7 +188,7 @@ class SmabBernoulliCC(BaseSmabBernoulli): strategy: CostControlBandit -class BaseSmabBernoulliMO(BaseSmabBernoulli): +class BaseSmabBernoulliMO(BaseSmabBernoulli, ABC): """ Base model for a Stochastic Bernoulli Multi-Armed Bandit with Thompson Sampling implementation, and Multi-Objectives strategy. @@ -203,15 +202,7 @@ class BaseSmabBernoulliMO(BaseSmabBernoulli): """ actions: Dict[ActionId, BetaMO] - strategy: Strategy - - @field_validator("actions", mode="after") - @classmethod - def all_actions_have_same_number_of_objectives(cls, actions: Dict[ActionId, BetaMO]): - n_objs_per_action = [len(beta.counters) for beta in actions.values()] - if len(set(n_objs_per_action)) != 1: - raise ValueError("All actions should have the same number of objectives") - return actions + strategy: MultiObjectiveStrategy class SmabBernoulliMO(BaseSmabBernoulliMO): @@ -234,7 +225,6 @@ class SmabBernoulliMO(BaseSmabBernoulliMO): The strategy used to select actions. """ - actions: Dict[ActionId, BetaMO] strategy: MultiObjectiveBandit diff --git a/pybandits/strategy.py b/pybandits/strategy.py index 014f753..ccf9a9e 100644 --- a/pybandits/strategy.py +++ b/pybandits/strategy.py @@ -30,7 +30,7 @@ from typing_extensions import Self from pybandits.base import ActionId, Float01, Probability, PyBanditsBaseModel -from pybandits.model import Beta, BetaMOCC, Model +from pybandits.model import Beta, BetaMOCC, Model, ModelMO class Strategy(PyBanditsBaseModel, ABC): @@ -243,7 +243,7 @@ def _average(cls, p_of_action: Union[Probability, List[Probability]]): def _evaluate_and_select( cls, p: Union[Dict[ActionId, Probability], Dict[ActionId, List[Probability]]], - actions: Dict[ActionId, Model], + actions: Dict[ActionId, Union[Model, ModelMO]], feasible_actions: List[ActionId], ) -> ActionId: """ diff --git a/pybandits/utils.py b/pybandits/utils.py index 62e6af7..1f03c07 100644 --- a/pybandits/utils.py +++ b/pybandits/utils.py @@ -3,7 +3,9 @@ from pydantic import validate_call -JSONSerializable = Union[str, int, float, bool, None, List["JSONSerializable"], Dict[str, "JSONSerializable"]] +Simple = Union[str, int, float, bool, None] + +JSONSerializable = Union[Simple, List["JSONSerializable"], Dict[str, "JSONSerializable"]] @validate_call @@ -21,6 +23,40 @@ def to_serializable_dict(d: Dict[str, Any]) -> Dict[str, JSONSerializable]: return json.loads(json.dumps(d, default=dict)) +@validate_call +def update_nested_struct( + d: Union[Dict[str, Any], List, Simple], other: Union[Dict[str, Any], List, Simple] +) -> Union[Dict[str, Any], List, Simple]: + """ + Update a nested combination of dictionaries and lists with another dictionary, recursively. + + Parameters + ---------- + d : Union[Dict[str, Any], List, Simple] + Nested combination of dictionaries and lists to update. + other : Union[Dict[str, Any], List, Simple] + Nested combination of dictionaries and lists to update with. + + Returns + ------- + d : Union[Dict[str, Any], List, Simple] + Updated nested combination of dictionaries and lists. + + """ + if isinstance(d, dict) and isinstance(other, dict): + for key, value in other.items(): + if key in d: + d[key] = update_nested_struct(d[key], value) + else: + d[key] = value + elif isinstance(d, list) and isinstance(other, list): + assert len(d) == len(other) + for i, (d_value, other_value) in enumerate(zip(d, other)): + d[i] = update_nested_struct(d_value, other_value) + + return d + + @validate_call def extract_argument_names_from_function(function_handle: Callable, is_class_method: bool = False) -> List[str]: """ diff --git a/tests/test_cmab.py b/tests/test_cmab.py index 73d39a9..47d280c 100644 --- a/tests/test_cmab.py +++ b/tests/test_cmab.py @@ -30,10 +30,23 @@ from pydantic import NonNegativeFloat, ValidationError from pybandits.base import Float01 -from pybandits.cmab import CmabBernoulli, CmabBernoulliBAI, CmabBernoulliCC -from pybandits.model import BayesianLogisticRegression, BayesianLogisticRegressionCC, StudentT, UpdateMethods -from pybandits.strategy import BestActionIdentificationBandit, ClassicBandit, CostControlBandit -from pybandits.utils import to_serializable_dict +from pybandits.cmab import CmabBernoulli, CmabBernoulliBAI, CmabBernoulliCC, CmabBernoulliMO, CmabBernoulliMOCC +from pybandits.model import ( + BayesianLogisticRegression, + BayesianLogisticRegressionCC, + BayesianLogisticRegressionMO, + BayesianLogisticRegressionMOCC, + StudentT, + UpdateMethods, +) +from pybandits.strategy import ( + BestActionIdentificationBandit, + ClassicBandit, + CostControlBandit, + MultiObjectiveBandit, + MultiObjectiveCostControlBandit, +) +from pybandits.utils import to_serializable_dict, update_nested_struct from tests.test_utils import is_serializable literal_update_methods = get_args(UpdateMethods) @@ -199,7 +212,7 @@ def run_update(context): run_update(context=context) -@settings(deadline=10000) +@settings(deadline=30000) @given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) def test_cmab_update_not_all_actions(n_samples, n_feat, update_method): actions = np.random.choice(["a3", "a4"], size=n_samples).tolist() @@ -410,7 +423,7 @@ def test_cmab_from_state(state, update_method): assert isinstance(cmab, CmabBernoulli) actual_actions = to_serializable_dict(cmab.actions) # Normalize the dict - expected_actions = {k: {**v, **state["actions"][k]} for k, v in actual_actions.items()} + expected_actions = update_nested_struct(state["actions"], actual_actions) assert expected_actions == actual_actions # Ensure get_state and from_state compatibility @@ -540,7 +553,7 @@ def test_cmab_bai_predict(n_samples, n_features): assert len(selected_actions) == len(probs) == len(weighted_sums) == n_samples -@settings(deadline=10000) +@settings(deadline=30000) @given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) def test_cmab_bai_update(n_samples, n_features, update_method): actions = np.random.choice(["a1", "a2"], size=n_samples).tolist() @@ -641,7 +654,7 @@ def test_cmab_bai_from_state(state, update_method): assert isinstance(cmab, CmabBernoulliBAI) actual_actions = to_serializable_dict(cmab.actions) # Normalize the dict - expected_actions = {k: {**v, **state["actions"][k]} for k, v in actual_actions.items()} + expected_actions = update_nested_struct(state["actions"], actual_actions) assert expected_actions == actual_actions expected_exploit_p = cmab.strategy.get_expected_value_from_state(state, "exploit_p") @@ -776,7 +789,7 @@ def test_cmab_cc_predict(n_samples, n_features): assert len(selected_actions) == len(probs) == len(weighted_sums) == n_samples -@settings(deadline=10000) +@settings(deadline=20000) @given(st.just(100), st.just(3), st.sampled_from(literal_update_methods)) def test_cmab_cc_update(n_samples, n_features, update_method): actions = np.random.choice(["a1", "a2"], size=n_samples).tolist() @@ -888,7 +901,7 @@ def test_cmab_cc_from_state(state, update_method): assert isinstance(cmab, CmabBernoulliCC) actual_actions = to_serializable_dict(cmab.actions) # Normalize the dict - expected_actions = {k: {**v, **state["actions"][k]} for k, v in actual_actions.items()} + expected_actions = update_nested_struct(state["actions"], actual_actions) assert expected_actions == actual_actions expected_subsidy_factor = cmab.strategy.get_expected_value_from_state(state, "subsidy_factor") @@ -903,6 +916,422 @@ def test_cmab_cc_from_state(state, update_method): ######################################################################################################################## +# cmabBernoulli with strategy=MultiObjectiveBandit() + + +@given( + st.lists(st.integers(min_value=1), min_size=6, max_size=6), + st.integers(min_value=2, max_value=100), +) +def test_can_init_cmab_mo(a_list, n_features): + a, b, c, d, e, f = a_list + model = BayesianLogisticRegressionMO( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ] + ) + + s = CmabBernoulliMO( + actions={ + "a1": model.model_copy(deep=True), + "a2": model.model_copy(deep=True), + }, + ) + assert s.actions["a1"] == model + assert s.actions["a2"] == model + assert s.strategy == MultiObjectiveBandit() + + +@given(st.lists(st.integers(min_value=1), min_size=7, max_size=7), st.integers(min_value=2, max_value=100)) +def test_bad_init_cmab_mo(a_list, n_features): + a, b, c, d, e, f, g = a_list + with pytest.raises(ValueError): + BayesianLogisticRegressionMO( + models=[ + BayesianLogisticRegressionCC(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()], cost=g), + BayesianLogisticRegressionCC(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()], cost=g), + BayesianLogisticRegressionCC(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()], cost=g), + ] + ) + + +@settings(deadline=500) +@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=2, max_value=100)) +def test_all_actions_must_have_same_number_of_objectives_cmab_mo(mu, sigma, n_features): + blr = BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]) + with pytest.raises(ValueError): + CmabBernoulliMO( + actions={ + "a1": BayesianLogisticRegressionMO(models=[blr.model_copy(deep=True), blr.model_copy(deep=True)]), + "a2": BayesianLogisticRegressionMO(models=[blr.model_copy(deep=True), blr.model_copy(deep=True)]), + "a3": BayesianLogisticRegressionMO( + models=[blr.model_copy(deep=True), blr.model_copy(deep=True), blr.model_copy(deep=True)] + ), + }, + ) + + +def test_cmab_mo_predict(n_samples=1000, n_objectives=3, n_features=10): + s = CmabBernoulliMO.cold_start(action_ids={"a1", "a2"}, n_objectives=n_objectives, n_features=n_features) + context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) + forbidden = None + s.predict(context=context, forbidden_actions=forbidden) + + forbidden = ["a1"] + predicted_actions, _, _ = s.predict(context=context, forbidden_actions=forbidden) + + assert "a1" not in predicted_actions + + forbidden = ["a1", "a2"] + with pytest.raises(ValueError): + s.predict(context=context, forbidden_actions=forbidden) + + forbidden = ["a1", "a2", "a3"] + with pytest.raises(ValueError): + s.predict(context=context, forbidden_actions=forbidden) + + forbidden = ["a1", "a3"] + with pytest.raises(ValueError): + s.predict(context=context, forbidden_actions=forbidden) + + +def test_cmab_mo_update(action_ids={"a1", "a2"}, n_samples=10, n_objectives=3, n_features=10): + context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) + rewards = [np.random.choice([0, 1], size=n_objectives).tolist() for _ in range(n_samples)] + actions = np.random.choice(list(action_ids), size=n_samples).tolist() + mab = CmabBernoulliMO.cold_start(action_ids=action_ids, n_objectives=n_objectives, n_features=n_features) + assert all( + [ + mab.actions[a] == BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + for a in action_ids + ] + ) + + mab.update(actions=actions, rewards=rewards, context=context) + assert all( + [ + mab.actions[a] != BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + for a in set(action_ids) + ] + ) + + +@given(st.lists(st.integers(min_value=1), min_size=6, max_size=6), st.integers(min_value=2, max_value=100)) +def test_cmab_mo_get_state(a_list, n_features): + a, b, c, d, e, f = a_list + actions = { + "a1": BayesianLogisticRegressionMO( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ] + ), + "a2": BayesianLogisticRegressionMO( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ] + ), + } + cmab = CmabBernoulliMO(actions=actions) + expected_state = to_serializable_dict( + { + "actions": actions, + "strategy": {}, + "epsilon": None, + "default_action": None, + "predict_actions_randomly": False, + "predict_with_proba": False, + } + ) + + class_name, cmab_state = cmab.get_state() + assert class_name == "CmabBernoulliMO" + assert cmab_state == expected_state + + assert is_serializable(cmab_state), "Internal state is not serializable" + + +@settings(deadline=500) +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "models": st.lists( + st.fixed_dictionaries( + { + "alpha": st.fixed_dictionaries( + { + "mu": st.integers(min_value=1, max_value=100), + "sigma": st.integers(min_value=1, max_value=100), + }, + ), + "betas": st.lists( + st.fixed_dictionaries( + { + "mu": st.integers(min_value=1, max_value=100), + "sigma": st.integers(min_value=1, max_value=100), + }, + ), + min_size=2, + max_size=2, + ), + }, + ), + min_size=3, + max_size=3, + ) + } + ), + min_size=2, + ), + "strategy": st.fixed_dictionaries({}), + } + ) +) +def test_cmab_mo_from_state(state): + cmab = CmabBernoulliMO.from_state(state) + assert isinstance(cmab, CmabBernoulliMO) + + actual_actions = to_serializable_dict(cmab.actions) + expected_actions = update_nested_struct(state["actions"], actual_actions) + assert expected_actions == actual_actions + + # Ensure get_state and from_state compatibility + new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1]) + assert new_cmab == cmab + + +######################################################################################################################## + + +# cmabBernoulli with strategy=MultiObjectiveCostControlBandit() + + +@given(st.lists(st.integers(min_value=1), min_size=8, max_size=8), st.integers(min_value=2, max_value=100)) +def test_can_init_cmab_mo_cc(a_list, n_features): + a, b, c, d, e, f, g, h = a_list + model1 = BayesianLogisticRegressionMOCC( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ], + cost=g, + ) + model2 = model1.model_copy(deep=True, update={"cost": h}) + s = CmabBernoulliMOCC( + actions={"a1": model1.model_copy(deep=True), "a2": model2.model_copy(deep=True)}, + ) + assert s.actions["a1"] == model1 + assert s.actions["a2"] == model2 + assert s.strategy == MultiObjectiveCostControlBandit() + + +@given(st.lists(st.integers(min_value=1), min_size=7, max_size=7), st.integers(min_value=2, max_value=100)) +def test_bad_init_cmab_mocc(a_list, n_features): + a, b, c, d, e, f, g = a_list + with pytest.raises(ValueError): + BayesianLogisticRegressionMOCC( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ] + ) + with pytest.raises(ValueError): + BayesianLogisticRegressionMOCC( + models=[ + BayesianLogisticRegressionCC(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()], cost=g), + BayesianLogisticRegressionCC(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()], cost=g), + BayesianLogisticRegressionCC(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()], cost=g), + ] + ) + + +@settings(deadline=500) +@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=2, max_value=100), st.just(1)) +def test_all_actions_must_have_same_number_of_objectives_cmab_mo_cc(mu, sigma, n_features, cost): + blr = BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]) + with pytest.raises(ValueError): + CmabBernoulliMO( + actions={ + "a1": BayesianLogisticRegressionMOCC( + models=[blr.model_copy(deep=True), blr.model_copy(deep=True)], cost=cost + ), + "a2": BayesianLogisticRegressionMOCC( + models=[blr.model_copy(deep=True), blr.model_copy(deep=True)], cost=cost + ), + "a3": BayesianLogisticRegressionMOCC( + models=[blr.model_copy(deep=True), blr.model_copy(deep=True), blr.model_copy(deep=True)], cost=cost + ), + }, + ) + + +def test_cmab_mo_cc_predict(n_samples=10, n_objectives=3, n_features=10): + context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) + + s = CmabBernoulliMOCC.cold_start( + action_ids_cost={"a1": 1, "a2": 2}, n_objectives=n_objectives, n_features=n_features + ) + + forbidden = None + s.predict(context=context, forbidden_actions=forbidden) + + forbidden = ["a1"] + predicted_actions, _, _ = s.predict(context=context, forbidden_actions=forbidden) + + assert "a1" not in predicted_actions + + forbidden = ["a1", "a2"] + with pytest.raises(ValueError): + s.predict(context=context, forbidden_actions=forbidden) + + forbidden = ["a1", "a2", "a3"] + with pytest.raises(ValueError): + s.predict(context=context, forbidden_actions=forbidden) + + forbidden = ["a1", "a3"] + with pytest.raises(ValueError): + s.predict(context=context, forbidden_actions=forbidden) + + +def test_cmab_mo_cc_update(action_ids_cost={"a1": 1, "a2": 2}, n_samples=10, n_objectives=3, n_features=10): + context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) + rewards = [np.random.choice([0, 1], size=n_objectives).tolist() for _ in range(n_samples)] + actions = np.random.choice(list(action_ids_cost), size=n_samples).tolist() + action_ids = set(action_ids_cost.keys()) + mab = CmabBernoulliMOCC.cold_start( + action_ids_cost=action_ids_cost, n_objectives=n_objectives, n_features=n_features + ) + assert all( + [ + mab.actions[a] + == BayesianLogisticRegressionMOCC.cold_start( + n_objectives=n_objectives, n_features=n_features, cost=action_ids_cost[a] + ) + for a in action_ids + ] + ) + + mab.update(actions=actions, rewards=rewards, context=context) + assert all( + [ + mab.actions[a] + != BayesianLogisticRegressionMOCC.cold_start( + n_objectives=n_objectives, n_features=n_features, cost=action_ids_cost[a] + ) + for a in set(action_ids) + ] + ) + + +@given(st.lists(st.integers(min_value=1), min_size=8, max_size=8), st.integers(min_value=2, max_value=100)) +def test_cmab_mo_cc_get_state(a_list, n_features): + a, b, c, d, e, f, g, h = a_list + + actions = { + "a1": BayesianLogisticRegressionMOCC( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ], + cost=g, + ), + "a2": BayesianLogisticRegressionMOCC( + models=[ + BayesianLogisticRegression(alpha=StudentT(mu=a, sigma=b), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=c, sigma=d), betas=n_features * [StudentT()]), + BayesianLogisticRegression(alpha=StudentT(mu=e, sigma=f), betas=n_features * [StudentT()]), + ], + cost=h, + ), + } + cmab = CmabBernoulliMOCC(actions=actions) + expected_state = to_serializable_dict( + { + "actions": actions, + "strategy": {}, + "epsilon": None, + "default_action": None, + "predict_actions_randomly": False, + "predict_with_proba": True, + } + ) + + class_name, cmab_state = cmab.get_state() + assert class_name == "CmabBernoulliMOCC" + assert cmab_state == expected_state + + assert is_serializable(cmab_state), "Internal state is not serializable" + + +@settings(deadline=500) +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "models": st.lists( + st.fixed_dictionaries( + { + "alpha": st.fixed_dictionaries( + { + "mu": st.integers(min_value=1, max_value=100), + "sigma": st.integers(min_value=1, max_value=100), + }, + ), + "betas": st.lists( + st.fixed_dictionaries( + { + "mu": st.integers(min_value=1, max_value=100), + "sigma": st.integers(min_value=1, max_value=100), + }, + ), + min_size=2, + max_size=2, + ), + }, + ), + min_size=3, + max_size=3, + ), + "cost": st.floats(min_value=0), + } + ), + min_size=2, + ), + "strategy": st.fixed_dictionaries({}), + } + ) +) +def test_cmab_mo_cc_from_state(state): + cmab = CmabBernoulliMOCC.from_state(state) + assert isinstance(cmab, CmabBernoulliMOCC) + + actual_actions = to_serializable_dict(cmab.actions) + expected_actions = update_nested_struct(state["actions"], actual_actions) + assert expected_actions == actual_actions + + # Ensure get_state and from_state compatibility + new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1]) + assert new_cmab == cmab + + +######################################################################################################################## + + # Cmab with epsilon-greedy super strategy @@ -949,3 +1378,31 @@ def test_epsilon_greedy_cmab_cc_predict(n_samples, n_features): assert len(selected_actions) == n_samples assert probs == n_samples * [{"a1": 0.5, "a2": 0.5}] assert weighted_sums == n_samples * [{"a1": 0, "a2": 0}] + + +def test_epsilon_greddy_cmab_mo_predict( + action_ids={"a1", "a2"}, n_samples=10, n_objectives=3, n_features=10, epsilon=0.1, default_action="a1" +): + s = CmabBernoulliMO.cold_start( + action_ids=action_ids, + n_objectives=n_objectives, + n_features=n_features, + epsilon=epsilon, + default_action=default_action, + ) + context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) + s.predict(context=context) + + +def test_epsilon_greddy_smab_mo_cc_predict( + action_ids_cost={"a1": 1, "a2": 2}, n_samples=10, n_objectives=3, n_features=10, epsilon=0.1, default_action="a1" +): + s = CmabBernoulliMOCC.cold_start( + action_ids_cost=action_ids_cost, + n_objectives=n_objectives, + n_features=n_features, + epsilon=epsilon, + default_action=default_action, + ) + context = np.random.uniform(low=-1.0, high=1.0, size=(n_samples, n_features)) + s.predict(context=context) diff --git a/tests/test_model.py b/tests/test_model.py index 2041cf2..a241cee 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -30,6 +30,8 @@ from pybandits.model import ( BayesianLogisticRegression, BayesianLogisticRegressionCC, + BayesianLogisticRegressionMO, + BayesianLogisticRegressionMOCC, Beta, BetaCC, BetaMO, @@ -55,7 +57,7 @@ def test_can_init_beta(success_counter, failure_counter): assert (b.n_successes, b.n_failures) == (1, 1) -def test_both_or_neither_counters_are_defined(): +def test_both_or_neither_models_are_defined(): with pytest.raises(ValidationError): Beta(n_successes=0) with pytest.raises(ValidationError): @@ -112,21 +114,21 @@ def test_can_init_betaCC(a_float): def test_can_init_base_beta_mo(): # init with default params - b = BetaMO(counters=[Beta(), Beta()]) - assert b.counters[0].n_successes == 1 and b.counters[0].n_failures == 1 - assert b.counters[1].n_successes == 1 and b.counters[1].n_failures == 1 + b = BetaMO(models=[Beta(), Beta()]) + assert b.models[0].n_successes == 1 and b.models[0].n_failures == 1 + assert b.models[1].n_successes == 1 and b.models[1].n_failures == 1 # init with empty dict - b = BetaMO(counters=[{}, {}]) - assert b.counters[0] == Beta() + b = BetaMO(models=[{}, {}]) + assert b.models[0] == Beta() # invalid init with BetaCC instead of Beta with pytest.raises(ValidationError): - BetaMO(counters=[BetaCC(cost=1), BetaCC(cost=1)]) + BetaMO(models=[BetaCC(cost=1), BetaCC(cost=1)]) def test_calculate_proba_beta_mo(): - b = BetaMO(counters=[Beta(), Beta()]) + b = BetaMO(models=[Beta(), Beta()]) b.sample_proba() @@ -139,12 +141,12 @@ def test_beta_update_mo(rewards1, rewards2): rewards1, rewards2 = rewards1[:min_len], rewards2[:min_len] rewards = [[a, b] for a, b in zip(rewards1, rewards2)] - b = BetaMO(counters=[Beta(n_successes=11, n_failures=22), Beta(n_successes=33, n_failures=44)]) + b = BetaMO(models=[Beta(n_successes=11, n_failures=22), Beta(n_successes=33, n_failures=44)]) b.update(rewards=rewards) assert b == BetaMO( - counters=[ + models=[ Beta(n_successes=11 + sum(rewards1), n_failures=22 + len(rewards1) - sum(rewards1)), Beta(n_successes=33 + sum(rewards2), n_failures=44 + len(rewards2) - sum(rewards2)), ] @@ -157,26 +159,6 @@ def test_beta_update_mo(rewards1, rewards2): ######################################################################################################################## -# BetaMO - - -def test_can_init_beta_mo(): - # init with default params - b = BetaMO(counters=[Beta(), Beta()]) - assert b.counters == [Beta(), Beta()] - - # init with empty dict - b = BetaMO(counters=[{}, {}]) - assert b.counters == [Beta(), Beta()] - - # invalid init with BetaCC instead of Beta - with pytest.raises(ValidationError): - BetaMO(counters=[BetaCC(cost=1), BetaCC(cost=1)]) - - -######################################################################################################################## - - # BetaMOCC @@ -184,21 +166,21 @@ def test_can_init_beta_mo(): def test_can_init_beta_mo_cc(a_float): if a_float < 0 or np.isnan(a_float): with pytest.raises(ValidationError): - BetaMOCC(counters=[Beta(), Beta()], cost=a_float) + BetaMOCC(models=[Beta(), Beta()], cost=a_float) else: # init with default params - b = BetaMOCC(counters=[Beta(), Beta()], cost=a_float) - assert b.counters == [Beta(), Beta()] + b = BetaMOCC(models=[Beta(), Beta()], cost=a_float) + assert b.models == [Beta(), Beta()] assert b.cost == a_float # init with empty dict - b = BetaMOCC(counters=[{}, {}], cost=a_float) - assert b.counters == [Beta(), Beta()] + b = BetaMOCC(models=[{}, {}], cost=a_float) + assert b.models == [Beta(), Beta()] assert b.cost == a_float # invalid init with BetaCC instead of Beta with pytest.raises(ValidationError): - BetaMOCC(counters=[BetaCC(cost=1), BetaCC(cost=1)], cost=a_float) + BetaMOCC(models=[BetaCC(cost=1), BetaCC(cost=1)], cost=a_float) ######################################################################################################################## @@ -382,3 +364,158 @@ def test_create_default_instance_bayesian_logistic_regression_cc(n_betas, cost): assert blr == BayesianLogisticRegressionCC( alpha=StudentT(), betas=[StudentT() for _ in range(n_betas)], cost=cost ) + + +######################################################################################################################## + + +# BayesianLogisticRegressionMO + + +@given(st.integers(max_value=10), st.integers(min_value=2, max_value=100)) +def test_can_init_bayesian_logistic_regression_mo(n_objectives, n_features): + # at least one blr must be specified + model = BayesianLogisticRegression(alpha=StudentT(), betas=[StudentT() for _ in range(n_features)]) + if n_objectives <= 0: + with pytest.raises(ValidationError): + BayesianLogisticRegressionMO(models=[model.model_copy(deep=True) for _ in range(n_objectives)]) + else: + blr_mo = BayesianLogisticRegressionMO(models=[model.model_copy(deep=True) for _ in range(n_objectives)]) + assert all(blr == model for blr in blr_mo.models) + + +@given(st.integers(max_value=10), st.integers(max_value=100)) +def test_create_default_instance_bayesian_logistic_regression_mo(n_objectives, n_features): + # at least one beta must be specified + if n_objectives <= 0 or n_features <= 0: + with pytest.raises(ValidationError): + BayesianLogisticRegressionMO.cold_start(n_features=n_features, n_objectives=n_objectives) + else: + blr_mo = BayesianLogisticRegressionMO.cold_start(n_features=n_features, n_objectives=n_objectives) + assert all( + blr == BayesianLogisticRegression(alpha=StudentT(), betas=[StudentT() for _ in range(n_features)]) + for blr in blr_mo.models + ) + + +@given( + st.integers(min_value=1, max_value=100), + st.integers(min_value=1, max_value=10), + st.integers(min_value=1, max_value=100), +) +def test_blr_mo_sample_proba(n_samples, n_objectives, n_features): + def sample_proba(context): + results = blr_mo.sample_proba(context=context) + prob, weighted_sum = zip(*results) # unpack the results + assert type(prob) is type(weighted_sum) is tuple # type of the returns must be np.ndarray + for p, ws in results: + assert type(p) is type(ws) is np.ndarray + assert len(p) == len(ws) == n_samples # return 1 sampled probability and ws per each sample + assert (np.clip(p, 0, 1) == p).all() # probs must be in the interval [0, 1] + + blr_mo = BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + + # context is numpy array + context = np.random.uniform(low=-100.0, high=100.0, size=(n_samples, n_features)) + assert type(context) is np.ndarray + sample_proba(context=context) + + # context is python list + context = context.tolist() + assert type(context) is list + sample_proba(context=context) + + # context is pandas DataFrame + context = pd.DataFrame(context) + assert type(context) is pd.DataFrame + sample_proba(context=context) + + +def test_blr_mo_update(n_samples=10, n_objectives=3, n_features=3): + def update(context, rewards): + blr_mo = BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + assert all( + [ + blr.alpha == StudentT(mu=0.0, sigma=10.0, nu=5.0) + and blr.betas == [StudentT(mu=0.0, sigma=10.0, nu=5.0)] * n_objectives + for blr in blr_mo.models + ] + ) + + blr_mo.update(context=context, rewards=rewards) + + assert all( + blr.alpha != StudentT(mu=0.0, sigma=10.0, nu=5.0) + and blr.betas != [StudentT(mu=0.0, sigma=10.0, nu=5.0)] * n_objectives + for blr in blr_mo.models + ) + + rewards = [np.random.choice([0, 1], size=n_objectives).tolist() for _ in range(n_samples)] + + # context is numpy array + context = np.random.uniform(low=-100.0, high=100.0, size=(n_samples, n_features)) + assert type(context) is np.ndarray + update(context=context, rewards=rewards) + + # context is python list + context = context.tolist() + assert type(context) is list + update(context=context, rewards=rewards) + + # context is pandas DataFrame + context = pd.DataFrame(context) + assert type(context) is pd.DataFrame + update(context=context, rewards=rewards) + + # raise an error if len(context) != len(rewards) + with pytest.raises(ValueError): + blr_mo = BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + blr_mo.update(context=context, rewards=rewards[1:]) + + # raise an error if n_objectives != len(rewards[0]) + with pytest.raises(AttributeError): + blr_mo = BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + blr_mo.update(context=context, rewards=[rewards[0][:1]] + rewards[1:]) + + # raise an error if n_objectives != len(rewards[*]) + with pytest.raises(AttributeError): + blr_mo = BayesianLogisticRegressionMO.cold_start(n_objectives=n_objectives, n_features=n_features) + blr_mo.update(context=context, rewards=[reward[:1] for reward in rewards]) + + +######################################################################################################################## + + +# BayesianLogisticRegressionMOCC + + +@given( + st.integers(max_value=10), st.integers(min_value=2, max_value=100), st.floats(allow_nan=False, allow_infinity=False) +) +def test_can_init_bayesian_logistic_regression_mocc(n_objectives, n_features, cost): + # at least one blr must be specified + model = BayesianLogisticRegression(alpha=StudentT(), betas=[StudentT() for _ in range(n_features)]) + if n_objectives <= 0 or cost < 0: + with pytest.raises(ValidationError): + BayesianLogisticRegressionMOCC(models=[model.model_copy(deep=True) for _ in range(n_objectives)], cost=cost) + else: + blr_mo = BayesianLogisticRegressionMOCC( + models=[model.model_copy(deep=True) for _ in range(n_objectives)], cost=cost + ) + assert all(blr == model for blr in blr_mo.models) + + +@given(st.integers(max_value=10), st.integers(max_value=100), st.floats(allow_nan=False, allow_infinity=False)) +def test_create_default_instance_bayesian_logistic_regression_mocc(n_objectives, n_features, cost): + # at least one beta must be specified + if n_objectives <= 0 or n_features <= 0 or cost < 0: + with pytest.raises(ValidationError): + BayesianLogisticRegressionMOCC.cold_start(n_objectives=n_objectives, n_features=n_features, cost=cost) + else: + blr_mocc = BayesianLogisticRegressionMOCC.cold_start( + n_objectives=n_objectives, n_features=n_features, cost=cost + ) + assert all( + blr == BayesianLogisticRegression(alpha=StudentT(), betas=[StudentT() for _ in range(n_features)]) + for blr in blr_mocc.models + ) diff --git a/tests/test_smab.py b/tests/test_smab.py index f11ec0c..00b216a 100644 --- a/tests/test_smab.py +++ b/tests/test_smab.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import json from copy import deepcopy from typing import List @@ -496,7 +495,7 @@ def test_smab_cc_from_state(state): assert isinstance(smab, SmabBernoulliCC) expected_actions = state["actions"] - actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + actual_actions = to_serializable_dict(smab.actions) # Normalize the dict assert expected_actions == actual_actions expected_subsidy_factor = smab.strategy.get_expected_value_from_state(state, "subsidy_factor") actual_subsidy_factor = smab.strategy.subsidy_factor @@ -520,14 +519,14 @@ def test_can_init_smab_mo(a_list): s = SmabBernoulliMO( actions={ "a1": BetaMO( - counters=[ + models=[ Beta(n_successes=a, n_failures=b), Beta(n_successes=c, n_failures=d), Beta(n_successes=e, n_failures=f), ] ), "a2": BetaMO( - counters=[ + models=[ Beta(n_successes=d, n_failures=a), Beta(n_successes=e, n_failures=b), Beta(n_successes=f, n_failures=c), @@ -536,14 +535,14 @@ def test_can_init_smab_mo(a_list): }, ) assert s.actions["a1"] == BetaMO( - counters=[ + models=[ Beta(n_successes=a, n_failures=b), Beta(n_successes=c, n_failures=d), Beta(n_successes=e, n_failures=f), ] ) assert s.actions["a2"] == BetaMO( - counters=[ + models=[ Beta(n_successes=d, n_failures=a), Beta(n_successes=e, n_failures=b), Beta(n_successes=f, n_failures=c), @@ -556,9 +555,9 @@ def test_all_actions_must_have_same_number_of_objectives_smab_mo(): with pytest.raises(ValueError): SmabBernoulliMO( actions={ - "a1": BetaMO(counters=[Beta(), Beta()]), - "a2": BetaMO(counters=[Beta(), Beta()]), - "a3": BetaMO(counters=[Beta(), Beta(), Beta()]), + "a1": BetaMO(models=[Beta(), Beta()]), + "a2": BetaMO(models=[Beta(), Beta()]), + "a3": BetaMO(models=[Beta(), Beta(), Beta()]), }, ) @@ -602,14 +601,14 @@ def test_smab_mo_get_state(a_list): actions = { "a1": BetaMO( - counters=[ + models=[ Beta(n_successes=a, n_failures=b), Beta(n_successes=c, n_failures=d), Beta(n_successes=e, n_failures=f), ] ), "a2": BetaMO( - counters=[ + models=[ Beta(n_successes=d, n_failures=a), Beta(n_successes=e, n_failures=b), Beta(n_successes=f, n_failures=c), @@ -640,7 +639,7 @@ def test_smab_mo_get_state(a_list): keys=st.text(min_size=1, max_size=10), values=st.fixed_dictionaries( { - "counters": st.lists( + "models": st.lists( st.fixed_dictionaries( { "n_successes": st.integers(min_value=1, max_value=100), @@ -663,7 +662,7 @@ def test_smab_mo_from_state(state): assert isinstance(smab, SmabBernoulliMO) expected_actions = state["actions"] - actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + actual_actions = to_serializable_dict(smab.actions) assert expected_actions == actual_actions # Ensure get_state and from_state compatibility @@ -684,7 +683,7 @@ def test_can_init_smab_mo_cc(a_list): s = SmabBernoulliMOCC( actions={ "a1": BetaMOCC( - counters=[ + models=[ Beta(n_successes=a, n_failures=b), Beta(n_successes=c, n_failures=d), Beta(n_successes=e, n_failures=f), @@ -692,7 +691,7 @@ def test_can_init_smab_mo_cc(a_list): cost=g, ), "a2": BetaMOCC( - counters=[ + models=[ Beta(n_successes=d, n_failures=a), Beta(n_successes=e, n_failures=b), Beta(n_successes=f, n_failures=c), @@ -702,7 +701,7 @@ def test_can_init_smab_mo_cc(a_list): }, ) assert s.actions["a1"] == BetaMOCC( - counters=[ + models=[ Beta(n_successes=a, n_failures=b), Beta(n_successes=c, n_failures=d), Beta(n_successes=e, n_failures=f), @@ -710,7 +709,7 @@ def test_can_init_smab_mo_cc(a_list): cost=g, ) assert s.actions["a2"] == BetaMOCC( - counters=[ + models=[ Beta(n_successes=d, n_failures=a), Beta(n_successes=e, n_failures=b), Beta(n_successes=f, n_failures=c), @@ -724,16 +723,14 @@ def test_all_actions_must_have_same_number_of_objectives_smab_mo_cc(): with pytest.raises(ValueError): SmabBernoulliMOCC( actions={ - "action 1": BetaMOCC(counters=[Beta(), Beta()], cost=1), - "action 2": BetaMOCC(counters=[Beta(), Beta()], cost=1), - "action 3": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=1), + "action 1": BetaMOCC(models=[Beta(), Beta()], cost=1), + "action 2": BetaMOCC(models=[Beta(), Beta()], cost=1), + "action 3": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=1), }, ) -def test_smab_mo_cc_predict(n_samples: int): - n_samples = 1000 - +def test_smab_mo_cc_predict(n_samples=1000): s = SmabBernoulliMOCC.cold_start(action_ids_cost={"a1": 1, "a2": 2}, n_objectives=2) forbidden = None @@ -782,7 +779,7 @@ def test_smab_mo_cc_get_state(a_list): actions = { "a1": BetaMOCC( - counters=[ + models=[ Beta(n_successes=a, n_failures=b), Beta(n_successes=c, n_failures=d), Beta(n_successes=e, n_failures=f), @@ -790,7 +787,7 @@ def test_smab_mo_cc_get_state(a_list): cost=g, ), "a2": BetaMOCC( - counters=[ + models=[ Beta(n_successes=d, n_failures=a), Beta(n_successes=e, n_failures=b), Beta(n_successes=f, n_failures=c), @@ -822,7 +819,7 @@ def test_smab_mo_cc_get_state(a_list): keys=st.text(min_size=1, max_size=10), values=st.fixed_dictionaries( { - "counters": st.lists( + "models": st.lists( st.fixed_dictionaries( { "n_successes": st.integers(min_value=1, max_value=100), @@ -877,9 +874,7 @@ def test_can_instantiate_epsilon_greddy_smab_with_params(a, b): assert s.actions["action1"] == s.actions["action2"] -def test_epsilon_greedy_smab_predict(n_samples: int): - n_samples = 1000 - +def test_epsilon_greedy_smab_predict(n_samples=1000): s = SmabBernoulli( actions={ "a0": Beta(), @@ -897,14 +892,12 @@ def test_epsilon_greedy_smab_predict(n_samples: int): _, _ = s.predict(n_samples=n_samples, forbidden_actions=forbidden_actions) -def test_epsilon_greddy_smabbai_predict(n_samples: int): - n_samples = 1000 +def test_epsilon_greddy_smabbai_predict(n_samples=1000): s = SmabBernoulliBAI(actions={"a1": Beta(), "a2": Beta()}, epsilon=0.1, default_action="a1") _, _ = s.predict(n_samples=n_samples) -def test_epsilon_greddy_smabcc_predict(n_samples: int): - n_samples = 1000 +def test_epsilon_greddy_smabcc_predict(n_samples=1000): s = SmabBernoulliCC( actions={ "a1": BetaCC(n_successes=1, n_failures=2, cost=10), @@ -917,18 +910,14 @@ def test_epsilon_greddy_smabcc_predict(n_samples: int): _, _ = s.predict(n_samples=n_samples) -def test_epsilon_greddy_smab_mo_predict(n_samples: int): - n_samples = 1000 - +def test_epsilon_greddy_smab_mo_predict(n_samples=1000): s = SmabBernoulliMO.cold_start(action_ids={"a1", "a2"}, n_objectives=3, epsilon=0.1, default_action="a1") forbidden = None s.predict(n_samples=n_samples, forbidden_actions=forbidden) -def test_epsilon_greddy_smab_mo_cc_predict(n_samples: int): - n_samples = 1000 - +def test_epsilon_greddy_smab_mo_cc_predict(n_samples=1000): s = SmabBernoulliMOCC.cold_start( action_ids_cost={"a1": 1, "a2": 2}, n_objectives=2, epsilon=0.1, default_action="a1" ) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 8c5165f..737886b 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -351,11 +351,11 @@ def test_select_action_mo_cc(): m = MultiObjectiveCostControlBandit() actions = { - "a1": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=8), - "a2": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=2), - "a3": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=5), - "a4": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=1), - "a5": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=7), + "a1": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=8), + "a2": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=2), + "a3": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=5), + "a4": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=1), + "a5": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=7), } p = { "a1": [0.1, 0.3, 0.5], @@ -369,9 +369,9 @@ def test_select_action_mo_cc(): assert m.select_action(p=p, actions=actions) == "a4" actions = { - "a1": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=2), - "a2": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=2), - "a3": BetaMOCC(counters=[Beta(), Beta(), Beta()], cost=5), + "a1": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=2), + "a2": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=2), + "a3": BetaMOCC(models=[Beta(), Beta(), Beta()], cost=5), } p = { "a1": [0.6, 0.1, 0.1],