diff --git a/docs/examples/notebooks/learn/neyman/NeymanConstruction.ipynb b/docs/examples/notebooks/learn/neyman/NeymanConstruction.ipynb index 261ff0f1c9..50d822c147 100644 --- a/docs/examples/notebooks/learn/neyman/NeymanConstruction.ipynb +++ b/docs/examples/notebooks/learn/neyman/NeymanConstruction.ipynb @@ -9,6 +9,7 @@ "import scipy.stats\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -92,8 +93,8 @@ "metadata": {}, "outputs": [], "source": [ - "def tmu_teststat(mu,muhat,sigma):\n", - " a = (mu-muhat)**2/sigma**2\n", + "def tmu_teststat(mu, muhat, sigma):\n", + " a = (mu - muhat) ** 2 / sigma ** 2\n", " return a" ] }, @@ -117,7 +118,8 @@ ], "source": [ "from utils import *\n", - "jointplot(mu = 1.0,sigma = 1.0,teststat = tmu_teststat)" + "\n", + "jointplot(mu=1.0, sigma=1.0, teststat=tmu_teststat)" ] }, { @@ -179,14 +181,15 @@ } ], "source": [ - "def tmu_tilde_teststat(mu,muhat,sigma):\n", - " a = tmu_teststat(mu,muhat,sigma)\n", + "def tmu_tilde_teststat(mu, muhat, sigma):\n", + " a = tmu_teststat(mu, muhat, sigma)\n", "\n", - " b = (0-muhat)**2/sigma**2\n", - " r = np.where(muhat<0,a-b,a)\n", + " b = (0 - muhat) ** 2 / sigma ** 2\n", + " r = np.where(muhat < 0, a - b, a)\n", " return r\n", "\n", - "jointplot(mu = 1.0,sigma = 1.0,teststat = tmu_tilde_teststat)" + "\n", + "jointplot(mu=1.0, sigma=1.0, teststat=tmu_tilde_teststat)" ] }, { @@ -229,13 +232,14 @@ } ], "source": [ - "def qmu_teststat(mu,muhat,sigma):\n", - " r = tmu_teststat(mu,muhat,sigma)\n", + "def qmu_teststat(mu, muhat, sigma):\n", + " r = tmu_teststat(mu, muhat, sigma)\n", " zero = np.zeros_like(r)\n", - " r = np.where(muhat>mu,zero,r)\n", + " r = np.where(muhat > mu, zero, r)\n", " return r\n", "\n", - "jointplot(mu = 0.0,sigma = 1.0,teststat = qmu_teststat)" + "\n", + "jointplot(mu=0.0, sigma=1.0, teststat=qmu_teststat)" ] }, { @@ -275,13 +279,14 @@ } ], "source": [ - "def qmu_tilde_teststat(mu,muhat,sigma):\n", - " r = tmu_tilde_teststat(mu,muhat,sigma)\n", + "def qmu_tilde_teststat(mu, muhat, sigma):\n", + " r = tmu_tilde_teststat(mu, muhat, sigma)\n", " zero = np.zeros_like(r)\n", - " r = np.where(muhat>mu,zero,r)\n", + " r = np.where(muhat > mu, zero, r)\n", " return r\n", "\n", - "jointplot(1.0,sigma = 1.0,teststat = qmu_tilde_teststat)" + "\n", + "jointplot(1.0, sigma=1.0, teststat=qmu_tilde_teststat)" ] }, { @@ -299,14 +304,19 @@ "source": [ "sigma = 1.0\n", "test_size = 0.05\n", - "min_mu,max_mu,myteststat = 0,5, tmu_tilde_teststat\n", - "# min_mu,max_mu,myteststat = 0,5, qmu_tilde_teststat \n", + "min_mu, max_mu, myteststat = 0, 5, tmu_tilde_teststat\n", + "# min_mu,max_mu,myteststat = 0,5, qmu_tilde_teststat\n", "# min_mu,max_mu,myteststat = -5,5, tmu_teststat\n", "# min_mu,max_mu,myteststat = -5,5, qmu_teststat\n", - "hypos_over_sigma = np.linspace(min_mu,max_mu,50)\n", - "scans = np.array([scan_tests_for_size(h*sigma,sigma = sigma,teststat = myteststat) for h in hypos_over_sigma])\n", - "cuts = np.argmax(scans[:,:,-1] > test_size,axis=1)\n", - "atcut = np.asarray([s[c] for s,c in zip(scans,cuts)])" + "hypos_over_sigma = np.linspace(min_mu, max_mu, 50)\n", + "scans = np.array(\n", + " [\n", + " scan_tests_for_size(h * sigma, sigma=sigma, teststat=myteststat)\n", + " for h in hypos_over_sigma\n", + " ]\n", + ")\n", + "cuts = np.argmax(scans[:, :, -1] > test_size, axis=1)\n", + "atcut = np.asarray([s[c] for s, c in zip(scans, cuts)])" ] }, { @@ -337,8 +347,8 @@ } ], "source": [ - "f = plot_teststat(min_mu,sigma = sigma, teststat = myteststat)\n", - "f.set_size_inches(3,6)\n", + "f = plot_teststat(min_mu, sigma=sigma, teststat=myteststat)\n", + "f.set_size_inches(3, 6)\n", "f.set_tight_layout(True)" ] }, @@ -370,12 +380,12 @@ } ], "source": [ - "f,axarr = plt.subplots(4,1)\n", - "plot_oneinterval(axarr[0],10.0,sigma,test_size,myteststat)\n", - "plot_oneinterval(axarr[1],1.0,sigma,test_size,myteststat)\n", - "plot_oneinterval(axarr[2],0.4,sigma,test_size,myteststat)\n", - "plot_oneinterval(axarr[3],0.0,sigma,test_size,myteststat)\n", - "f.set_size_inches(3,12)" + "f, axarr = plt.subplots(4, 1)\n", + "plot_oneinterval(axarr[0], 10.0, sigma, test_size, myteststat)\n", + "plot_oneinterval(axarr[1], 1.0, sigma, test_size, myteststat)\n", + "plot_oneinterval(axarr[2], 0.4, sigma, test_size, myteststat)\n", + "plot_oneinterval(axarr[3], 0.0, sigma, test_size, myteststat)\n", + "f.set_size_inches(3, 12)" ] }, { @@ -413,10 +423,10 @@ } ], "source": [ - "f,axarr = plt.subplots(2,1)\n", - "plot_neyman_construction(axarr[0],min_mu,max_mu,hypos_over_sigma,atcut,delta = True)\n", - "plot_neyman_construction(axarr[1],min_mu,max_mu,hypos_over_sigma,atcut,delta = False)\n", - "f.set_size_inches(3,6)\n", + "f, axarr = plt.subplots(2, 1)\n", + "plot_neyman_construction(axarr[0], min_mu, max_mu, hypos_over_sigma, atcut, delta=True)\n", + "plot_neyman_construction(axarr[1], min_mu, max_mu, hypos_over_sigma, atcut, delta=False)\n", + "f.set_size_inches(3, 6)\n", "f.set_tight_layout(True)" ] }, @@ -449,10 +459,10 @@ } ], "source": [ - "f,axarr = plt.subplots(2,1)\n", - "plot_cuts(axarr[0],hypos_over_sigma,atcut,sigma,myteststat)\n", - "plot_pvalue(axarr[1],hypos_over_sigma,scans,sigma,myteststat,atcut[:,0])\n", - "f.set_size_inches(3,6)\n", + "f, axarr = plt.subplots(2, 1)\n", + "plot_cuts(axarr[0], hypos_over_sigma, atcut, sigma, myteststat)\n", + "plot_pvalue(axarr[1], hypos_over_sigma, scans, sigma, myteststat, atcut[:, 0])\n", + "f.set_size_inches(3, 6)\n", "f.set_tight_layout(True)" ] },