Skip to content

Commit

Permalink
Vectorized Predict for Non-contextual Policies and Scaler Refactor (#53
Browse files Browse the repository at this point in the history
…-54)
  • Loading branch information
bkleyn authored Mar 18, 2022
1 parent e03d242 commit 329125d
Show file tree
Hide file tree
Showing 43 changed files with 672 additions and 332 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
MABWiser CHANGELOG
=====================

March, 17, 2022 2.4.0
-------------------------------------------------------------------------------
major:
- Implement vectorized functions for non-contextual policies to speed-up prediction for multiple decisions.
- Change MAB predict and predict_expectations to allow empty contexts to be specified for non-contextual policies.
- Update scaler use in Linear policies so that standard scaler can be fit directly instead of pre-trained scalers.
- Change scaler argument from pre-trained `arm_to_scaler` input to a boolean scale flag.

March, 8, 2022 2.3.0
-------------------------------------------------------------------------------
major:
Expand Down
2 changes: 1 addition & 1 deletion docs/.buildinfo
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Sphinx build info version 1
# This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
config: be8d7ae06467e3a22a202d69cb2b24ab
config: 3d55ee819e2d2bd60e727943c52c4b22
tags: 645f666f9bcd5a90fca523b33c5a78b7
2 changes: 1 addition & 1 deletion docs/_static/documentation_options.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
VERSION: '2.3.0',
VERSION: '2.4.0',
LANGUAGE: 'None',
COLLAPSE_INDEX: false,
BUILDER: 'html',
Expand Down
2 changes: 1 addition & 1 deletion docs/about.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>About Multi-Armed Bandits &mdash; MABWiser 2.3.0 documentation</title>
<title>About Multi-Armed Bandits &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
119 changes: 55 additions & 64 deletions docs/api.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/contributing.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Contributing &mdash; MABWiser 2.3.0 documentation</title>
<title>Contributing &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
2 changes: 1 addition & 1 deletion docs/examples.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Usage Examples &mdash; MABWiser 2.3.0 documentation</title>
<title>Usage Examples &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
32 changes: 16 additions & 16 deletions docs/genindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Index &mdash; MABWiser 2.3.0 documentation</title>
<title>Index &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down Expand Up @@ -111,21 +111,13 @@ <h2 id="A">A</h2>
<li><a href="api.html#mabwiser.utils.argmax">argmax() (in module mabwiser.utils)</a>
</li>
<li><a href="api.html#mabwiser.utils.argmin">argmin() (in module mabwiser.utils)</a>
</li>
<li><a href="api.html#mabwiser.utils.Arm">Arm (in module mabwiser.utils)</a>
</li>
<li><a href="api.html#mabwiser.base_mab.BaseMAB.arm_to_expectation">arm_to_expectation (mabwiser.base_mab.BaseMAB attribute)</a>
</li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="api.html#id1">arm_to_scaler (mabwiser.mab.LearningPolicy.LinGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinGreedy.arm_to_scaler">[1]</a>

<ul>
<li><a href="api.html#id5">(mabwiser.mab.LearningPolicy.LinTS attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinTS.arm_to_scaler">[1]</a>
<li><a href="api.html#mabwiser.utils.Arm">Arm (in module mabwiser.utils)</a>
</li>
<li><a href="api.html#id8">(mabwiser.mab.LearningPolicy.LinUCB attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinUCB.arm_to_scaler">[1]</a>
<li><a href="api.html#mabwiser.base_mab.BaseMAB.arm_to_expectation">arm_to_expectation (mabwiser.base_mab.BaseMAB attribute)</a>
</li>
</ul></li>
<li><a href="api.html#mabwiser.simulator.Simulator.arm_to_stats_test">arm_to_stats_test (mabwiser.simulator.Simulator attribute)</a>
</li>
<li><a href="api.html#mabwiser.simulator.Simulator.arm_to_stats_total">arm_to_stats_total (mabwiser.simulator.Simulator attribute)</a>
Expand Down Expand Up @@ -223,7 +215,7 @@ <h2 id="E">E</h2>
<li><a href="api.html#id0">epsilon (mabwiser.mab.LearningPolicy.EpsilonGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.EpsilonGreedy.epsilon">[1]</a>

<ul>
<li><a href="api.html#id2">(mabwiser.mab.LearningPolicy.LinGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinGreedy.epsilon">[1]</a>
<li><a href="api.html#id1">(mabwiser.mab.LearningPolicy.LinGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinGreedy.epsilon">[1]</a>
</li>
</ul></li>
</ul></td>
Expand Down Expand Up @@ -284,12 +276,12 @@ <h2 id="K">K</h2>
<h2 id="L">L</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="api.html#id3">l2_lambda (mabwiser.mab.LearningPolicy.LinGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinGreedy.l2_lambda">[1]</a>
<li><a href="api.html#id2">l2_lambda (mabwiser.mab.LearningPolicy.LinGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinGreedy.l2_lambda">[1]</a>

<ul>
<li><a href="api.html#id6">(mabwiser.mab.LearningPolicy.LinTS attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinTS.l2_lambda">[1]</a>
<li><a href="api.html#id5">(mabwiser.mab.LearningPolicy.LinTS attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinTS.l2_lambda">[1]</a>
</li>
<li><a href="api.html#id9">(mabwiser.mab.LearningPolicy.LinUCB attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinUCB.l2_lambda">[1]</a>
<li><a href="api.html#id8">(mabwiser.mab.LearningPolicy.LinUCB attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinUCB.l2_lambda">[1]</a>
</li>
</ul></li>
<li><a href="api.html#mabwiser.mab.MAB.learning_policy">learning_policy (mabwiser.mab.MAB attribute)</a>
Expand Down Expand Up @@ -483,10 +475,18 @@ <h2 id="R">R</h2>
<h2 id="S">S</h2>
<table style="width: 100%" class="indextable genindextable"><tr>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="api.html#mabwiser.simulator.Simulator.scaler">scaler (mabwiser.simulator.Simulator attribute)</a>
<li><a href="api.html#id3">scale (mabwiser.mab.LearningPolicy.LinGreedy attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinGreedy.scale">[1]</a>

<ul>
<li><a href="api.html#id6">(mabwiser.mab.LearningPolicy.LinTS attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinTS.scale">[1]</a>
</li>
<li><a href="api.html#id9">(mabwiser.mab.LearningPolicy.LinUCB attribute)</a>, <a href="api.html#mabwiser.mab.LearningPolicy.LinUCB.scale">[1]</a>
</li>
</ul></li>
</ul></td>
<td style="width: 33%; vertical-align: top;"><ul>
<li><a href="api.html#mabwiser.simulator.Simulator.scaler">scaler (mabwiser.simulator.Simulator attribute)</a>
</li>
<li><a href="api.html#mabwiser.mab.MAB.seed">seed (mabwiser.mab.MAB attribute)</a>
</li>
<li><a href="api.html#mabwiser.simulator.Simulator">Simulator (class in mabwiser.simulator)</a>
Expand Down
2 changes: 1 addition & 1 deletion docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>MABWiser Contextual Multi-Armed Bandits &mdash; MABWiser 2.3.0 documentation</title>
<title>MABWiser Contextual Multi-Armed Bandits &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Installation &mdash; MABWiser 2.3.0 documentation</title>
<title>Installation &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
2 changes: 1 addition & 1 deletion docs/new_bandit.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Adding a New Bandit &mdash; MABWiser 2.3.0 documentation</title>
<title>Adding a New Bandit &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
Binary file modified docs/objects.inv
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/py-modindex.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Python Module Index &mdash; MABWiser 2.3.0 documentation</title>
<title>Python Module Index &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
2 changes: 1 addition & 1 deletion docs/quick.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<meta charset="utf-8" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Quick Start &mdash; MABWiser 2.3.0 documentation</title>
<title>Quick Start &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
<!--[if lt IE 9]>
Expand Down
2 changes: 1 addition & 1 deletion docs/search.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Search &mdash; MABWiser 2.3.0 documentation</title>
<title>Search &mdash; MABWiser 2.4.0 documentation</title>
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />

Expand Down
2 changes: 1 addition & 1 deletion docs/searchindex.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mabwiser/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

__author__ = "FMR LLC"
__email__ = "opensource@fmr.com"
__version__ = "2.3.0"
__version__ = "2.4.0"
__copyright__ = "Copyright (C), FMR LLC"
15 changes: 8 additions & 7 deletions mabwiser/base_mab.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import abc
from itertools import chain
from typing import Callable, Dict, List, NoReturn, Optional
from typing import Callable, Dict, List, NoReturn, Optional, Union
import multiprocessing as mp

from joblib import Parallel, delayed
Expand Down Expand Up @@ -57,7 +57,7 @@ class BaseMAB(metaclass=abc.ABCMeta):
If set to -2, all CPUs but one are used, and so on.
backend: str, optional
Specify a parallelization backend implementation supported in the joblib library. Supported options are:
- “loky” used by default, can induce some communication and memory overhead when exchanging input and output data with the worker Python processes.
- “loky” used by default, can induce some communication and memory overhead when exchanging input and output.
- “multiprocessing” previous process-based backend based on multiprocessing.Pool. Less robust than loky.
- “threading” is a very low-overhead backend but it suffers from the Python Global Interpreter Lock if the
called function relies a lot on Python objects.
Expand Down Expand Up @@ -86,14 +86,14 @@ def __init__(self, rng: _BaseRNG, arms: List[Arm], n_jobs: int, backend: str = N
self.cold_arm_to_warm_arm: Dict[Arm, Arm] = dict()
self.trained_arms: List[Arm] = list()

def add_arm(self, arm: Arm, binarizer: Callable = None, scaler: Callable = None) -> NoReturn:
def add_arm(self, arm: Arm, binarizer: Callable = None) -> NoReturn:
"""Introduces a new arm to the bandit.
Adds the new arm with zero expectations and
calls the ``_uptake_new_arm()`` function of the sub-class.
"""
self.arm_to_expectation[arm] = 0
self._uptake_new_arm(arm, binarizer, scaler)
self._uptake_new_arm(arm, binarizer)

def remove_arm(self, arm: Arm) -> NoReturn:
"""Removes arm from the bandit.
Expand Down Expand Up @@ -122,15 +122,16 @@ def partial_fit(self, decisions: np.ndarray, rewards: np.ndarray,
pass

@abc.abstractmethod
def predict(self, contexts: Optional[np.ndarray] = None) -> Arm:
def predict(self, contexts: Optional[np.ndarray] = None) -> Union[Arm, List[Arm]]:
"""Abstract method.
Returns the predicted arm.
"""
pass

@abc.abstractmethod
def predict_expectations(self, contexts: Optional[np.ndarray] = None) -> Dict[Arm, Num]:
def predict_expectations(self, contexts: Optional[np.ndarray] = None) -> Union[Dict[Arm, Num],
List[Dict[Arm, Num]]]:
"""Abstract method.
Returns a dictionary from arms (keys) to their expected rewards (values).
Expand All @@ -146,7 +147,7 @@ def _copy_arms(self, cold_arm_to_warm_arm: Dict[Arm, Arm]) -> NoReturn:
pass

@abc.abstractmethod
def _uptake_new_arm(self, arm: Arm, binarizer: Callable = None, scaler: Callable = None) -> NoReturn:
def _uptake_new_arm(self, arm: Arm, binarizer: Callable = None) -> NoReturn:
"""Abstract method.
Updates the multi-armed bandit with the new arm.
Expand Down
4 changes: 2 additions & 2 deletions mabwiser/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def partial_fit(self, decisions: np.ndarray, rewards: np.ndarray,

self._fit_operation()

def predict(self, contexts: Optional[np.ndarray] = None):
def predict(self, contexts: np.ndarray = None) -> Union[Arm, List[Arm]]:
# Return predict within the cluster
return self._parallel_predict(contexts, is_predict=True)

def predict_expectations(self, contexts: Optional[np.ndarray] = None):
def predict_expectations(self, contexts: np.ndarray = None) -> Union[Dict[Arm, Num], List[Dict[Arm, Num]]]:
# Return predict expectations within the cluster
return self._parallel_predict(contexts, is_predict=False)

Expand Down
32 changes: 23 additions & 9 deletions mabwiser/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Callable, Dict, List, NoReturn, Optional
from typing import Callable, Dict, List, NoReturn, Optional, Union

import numpy as np

Expand Down Expand Up @@ -32,19 +32,33 @@ def fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray =
def partial_fit(self, decisions: np.ndarray, rewards: np.ndarray, contexts: np.ndarray = None) -> NoReturn:
self._parallel_fit(decisions, rewards, contexts)

def predict(self, contexts: np.ndarray = None) -> Arm:
def predict(self, contexts: Optional[np.ndarray] = None) -> Union[Arm, List[Arm]]:

# Return the first arm with maximum expectation
return argmax(self.predict_expectations())
# Return the arm with maximum expectation
expectations = self.predict_expectations(contexts)
if isinstance(expectations, dict):
return argmax(expectations)
else:
return [argmax(exp) for exp in expectations]

def predict_expectations(self, contexts: np.ndarray = None) -> Dict[Arm, Num]:
def predict_expectations(self, contexts: Optional[np.ndarray] = None) -> Union[Dict[Arm, Num],
List[Dict[Arm, Num]]]:

# Return a random expectation (between 0 and 1) for each arm with epsilon probability,
# and the actual arm expectations otherwise
if self.rng.rand() < self.epsilon:
return dict((arm, self.rng.rand()) for arm in self.arms).copy()
# and the actual arm expectations otherwise.
# If contexts is None or has length of 1 generate single arm to expectations,
# otherwise use vectorized functions to generate a list of arm to expectations with same length as contexts.
if contexts is None or len(contexts) == 1:
if self.rng.rand() < self.epsilon:
return dict((arm, self.rng.rand()) for arm in self.arms).copy()
else:
return self.arm_to_expectation.copy()
else:
return self.arm_to_expectation.copy()
probability = self.rng.rand(len(contexts))
random_values = self.rng.rand((len(contexts), len(self.arms)))
expectations = [dict(zip(self.arms, exp)).copy() if probability[index] < self.epsilon
else self.arm_to_expectation.copy() for index, exp in enumerate(random_values)]
return expectations

def _copy_arms(self, cold_arm_to_warm_arm):
for cold_arm, warm_arm in cold_arm_to_warm_arm.items():
Expand Down
Loading

0 comments on commit 329125d

Please sign in to comment.