Skip to content

Commit

Permalink
Merge pull request #93 from danielward27/formatting
Browse files Browse the repository at this point in the history
black format examples
  • Loading branch information
danielward27 authored Aug 26, 2023
2 parents f1970df + 44e9457 commit c3af9b1
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 79 deletions.
41 changes: 25 additions & 16 deletions docs/examples/bounded.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
],
"source": [
"key, x_key = jr.split(jr.PRNGKey(0))\n",
"x = jr.beta(x_key, a=0.4, b=0.4, shape=(5000, 2)) # Supported on the interval [0, 1]^2 "
"x = jr.beta(x_key, a=0.4, b=0.4, shape=(5000, 2)) # Supported on the interval [0, 1]^2"
]
},
{
Expand Down Expand Up @@ -97,20 +97,24 @@
"source": [
"eps = 1e-7 # Avoid potential numerical issues\n",
"\n",
"preprocess = Chain([\n",
" Affine(loc=-jnp.ones(2) + eps, scale=(1-eps)*jnp.array([2, 2])), # to [-1+eps, 1-eps]\n",
" Invert(Tanh(shape=(2,))) # arctanh (to unbounded)\n",
" ])\n",
"preprocess = Chain(\n",
" [\n",
" Affine(\n",
" loc=-jnp.ones(2) + eps, scale=(1 - eps) * jnp.array([2, 2])\n",
" ), # [-1+eps, 1-eps]\n",
" Invert(Tanh(shape=(2,))), # arctanh (to unbounded)\n",
" ]\n",
")\n",
"\n",
"x_preprocessed = jax.vmap(preprocess.transform)(x)\n",
"\n",
"# Plot the data\n",
"fig, axes = plt.subplots(ncols=2)\n",
"for (k, v), ax in zip({\"Raw\": x, \"Preprocessed\": x_preprocessed}.items(), axes):\n",
"plot_data = {\"Raw\": x, \"Preprocessed\": x_preprocessed}\n",
"for (k, v), ax in zip(plot_data.items(), axes, strict=True):\n",
" ax.scatter(v[:, 0], v[:, 1], s=0.3)\n",
" ax.set_title(k)\n",
" ax.set_aspect('equal')\n",
" "
" ax.set_aspect(\"equal\")"
]
},
{
Expand Down Expand Up @@ -148,11 +152,16 @@
"\n",
"# Train on the unbounded space\n",
"flow, losses = fit_to_data(\n",
" subkey, untrained_flow, x_preprocessed, learning_rate=5e-3, max_patience=10, max_epochs=70\n",
" )\n",
" key=subkey,\n",
" dist=untrained_flow,\n",
" x=x_preprocessed,\n",
" learning_rate=5e-3,\n",
" max_patience=10,\n",
" max_epochs=70,\n",
")\n",
"\n",
"# Transform flow back to bounded space\n",
"flow = Transformed(flow, Invert(preprocess))\n"
"flow = Transformed(flow, Invert(preprocess))"
]
},
{
Expand Down Expand Up @@ -186,7 +195,7 @@
"source": [
"naive_flow, losses = fit_to_data(\n",
" subkey, untrained_flow, x, learning_rate=5e-3, max_patience=10, max_epochs=70\n",
" )"
")"
]
},
{
Expand Down Expand Up @@ -217,16 +226,16 @@
"key, *subkeys = jr.split(key, 3)\n",
"samples = {\n",
" \"True Distribution\": x,\n",
" \"Bounded Flow\": flow.sample(subkeys[0], (x.shape[0], )),\n",
" \"Naive Flow\": naive_flow.sample(subkeys[1], (x.shape[0], )),\n",
" \"Bounded Flow\": flow.sample(subkeys[0], (x.shape[0],)),\n",
" \"Naive Flow\": naive_flow.sample(subkeys[1], (x.shape[0],)),\n",
"}\n",
"\n",
"fig, axes = plt.subplots(ncols=3, sharex=True, sharey=True)\n",
"\n",
"for (k, v), ax in zip(samples.items(), axes):\n",
"for (k, v), ax in zip(samples.items(), axes, strict=True):\n",
" ax.scatter(v[:, 0], v[:, 1], s=0.3)\n",
" ax.set_title(k)\n",
" ax.set_aspect('equal')\n",
" ax.set_aspect(\"equal\")\n",
"\n",
"ax.set_xlim((-0.2, 1.2))\n",
"ax.set_ylim((-0.2, 1.2))\n",
Expand Down
13 changes: 8 additions & 5 deletions docs/examples/conditional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
{
"cell_type": "code",
"execution_count": 1,
"id": null,
"metadata": {},
"outputs": [],
"source": [
"import jax.random as jr\n",
"import jax.numpy as jnp\n",
"import numpy as onp\n",
"from flowjax.flows import BlockNeuralAutoregressiveFlow\n",
"from flowjax.distributions import Normal\n",
"from flowjax.train import fit_to_data\n",
Expand All @@ -50,6 +50,7 @@
{
"cell_type": "code",
"execution_count": 2,
"id": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -77,6 +78,7 @@
{
"cell_type": "code",
"execution_count": 3,
"id": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -104,7 +106,7 @@
" condition=u,\n",
" learning_rate=5e-3,\n",
" max_patience=10,\n",
" )"
")"
]
},
{
Expand All @@ -118,6 +120,7 @@
{
"cell_type": "code",
"execution_count": 4,
"id": null,
"metadata": {},
"outputs": [
{
Expand All @@ -136,9 +139,8 @@
"test_u = jnp.array([1.0, 3])\n",
"\n",
"xgrid, ygrid = jnp.meshgrid(\n",
" jnp.linspace(-1, 4, resolution),\n",
" jnp.linspace(-1, 4, resolution)\n",
" )\n",
" jnp.linspace(-1, 4, resolution), jnp.linspace(-1, 4, resolution)\n",
")\n",
"xyinput = jnp.column_stack((xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)))\n",
"zgrid = jnp.exp(flow.log_prob(xyinput, test_u).reshape(resolution, resolution))\n",
"plt.contourf(xgrid, ygrid, zgrid, levels=50)\n",
Expand All @@ -148,6 +150,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"metadata": {},
"outputs": [],
"source": []
Expand Down
68 changes: 36 additions & 32 deletions docs/examples/snpe.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions docs/examples/unconditional.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"n_samples = 10000\n",
"rng = jr.PRNGKey(0)\n",
"x = two_moons(rng, n_samples)\n",
"x = (x - x.mean(axis=0))/x.std(axis=0) # Standardize"
"x = (x - x.mean(axis=0)) / x.std(axis=0) # Standardize"
]
},
{
Expand Down Expand Up @@ -181,7 +181,7 @@
"for ax in axs:\n",
" ax.set_xlim(lims)\n",
" ax.set_ylim(lims)\n",
" ax.set_aspect('equal')\n",
" ax.set_aspect(\"equal\")\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
Expand Down
51 changes: 27 additions & 24 deletions docs/examples/variational_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"def unormalized_posterior(w):\n",
" likelihood = norm.logpdf(y, w[0] + x * w[1]).sum()\n",
" prior = norm.logpdf(w).sum() # Standard normal prior\n",
" return (likelihood + prior).sum()\n"
" return (likelihood + prior).sum()"
]
},
{
Expand Down Expand Up @@ -92,21 +92,13 @@
"loss = ElboLoss(unormalized_posterior, num_samples=100)\n",
"\n",
"key, flow_key, train_key = jr.split(key, 3)\n",
"from flowjax.bijections import Exp\n",
"flow = MaskedAutoregressiveFlow(\n",
" flow_key,\n",
" base_dist=StandardNormal((2,)),\n",
" transformer=Affine(),\n",
" invert=False\n",
" flow_key, base_dist=StandardNormal((2,)), transformer=Affine(), invert=False\n",
")\n",
"\n",
"# Train the flow variationally\n",
"flow, losses = fit_to_variational_target(\n",
" train_key,\n",
" flow,\n",
" loss,\n",
" learning_rate=1e-3,\n",
" steps=200\n",
" train_key, flow, loss, learning_rate=1e-3, steps=200\n",
")"
]
},
Expand Down Expand Up @@ -145,7 +137,9 @@
}
],
"source": [
"def plot_density(ax, density_fn, xmin=-5, xmax=5, ymin=-5, ymax=5, n=100, levels=None, cmap=\"Blues\"):\n",
"def plot_density(\n",
" ax, density_fn, xmin=-5, xmax=5, ymin=-5, ymax=5, n=100, levels=None, cmap=\"Blues\"\n",
"):\n",
" xvalues = jnp.linspace(xmin, xmax, n)\n",
" yvalues = jnp.linspace(ymin, ymax, n)\n",
" X, Y = jnp.meshgrid(xvalues, yvalues)\n",
Expand All @@ -155,25 +149,33 @@
" log_prob = density_fn(points).reshape(n, n)\n",
" prob = jnp.exp(log_prob)\n",
"\n",
" ax.contour(prob, levels=levels, extent=[xmin, xmax, ymin, ymax], origin=\"lower\", cmap=cmap)\n",
" ax.contour(\n",
" prob, levels=levels, extent=[xmin, xmax, ymin, ymax], origin=\"lower\", cmap=cmap\n",
" )\n",
"\n",
" ax.set_xlim(xmin, xmax)\n",
" ax.set_ylim(ymin, ymax)\n",
"\n",
"\n",
"fig, axes = plt.subplots(ncols=2, figsize=(9, 3))\n",
"axes[0].set_title('Density plot')\n",
"axes[0].set_title(\"Density plot\")\n",
"\n",
"kwargs = dict(xmin=0.25, xmax=1.25, ymin=-1, ymax=0, levels=5)\n",
"plot_density(axes[0], flow.log_prob, cmap=\"Blues\", **kwargs)\n",
"\n",
"# True posterior for comparison\n",
"_x = jnp.vstack([jnp.ones_like(x), x]) # full design matrix\n",
"_x = jnp.vstack([jnp.ones_like(x), x]) # full design matrix\n",
"cov = jnp.linalg.inv(_x.dot(_x.T) + jnp.eye(2))\n",
"mean = cov.dot(_x).dot(y)\n",
"true_posterior_log_prob = lambda theta: multivariate_normal.logpdf(theta, mean, cov)\n",
"\n",
"\n",
"def true_posterior_log_prob(theta):\n",
" return multivariate_normal.logpdf(theta, mean, cov)\n",
"\n",
"\n",
"plot_density(axes[0], true_posterior_log_prob, cmap=\"Reds\", **kwargs)\n",
"axes[1].set_title('losses')\n",
"axes[1].plot(losses)\n"
"axes[1].set_title(\"losses\")\n",
"axes[1].plot(losses)"
]
},
{
Expand Down Expand Up @@ -202,8 +204,8 @@
"source": [
"x_inspect = jnp.linspace(2, -2, n)\n",
"plots = [\n",
" ('prior', StandardNormal((2,)), 'tab:green'),\n",
" ('trained', flow, 'tab:orange'),\n",
" (\"prior\", StandardNormal((2,)), \"tab:green\"),\n",
" (\"trained\", flow, \"tab:orange\"),\n",
"]\n",
"n_samples = 25\n",
"\n",
Expand All @@ -212,10 +214,11 @@
" w = flow.sample(sample_key, (n_samples,))\n",
" for ix, (w_0, w_1) in enumerate(w):\n",
" y_inspect = w_0 + w_1 * x_inspect\n",
" plt.plot(x_inspect, y_inspect, alpha=0.3, c=colour, label=label if ix == 0 else None)\n",
" \n",
"plt.scatter(x, y, label='samples')\n",
"plt.title('Sample Fits')\n",
" lab = label if ix == 0 else None\n",
" plt.plot(x_inspect, y_inspect, alpha=0.3, c=colour, label=lab)\n",
"\n",
"plt.scatter(x, y, label=\"samples\")\n",
"plt.title(\"Sample Fits\")\n",
"plt.legend()\n",
"plt.show()"
]
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,7 @@ pythonpath = ["."]

[tool.ruff]
select = ["E", "F", "B"]
include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"]

[tool.ruff.pydocstyle]
convention = "google"

0 comments on commit c3af9b1

Please sign in to comment.