Skip to content

Commit

Permalink
Merge pull request #17 from Bayer-Group/v0.1.2
Browse files Browse the repository at this point in the history
V0.1.2
  • Loading branch information
sprivite authored Mar 27, 2024
2 parents 7d3ac5a + db918b1 commit 0809bb9
Show file tree
Hide file tree
Showing 20 changed files with 2,941 additions and 1,773 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ mlruns/*
docs/*
.devcontainer/*
.idea
.idea/*
.idea/*
venv/*
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
BSD 3-Clause License

Copyright (c) 2022, Bayer AG.
Copyright (c) 2024, Bayer AG.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@ The `pybalance` library implements several routines for optimizing the balance
between non-random populations. In observational studies, this matching process
is a key step towards minimizing the potential effects of confounding
covariates. The official documentation is hosted [here](https://bayer-group.github.io/pybalance/).
An application of this library to matchng in the pharmaceutical setting is presented here: [here](https://onlinelibrary.wiley.com/doi/10.1002/pst.2352).

## Features

- Implements linear and non-linear optimization approaches for matching.
- Utilizes integer program solvers and evolutionary solvers for optimization.
- Includes implementation of propensity score matching for comparison.
- Offers a variety of balance calculators and matchers.
- Provides visualization tools for analysis.
- Supports simulation of datasets for testing and demonstration purposes.
5 changes: 2 additions & 3 deletions environments/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ seaborn>=0.12.2
torch>=2.0.0

# the two below are very particular: keep them set at a fixed version
cvxpy==1.3.1
ortools==9.4.1874
ortools==9.9.3963

# FIXME this isn't really a dependency, only needed for running in our
# infrastructure
boto3
fsspec
s3fs
s3fs
17 changes: 0 additions & 17 deletions pybalance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,3 @@
import pybalance.sim
import pybalance.visualization
import pybalance.lp

__version__ = "0.1.1"

import logging

logger = logging.getLogger(__name__)
logger.info(f"Loaded pybalance version {__version__}.")

# Logging is configured at the application level. To adjust logging level,
# configure logging as below before importing pybalance:
#
# import logging
# logging.basicConfig(
# format="%(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
# level='INFO',
# )
#
2 changes: 1 addition & 1 deletion pybalance/genetic/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

def _check_fitted(matcher):
if matcher.best_match is None:
raise (ValueError, "Matcher has not been fitted!")
raise ValueError("Matcher has not been fitted!")


def get_global_defaults(n_candidate_populations=5000):
Expand Down
4 changes: 2 additions & 2 deletions pybalance/lp/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def _check_fitted(matcher):
if matcher.best_match is None:
raise (ValueError, "Matcher has not been fitted!")
raise ValueError("Matcher has not been fitted!")


def compute_truncation_error(x: np.ndarray) -> float:
Expand Down Expand Up @@ -449,7 +449,7 @@ def match(self, hint: Optional[List[int]] = None) -> MatchingData:

if self.verbose:
solution_printer = SolutionPrinter(x, abs_deltas, self)
status = solver.SolveWithSolutionCallback(model, solution_printer)
status = solver.Solve(model, solution_printer)
logger.info("Status = %s" % solver.StatusName(status))
logger.info(
"Number of solutions found: %i" % solution_printer.solution_count()
Expand Down
2 changes: 1 addition & 1 deletion pybalance/propensity/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

def _check_fitted(matcher):
if matcher.best_match is None:
raise (ValueError, "Matcher has not been fitted!")
raise ValueError("Matcher has not been fitted!")


class PropensityScoreMatcher:
Expand Down
1 change: 1 addition & 0 deletions pybalance/sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
generate_random_feature_data_rct,
generate_random_feature_data_rwd,
generate_toy_dataset,
get_paper_dataset_path,
load_paper_dataset,
)
14 changes: 12 additions & 2 deletions pybalance/sim/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,22 @@ def generate_toy_dataset(n_pool=10000, n_target=1000, seed=45):
return MatchingData(feature_data)


def load_paper_dataset():
def get_paper_dataset_path():
"""
Load the simulated matching dataset presented in the pybalance paper.
Get the path to the simulated matching dataset presented in the pybalance paper
(https://onlinelibrary.wiley.com/doi/10.1002/pst.2352).
"""
filepath = "pool250000-target25000-normal0-lognormal0-binary4.parquet"
resource = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "data", filepath
)
return resource


def load_paper_dataset():
"""
Load the simulated matching dataset presented in the pybalance paper
(https://onlinelibrary.wiley.com/doi/10.1002/pst.2352).
"""
resource = get_paper_dataset_path()
return MatchingData(resource)
75 changes: 36 additions & 39 deletions pybalance/visualization/distributions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Some helpful functions for plotting distribution.
"""

from collections import defaultdict
from typing import List, Optional
import itertools
Expand Down Expand Up @@ -243,9 +244,9 @@ def plot_binary_features(
frequencies.loc[:, "difference"] = frequencies.loc[:, "difference"] / np.sqrt(
variance
)
difference_label = "Standard Difference"
difference_label = "Std. Mean\nDifference"
else:
difference_label = "Abs Difference"
difference_label = "Abs. Mean\nDifference"

# Restrict to top features
frequencies = frequencies.sort_values(
Expand Down Expand Up @@ -275,75 +276,81 @@ def plot_binary_features(
fig, axes = plt.subplots(
nrows=2,
ncols=1,
figsize=(len(pool_frequencies) / 2, 16),
figsize=(len(pool_frequencies) / 2, 8),
gridspec_kw={"height_ratios": [1, 3]},
)

plt.subplot(2, 1, 1)
sns.barplot(data=frequencies, y="difference", x="index", **default_params)
xticks, labels = plt.xticks()
plt.gca().set_xticks(xticks + 0.5, minor=False)
plt.gca().set_xticks(xticks, minor=True)
plt.gca().set_xticklabels([""] * len(labels), minor=True, rotation=90)
ticks, labels = plt.xticks()
ticks = np.array(ticks)
plt.gca().set_xticks(ticks + 0.5, minor=False)
plt.gca().set_xticklabels([""] * len(labels), minor=True)
plt.gca().set_xticklabels([""] * len(labels), minor=False)
plt.grid(True)
plt.axhline(
y=0.1,
xmin=xticks.min(),
xmax=xticks.max(),
xmin=ticks.min(),
xmax=ticks.max(),
c="k",
lw=2.5,
zorder=10000,
linestyle="--",
)
plt.ylim([0, 0.25])
plt.ylabel("Abs Difference", fontsize=18)
plt.ylabel(difference_label, fontsize=14)
plt.gca().get_legend().remove()
plt.xlabel("")

plt.subplot(2, 1, 2)
sns.barplot(data=frequencies, y="value", x="index", **default_params)
xticks, labels = plt.xticks()
plt.gca().set_xticks(xticks + 0.5, minor=False)
plt.gca().set_xticks(xticks, minor=True)
plt.gca().set_xticklabels(
labels, minor=True, rotation=45, ha="right", fontsize=16
)
ticks, labels = plt.xticks()
ticks = np.array(ticks)
plt.gca().set_xticks(ticks, minor=True)
plt.gca().set_xticklabels(labels, minor=True)
plt.gca().set_xticks(ticks + 0.5, minor=False)
plt.gca().set_xticklabels([""] * len(labels), minor=False)
plt.xticks(rotation=90, fontsize=12, ha="right", minor=True)
plt.grid(True)
plt.ylabel("Frequency", fontsize=18)
plt.xlabel("Feature")
plt.ylabel("Frequency", fontsize=14)
plt.xlabel("Feature", fontsize=14)

else:
fig, axes = plt.subplots(
nrows=1,
ncols=2,
figsize=(16, len(pool_frequencies) / 2),
figsize=(8, len(pool_frequencies) / 2),
gridspec_kw={"width_ratios": [3, 1]},
)

plt.subplot(1, 2, 2)
sns.barplot(data=frequencies, x="difference", y="index", **default_params)
ticks, labels = plt.yticks()
ticks = np.array(ticks)
plt.gca().set_yticks(ticks + 0.5, minor=False)
plt.gca().set_yticks(ticks, minor=True)
plt.gca().set_yticklabels([""] * len(labels), minor=True)
plt.gca().set_yticklabels([""] * len(labels), minor=False)
plt.grid(True)
plt.axvline(
x=0.1, ymin=ticks.min(), ymax=ticks.max(), c="k", lw=2.5, linestyle="--"
)
plt.xlim([0, 0.25])
plt.xlabel(difference_label, fontsize=18)
plt.xlabel(difference_label, fontsize=14)
plt.gca().get_legend().remove()
plt.ylabel("")

plt.subplot(1, 2, 1)
sns.barplot(data=frequencies, x="value", y="index", **default_params)
ticks, labels = plt.yticks()
plt.gca().set_yticks(ticks + 0.5, minor=False)
ticks = np.array(ticks)
plt.gca().set_yticks(ticks, minor=True)
plt.gca().set_yticklabels(labels, minor=True, fontsize=16)
plt.gca().set_yticklabels(labels, minor=True)
plt.gca().set_yticks(ticks + 0.5, minor=False)
plt.gca().set_yticklabels([""] * len(labels), minor=False)
plt.yticks(rotation=0, fontsize=12, minor=True)
plt.grid(True)
plt.xlabel("Frequency", fontsize=18)
plt.ylabel("Feature")
plt.xlabel("Frequency", fontsize=14)
plt.ylabel("Feature", fontsize=14)

plt.tight_layout()
return fig
Expand Down Expand Up @@ -457,8 +464,8 @@ def plot_per_feature_loss(

plt.ylim(ymin=0)
ymin, ymax = fig.gca().get_ylim()
plt.vlines(plt.xticks()[0] + 0.5, ymin, ymax, linewidth=0.5, color="k")
plt.vlines(plt.xticks()[0][0] - 0.5, ymin, ymax, linewidth=0.5, color="k")
plt.vlines(np.array(plt.xticks()[0]) + 0.5, ymin, ymax, linewidth=0.5, color="k")
plt.vlines(np.array(plt.xticks()[0]) - 0.5, ymin, ymax, linewidth=0.5, color="k")

plt.xticks(rotation=90)
xmin, xmax = plt.xticks()[0][0] - 0.5, plt.xticks()[0][-1] + 0.5
Expand Down Expand Up @@ -493,7 +500,7 @@ def plot_joint_numeric_categoric_distributions(
g = sns.JointGrid(data=matching_data.data, x=x, y=y, **default_params)
grids.append(g)

g.plot_joint(sns.violinplot, s=25, split=True, saturation=0.9, dodge=True)
g.plot_joint(sns.violinplot, split=True, saturation=0.9, dodge=True)
g.ax_joint.grid(True)

sns.histplot(
Expand Down Expand Up @@ -551,17 +558,7 @@ def plot_joint_numeric_distributions(
g.plot_joint(sns.kdeplot, levels=5)

elif joint_kind == "scatter":
# Give larger populations more alpha, so they don't overwhelm the plot
counts = matching_data.counts()
weights = (1 - counts / counts.sum()).reset_index()
alpha = np.clip(
matching_data.data.merge(weights, on=matching_data.population_col)[
"N"
].values,
0.25,
1,
)
g.plot_joint(sns.scatterplot, s=25, alpha=alpha)
g.plot_joint(sns.scatterplot, s=25)

else:
raise NotImplementedError(f"Unsupported joint_kind: {joint_kind}.")
Expand Down
12 changes: 7 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import setuptools
from pybalance import __version__

__version__ = "0.1.1"

with open("README.md", "r") as fh:
long_description = fh.read()
Expand All @@ -10,8 +11,8 @@
setuptools.setup(
name="pybalance",
version=__version__,
author="IEG Data Science",
author_email="author@example.com",
author="Stephen Privitera",
author_email="stephen.privitera@bayer.com",
description="Population Matching",
long_description=long_description,
install_requires=requirements,
Expand All @@ -20,8 +21,9 @@
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"License :: OSI Approved :: BSD 3-Clause License",
"Operating System :: OS Independent",
],
python_requires=">=3.6",
python_requires=">=3.8",
package_data={"pybalance": ["sim/data/*parquet", "sim/data/*csv"]},
)
7 changes: 4 additions & 3 deletions sphinx/00_introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ searches the solution space. An evolutionary solver is also implemented in
`pybalance`, together with a number of heuristics for efficiently
searching the space.

For completeness and ease of comparison, `pybalance` also implements
matching based on propensity score. For greater technical detail, see our
publication (FIXME currently in review).
For completeness and ease of comparison, `pybalance` also implements matching
based on propensity score. For greater technical detail as well as applications,
see our publication `here
<https://onlinelibrary.wiley.com/doi/10.1002/pst.2352>`_.
18 changes: 16 additions & 2 deletions sphinx/01_installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,23 @@ Installation instructions
Install from PIP
=========================

For now, everyone must install from source. Installation from PIP will be made
available soon.
We reccomend using virtualenv and installing pybalance using pip:

>>> python3.9 -m venv venv/pybalance
>>> source venv/pybalance/bin/activate
>>> pip install --upgrade pip
>>> pip install pybalance

If you wish you use pybalance within a jupyter notebook, you will also need to
install jupyter:

>>> pip install jupyter

and then register your enviroment with jupyter:

>>> python -m ipykernel install --user --name=pybalance

Make sure to select the pybalance kernel when running the notebook.

Install from source
=========================
Expand Down
4 changes: 2 additions & 2 deletions sphinx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

# -- Project information -----------------------------------------------------

project = "Population Matching"
project = "PyBalance"
author = "Stephen Privitera, Hooman Sedghamiz, Alex Hartenstein"
copyright = f"2022 - Bayer AG - {author}"
copyright = f"2024 - Bayer AG - {author}"

# The full version, including alpha/beta/rc tags
release = "0.0.1"
Expand Down
Loading

0 comments on commit 0809bb9

Please sign in to comment.