Skip to content

Commit

Permalink
Apply pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfeickert committed Jul 6, 2023
1 parent 9879a27 commit 7fbafd0
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions docs/examples/notebooks/learn/neyman/NeymanConstruction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"import scipy.stats\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
Expand Down Expand Up @@ -92,8 +93,8 @@
"metadata": {},
"outputs": [],
"source": [
"def tmu_teststat(mu,muhat,sigma):\n",
" a = (mu-muhat)**2/sigma**2\n",
"def tmu_teststat(mu, muhat, sigma):\n",
" a = (mu - muhat) ** 2 / sigma ** 2\n",
" return a"
]
},
Expand All @@ -117,7 +118,8 @@
],
"source": [
"from utils import *\n",
"jointplot(mu = 1.0,sigma = 1.0,teststat = tmu_teststat)"
"\n",
"jointplot(mu=1.0, sigma=1.0, teststat=tmu_teststat)"
]
},
{
Expand Down Expand Up @@ -179,14 +181,15 @@
}
],
"source": [
"def tmu_tilde_teststat(mu,muhat,sigma):\n",
" a = tmu_teststat(mu,muhat,sigma)\n",
"def tmu_tilde_teststat(mu, muhat, sigma):\n",
" a = tmu_teststat(mu, muhat, sigma)\n",
"\n",
" b = (0-muhat)**2/sigma**2\n",
" r = np.where(muhat<0,a-b,a)\n",
" b = (0 - muhat) ** 2 / sigma ** 2\n",
" r = np.where(muhat < 0, a - b, a)\n",
" return r\n",
"\n",
"jointplot(mu = 1.0,sigma = 1.0,teststat = tmu_tilde_teststat)"
"\n",
"jointplot(mu=1.0, sigma=1.0, teststat=tmu_tilde_teststat)"
]
},
{
Expand Down Expand Up @@ -229,13 +232,14 @@
}
],
"source": [
"def qmu_teststat(mu,muhat,sigma):\n",
" r = tmu_teststat(mu,muhat,sigma)\n",
"def qmu_teststat(mu, muhat, sigma):\n",
" r = tmu_teststat(mu, muhat, sigma)\n",
" zero = np.zeros_like(r)\n",
" r = np.where(muhat>mu,zero,r)\n",
" r = np.where(muhat > mu, zero, r)\n",
" return r\n",
"\n",
"jointplot(mu = 0.0,sigma = 1.0,teststat = qmu_teststat)"
"\n",
"jointplot(mu=0.0, sigma=1.0, teststat=qmu_teststat)"
]
},
{
Expand Down Expand Up @@ -275,13 +279,14 @@
}
],
"source": [
"def qmu_tilde_teststat(mu,muhat,sigma):\n",
" r = tmu_tilde_teststat(mu,muhat,sigma)\n",
"def qmu_tilde_teststat(mu, muhat, sigma):\n",
" r = tmu_tilde_teststat(mu, muhat, sigma)\n",
" zero = np.zeros_like(r)\n",
" r = np.where(muhat>mu,zero,r)\n",
" r = np.where(muhat > mu, zero, r)\n",
" return r\n",
"\n",
"jointplot(1.0,sigma = 1.0,teststat = qmu_tilde_teststat)"
"\n",
"jointplot(1.0, sigma=1.0, teststat=qmu_tilde_teststat)"
]
},
{
Expand All @@ -299,14 +304,19 @@
"source": [
"sigma = 1.0\n",
"test_size = 0.05\n",
"min_mu,max_mu,myteststat = 0,5, tmu_tilde_teststat\n",
"# min_mu,max_mu,myteststat = 0,5, qmu_tilde_teststat \n",
"min_mu, max_mu, myteststat = 0, 5, tmu_tilde_teststat\n",
"# min_mu,max_mu,myteststat = 0,5, qmu_tilde_teststat\n",
"# min_mu,max_mu,myteststat = -5,5, tmu_teststat\n",
"# min_mu,max_mu,myteststat = -5,5, qmu_teststat\n",
"hypos_over_sigma = np.linspace(min_mu,max_mu,50)\n",
"scans = np.array([scan_tests_for_size(h*sigma,sigma = sigma,teststat = myteststat) for h in hypos_over_sigma])\n",
"cuts = np.argmax(scans[:,:,-1] > test_size,axis=1)\n",
"atcut = np.asarray([s[c] for s,c in zip(scans,cuts)])"
"hypos_over_sigma = np.linspace(min_mu, max_mu, 50)\n",
"scans = np.array(\n",
" [\n",
" scan_tests_for_size(h * sigma, sigma=sigma, teststat=myteststat)\n",
" for h in hypos_over_sigma\n",
" ]\n",
")\n",
"cuts = np.argmax(scans[:, :, -1] > test_size, axis=1)\n",
"atcut = np.asarray([s[c] for s, c in zip(scans, cuts)])"
]
},
{
Expand Down Expand Up @@ -337,8 +347,8 @@
}
],
"source": [
"f = plot_teststat(min_mu,sigma = sigma, teststat = myteststat)\n",
"f.set_size_inches(3,6)\n",
"f = plot_teststat(min_mu, sigma=sigma, teststat=myteststat)\n",
"f.set_size_inches(3, 6)\n",
"f.set_tight_layout(True)"
]
},
Expand Down Expand Up @@ -370,12 +380,12 @@
}
],
"source": [
"f,axarr = plt.subplots(4,1)\n",
"plot_oneinterval(axarr[0],10.0,sigma,test_size,myteststat)\n",
"plot_oneinterval(axarr[1],1.0,sigma,test_size,myteststat)\n",
"plot_oneinterval(axarr[2],0.4,sigma,test_size,myteststat)\n",
"plot_oneinterval(axarr[3],0.0,sigma,test_size,myteststat)\n",
"f.set_size_inches(3,12)"
"f, axarr = plt.subplots(4, 1)\n",
"plot_oneinterval(axarr[0], 10.0, sigma, test_size, myteststat)\n",
"plot_oneinterval(axarr[1], 1.0, sigma, test_size, myteststat)\n",
"plot_oneinterval(axarr[2], 0.4, sigma, test_size, myteststat)\n",
"plot_oneinterval(axarr[3], 0.0, sigma, test_size, myteststat)\n",
"f.set_size_inches(3, 12)"
]
},
{
Expand Down Expand Up @@ -413,10 +423,10 @@
}
],
"source": [
"f,axarr = plt.subplots(2,1)\n",
"plot_neyman_construction(axarr[0],min_mu,max_mu,hypos_over_sigma,atcut,delta = True)\n",
"plot_neyman_construction(axarr[1],min_mu,max_mu,hypos_over_sigma,atcut,delta = False)\n",
"f.set_size_inches(3,6)\n",
"f, axarr = plt.subplots(2, 1)\n",
"plot_neyman_construction(axarr[0], min_mu, max_mu, hypos_over_sigma, atcut, delta=True)\n",
"plot_neyman_construction(axarr[1], min_mu, max_mu, hypos_over_sigma, atcut, delta=False)\n",
"f.set_size_inches(3, 6)\n",
"f.set_tight_layout(True)"
]
},
Expand Down Expand Up @@ -449,10 +459,10 @@
}
],
"source": [
"f,axarr = plt.subplots(2,1)\n",
"plot_cuts(axarr[0],hypos_over_sigma,atcut,sigma,myteststat)\n",
"plot_pvalue(axarr[1],hypos_over_sigma,scans,sigma,myteststat,atcut[:,0])\n",
"f.set_size_inches(3,6)\n",
"f, axarr = plt.subplots(2, 1)\n",
"plot_cuts(axarr[0], hypos_over_sigma, atcut, sigma, myteststat)\n",
"plot_pvalue(axarr[1], hypos_over_sigma, scans, sigma, myteststat, atcut[:, 0])\n",
"f.set_size_inches(3, 6)\n",
"f.set_tight_layout(True)"
]
},
Expand Down

0 comments on commit 7fbafd0

Please sign in to comment.