Skip to content

Commit

Permalink
Completed checks for derivatives of MLP log-target (xor)
Browse files Browse the repository at this point in the history
  • Loading branch information
Theodore Papamarkou committed Aug 31, 2019
1 parent e0aa432 commit 97e8770
Showing 1 changed file with 180 additions and 5 deletions.
185 changes: 180 additions & 5 deletions examples/checks/mlp/log_target_derivatives.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluation of MLP log-likelihood\n",
"# Evaluation of grad and of Hessian of MLP log-target\n",
"\n",
"Confirm PyTorch and manually coded MLP log-likelihood coincide"
"Confirm PyTorch and manually coded grad and metric tensor of MLP log-target coincide"
]
},
{
Expand All @@ -24,8 +24,9 @@
"from torch.distributions import Normal\n",
"from torch.autograd import grad\n",
"\n",
"from eeyore.models.mlp import Hyperparameters, MLP\n",
"from eeyore.data import XOR"
"from eeyore.data import XOR\n",
"from eeyore.stats import binary_cross_entropy\n",
"from eeyore.models.mlp import Hyperparameters, MLP"
]
},
{
Expand Down Expand Up @@ -145,7 +146,7 @@
" g2 = h1 @ w2.t() + b2\n",
" h2 = torch.sigmoid(g2)\n",
" \n",
" return -F.binary_cross_entropy(h2, y, reduction='sum')"
" return -binary_cross_entropy(h2, y, reduction='sum')"
]
},
{
Expand Down Expand Up @@ -419,6 +420,180 @@
"source": [
"[p for p in [glt_result01, glt_result02, glt_result03, glt_result04]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute metric of MLP log-target using eeyore API version"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(-65.8127, dtype=torch.float64, grad_fn=<AddBackward0>),\n",
" tensor([-3.1125e-01, -3.1041e-01, 2.7002e-04, 1.5006e-04, -3.6975e-01,\n",
" -3.4671e-04, -1.9098e+00, -1.9979e+00, -1.9992e+00],\n",
" dtype=torch.float64, grad_fn=<CatBackward>),\n",
" tensor([[-2.6380e-01, -2.6390e-01, -2.9198e-08, -2.6460e-08, -2.6390e-01,\n",
" -2.9198e-08, 7.0188e-02, 1.2025e-04, 1.2026e-04],\n",
" [-2.6390e-01, -2.6322e-01, -2.6460e-08, -9.5428e-08, -2.6332e-01,\n",
" -9.5428e-08, 7.0442e-02, 5.6173e-04, 5.6176e-04],\n",
" [-2.9198e-08, -2.6460e-08, 3.2999e-04, 2.3013e-04, -2.9198e-08,\n",
" 2.2999e-04, -2.1572e-07, 6.7431e-05, -2.2279e-07],\n",
" [-2.6460e-08, -9.5428e-08, 2.3013e-04, 3.3003e-04, -9.5428e-08,\n",
" 2.3003e-04, -1.5708e-07, 6.7483e-05, -1.8290e-07],\n",
" [-2.6390e-01, -2.6332e-01, -2.9198e-08, -9.5428e-08, -3.2027e-01,\n",
" -1.0380e-07, 8.3666e-02, 5.8348e-04, 5.8352e-04],\n",
" [-2.9198e-08, -9.5428e-08, 2.2999e-04, 2.3003e-04, -1.0380e-07,\n",
" 6.7319e-04, -3.8906e-07, 1.6820e-04, -4.1677e-07],\n",
" [ 7.0188e-02, 7.0442e-02, -2.1572e-07, -1.5708e-07, 8.3666e-02,\n",
" -3.8906e-07, 1.3623e-03, 1.3936e-03, 1.3937e-03],\n",
" [ 1.2025e-04, 5.6173e-04, 6.7431e-05, 6.7483e-05, 5.8348e-04,\n",
" 1.6820e-04, 1.3936e-03, 1.6519e-03, 1.5520e-03],\n",
" [ 1.2026e-04, 5.6176e-04, -2.2279e-07, -1.8290e-07, 5.8352e-04,\n",
" -4.1677e-07, 1.3937e-03, 1.5520e-03, 1.6521e-03]],\n",
" dtype=torch.float64))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta = theta0.clone().detach()\n",
"\n",
"lt_val05, glt_result05, mlt_result01 = model.upto_metric_log_target(theta, data, labels)\n",
"lt_val05, glt_result05, mlt_result01"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compute metric of MLP log-target calling grad() on manually coded log_target()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-2.6380e-01, -2.6390e-01, -2.9198e-08, -2.6460e-08, -2.6390e-01,\n",
" -2.9198e-08, 7.0188e-02, 1.2025e-04, 1.2026e-04],\n",
" [-2.6390e-01, -2.6322e-01, -2.6460e-08, -9.5428e-08, -2.6332e-01,\n",
" -9.5428e-08, 7.0442e-02, 5.6173e-04, 5.6176e-04],\n",
" [-2.9198e-08, -2.6460e-08, 3.2999e-04, 2.3013e-04, -2.9198e-08,\n",
" 2.2999e-04, -2.1572e-07, 6.7431e-05, -2.2279e-07],\n",
" [-2.6460e-08, -9.5428e-08, 2.3013e-04, 3.3003e-04, -9.5428e-08,\n",
" 2.3003e-04, -1.5708e-07, 6.7483e-05, -1.8290e-07],\n",
" [-2.6390e-01, -2.6332e-01, -2.9198e-08, -9.5428e-08, -3.2027e-01,\n",
" -1.0380e-07, 8.3666e-02, 5.8348e-04, 5.8352e-04],\n",
" [-2.9198e-08, -9.5428e-08, 2.2999e-04, 2.3003e-04, -1.0380e-07,\n",
" 6.7319e-04, -3.8906e-07, 1.6820e-04, -4.1677e-07],\n",
" [ 7.0188e-02, 7.0442e-02, -2.1572e-07, -1.5708e-07, 8.3666e-02,\n",
" -3.8906e-07, 1.3623e-03, 1.3936e-03, 1.3937e-03],\n",
" [ 1.2025e-04, 5.6173e-04, 6.7431e-05, 6.7483e-05, 5.8348e-04,\n",
" 1.6820e-04, 1.3936e-03, 1.6519e-03, 1.5520e-03],\n",
" [ 1.2026e-04, 5.6176e-04, -2.2279e-07, -1.8290e-07, 5.8352e-04,\n",
" -4.1677e-07, 1.3937e-03, 1.5520e-03, 1.6521e-03]],\n",
" dtype=torch.float64)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"theta = theta0.clone().detach()\n",
"theta.requires_grad_(True)\n",
"\n",
"lt_val06 = log_target(theta, data, labels)\n",
"\n",
"glt_result06, = grad(lt_val06, theta, create_graph=True)\n",
"\n",
"hlt_val = []\n",
"for i in range(9):\n",
" deriv_i_wrt_grad = grad(glt_result06[i], theta, retain_graph=True)\n",
" hlt_val.append(torch.cat([h.view(-1) for h in deriv_i_wrt_grad]))\n",
"\n",
"mlt_result02 = -torch.cat(hlt_val, 0).reshape(9, 9)\n",
"mlt_result02"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Print out values of both metric log-target implementations"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([[-2.6380e-01, -2.6390e-01, -2.9198e-08, -2.6460e-08, -2.6390e-01,\n",
" -2.9198e-08, 7.0188e-02, 1.2025e-04, 1.2026e-04],\n",
" [-2.6390e-01, -2.6322e-01, -2.6460e-08, -9.5428e-08, -2.6332e-01,\n",
" -9.5428e-08, 7.0442e-02, 5.6173e-04, 5.6176e-04],\n",
" [-2.9198e-08, -2.6460e-08, 3.2999e-04, 2.3013e-04, -2.9198e-08,\n",
" 2.2999e-04, -2.1572e-07, 6.7431e-05, -2.2279e-07],\n",
" [-2.6460e-08, -9.5428e-08, 2.3013e-04, 3.3003e-04, -9.5428e-08,\n",
" 2.3003e-04, -1.5708e-07, 6.7483e-05, -1.8290e-07],\n",
" [-2.6390e-01, -2.6332e-01, -2.9198e-08, -9.5428e-08, -3.2027e-01,\n",
" -1.0380e-07, 8.3666e-02, 5.8348e-04, 5.8352e-04],\n",
" [-2.9198e-08, -9.5428e-08, 2.2999e-04, 2.3003e-04, -1.0380e-07,\n",
" 6.7319e-04, -3.8906e-07, 1.6820e-04, -4.1677e-07],\n",
" [ 7.0188e-02, 7.0442e-02, -2.1572e-07, -1.5708e-07, 8.3666e-02,\n",
" -3.8906e-07, 1.3623e-03, 1.3936e-03, 1.3937e-03],\n",
" [ 1.2025e-04, 5.6173e-04, 6.7431e-05, 6.7483e-05, 5.8348e-04,\n",
" 1.6820e-04, 1.3936e-03, 1.6519e-03, 1.5520e-03],\n",
" [ 1.2026e-04, 5.6176e-04, -2.2279e-07, -1.8290e-07, 5.8352e-04,\n",
" -4.1677e-07, 1.3937e-03, 1.5520e-03, 1.6521e-03]],\n",
" dtype=torch.float64),\n",
" tensor([[-2.6380e-01, -2.6390e-01, -2.9198e-08, -2.6460e-08, -2.6390e-01,\n",
" -2.9198e-08, 7.0188e-02, 1.2025e-04, 1.2026e-04],\n",
" [-2.6390e-01, -2.6322e-01, -2.6460e-08, -9.5428e-08, -2.6332e-01,\n",
" -9.5428e-08, 7.0442e-02, 5.6173e-04, 5.6176e-04],\n",
" [-2.9198e-08, -2.6460e-08, 3.2999e-04, 2.3013e-04, -2.9198e-08,\n",
" 2.2999e-04, -2.1572e-07, 6.7431e-05, -2.2279e-07],\n",
" [-2.6460e-08, -9.5428e-08, 2.3013e-04, 3.3003e-04, -9.5428e-08,\n",
" 2.3003e-04, -1.5708e-07, 6.7483e-05, -1.8290e-07],\n",
" [-2.6390e-01, -2.6332e-01, -2.9198e-08, -9.5428e-08, -3.2027e-01,\n",
" -1.0380e-07, 8.3666e-02, 5.8348e-04, 5.8352e-04],\n",
" [-2.9198e-08, -9.5428e-08, 2.2999e-04, 2.3003e-04, -1.0380e-07,\n",
" 6.7319e-04, -3.8906e-07, 1.6820e-04, -4.1677e-07],\n",
" [ 7.0188e-02, 7.0442e-02, -2.1572e-07, -1.5708e-07, 8.3666e-02,\n",
" -3.8906e-07, 1.3623e-03, 1.3936e-03, 1.3937e-03],\n",
" [ 1.2025e-04, 5.6173e-04, 6.7431e-05, 6.7483e-05, 5.8348e-04,\n",
" 1.6820e-04, 1.3936e-03, 1.6519e-03, 1.5520e-03],\n",
" [ 1.2026e-04, 5.6176e-04, -2.2279e-07, -1.8290e-07, 5.8352e-04,\n",
" -4.1677e-07, 1.3937e-03, 1.5520e-03, 1.6521e-03]],\n",
" dtype=torch.float64)]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[p for p in [mlt_result01, mlt_result02]]"
]
}
],
"metadata": {
Expand Down

0 comments on commit 97e8770

Please sign in to comment.