diff --git a/deepaas/api/__init__.py b/deepaas/api/__init__.py index b4db5aee..eb1cd738 100644 --- a/deepaas/api/__init__.py +++ b/deepaas/api/__init__.py @@ -54,6 +54,7 @@ async def get_app( swagger=True, + enable_doc=True, doc="/ui", prefix="", static_path="/static/swagger", @@ -101,7 +102,7 @@ async def get_app( basePath=base_path, version=deepaas.__version__, url="/swagger.json", - swagger_path=doc if doc else None, + swagger_path=doc if enable_doc else None, prefix=prefix, static_path=static_path, in_place=True, diff --git a/deepaas/cmd/run.py b/deepaas/cmd/run.py index 62906f2b..b68e46b8 100644 --- a/deepaas/cmd/run.py +++ b/deepaas/cmd/run.py @@ -91,7 +91,11 @@ def main(): log.info("Starting DEEPaaS version %s", deepaas.__version__) - app = api.get_app(doc="/ui") + app = api.get_app( + enable_doc=CONF.doc_endpoint, + enable_train=CONF.train_endpoint, + enable_predict=CONF.predict_endpoint, + ) web.run_app( app, host=CONF.listen_ip, diff --git a/deepaas/config.py b/deepaas/config.py index fc6b16be..18b079a0 100644 --- a/deepaas/config.py +++ b/deepaas/config.py @@ -26,14 +26,36 @@ warnings.simplefilter("default", DeprecationWarning) opts = [ + cfg.BoolOpt( + "train-endpoint", + default=True, + help=""" +Specify whether DEEPaaS should provide a train endpoint (default: True). +""", + ), + cfg.BoolOpt( + "predict-endpoint", + default=True, + help=""" +Specify whether DEEPaaS should provide a predict endpoint (default: True). +""", + ), cfg.BoolOpt( "debug-endpoint", - default="false", + default=False, help=""" Enable debug endpoint. If set we will provide all the information that you print to the standard output and error (i.e. stdout and stderr) through the "/debug" endpoint. Default is to not provide this information. This will not provide logging information about the API itself. +""", + ), + cfg.BoolOpt( + "doc-endpoint", + default=True, + help=""" +Enable documentation endpoint. If set we will provide the documentation +through the "/ui" endpoint. Default is to provide this information. """, ), cfg.IntOpt( diff --git a/deepaas/tests/test_v2_api.py b/deepaas/tests/test_v2_api.py index 69fbba42..e7d90dc5 100644 --- a/deepaas/tests/test_v2_api.py +++ b/deepaas/tests/test_v2_api.py @@ -42,6 +42,128 @@ class Fake(object): assert response is responses.Prediction +class TestApiV2NoTrain(base.TestCase): + async def get_application(self): + app = web.Application(debug=True) + app.middlewares.append(web.normalize_path_middleware()) + + deepaas.model.v2.register_models(app) + + v2app = v2.get_app(enable_train=False) + app.add_subapp("/v2", v2app) + + return app + + def setUp(self): + super(TestApiV2NoTrain, self).setUp() + + self.maxDiff = None + + self.flags(debug=True) + + def assert_ok(self, response): + self.assertIn(response.status, [200, 201]) + + async def test_not_found(self): + ret = await self.client.post("/v2/models/deepaas-test/train") + self.assertEqual(404, ret.status) + + async def test_predict_data(self): + f = six.BytesIO(b"foo") + ret = await self.client.post( + "/v2/models/deepaas-test/predict/", + data={"data": (f, "foo.txt"), "parameter": 1}, + ) + json = await ret.json() + self.assertEqual(200, ret.status) + self.assertDictEqual(fake_responses.deepaas_test_predict, json) + + async def test_get_metadata(self): + meta = fake_responses.models_meta + + ret = await self.client.get("/v2/models/") + self.assert_ok(ret) + self.assertDictEqual(meta, await ret.json()) + + ret = await self.client.get("/v2/models/deepaas-test/") + self.assert_ok(ret) + self.assertDictEqual(meta["models"][0], await ret.json()) + + +class TestApiV2NoPredict(base.TestCase): + async def get_application(self): + app = web.Application(debug=True) + app.middlewares.append(web.normalize_path_middleware()) + + deepaas.model.v2.register_models(app) + + v2app = v2.get_app(enable_predict=False) + app.add_subapp("/v2", v2app) + + return app + + def setUp(self): + super(TestApiV2NoPredict, self).setUp() + + self.maxDiff = None + + self.flags(debug=True) + + def assert_ok(self, response): + self.assertIn(response.status, [200, 201]) + + async def test_not_found(self): + ret = await self.client.post("/v2/models/deepaas-test/predict/") + self.assertEqual(404, ret.status) + + async def test_train(self): + ret = await self.client.post( + "/v2/models/deepaas-test/train/", data={"sleep": 0} + ) + self.assertEqual(200, ret.status) + json = await ret.json() + json.pop("date") + self.assertDictEqual(fake_responses.deepaas_test_train, json) + + async def test_get_metadata(self): + meta = fake_responses.models_meta + + ret = await self.client.get("/v2/models/") + self.assert_ok(ret) + self.assertDictEqual(meta, await ret.json()) + + ret = await self.client.get("/v2/models/deepaas-test/") + self.assert_ok(ret) + self.assertDictEqual(meta["models"][0], await ret.json()) + + +class TestApiV2NoDoc(base.TestCase): + async def get_application(self): + app = web.Application(debug=True) + app.middlewares.append(web.normalize_path_middleware()) + + deepaas.model.v2.register_models(app) + + v2app = v2.get_app(enable_doc=False) + app.add_subapp("/v2", v2app) + + return app + + def setUp(self): + super(TestApiV2NoDoc, self).setUp() + + self.maxDiff = None + + self.flags(debug=True) + + def assert_ok(self, response): + self.assertIn(response.status, [200, 201]) + + async def test_not_found(self): + ret = await self.client.get("/ui") + self.assertEqual(404, ret.status) + + class TestApiV2(base.TestCase): async def get_application(self): app = web.Application(debug=True) diff --git a/doc/source/cli/deepaas-run.rst b/doc/source/cli/deepaas-run.rst index 1ae03d9e..421d4d48 100644 --- a/doc/source/cli/deepaas-run.rst +++ b/doc/source/cli/deepaas-run.rst @@ -21,6 +21,14 @@ Options If set to true, the logging level will be set to DEBUG instead of the default INFO level. +.. option:: --predict-endpoint + + Specify whether DEEPaaS should provide a train endpoint (default: True). + +.. option:: --train-endpoint + + Specify whether DEEPaaS should provide a train endpoint (default: True). + .. option:: --debug-endpoint Enable debug endpoint. If set we will provide all the information that you diff --git a/tox.ini b/tox.ini index e9bbe08b..3ba83d82 100644 --- a/tox.ini +++ b/tox.ini @@ -39,6 +39,7 @@ deps = pytest-aiohttp pytest-cov>=4.0.0,<4.1 fixtures + reno mock testtools -r{toxinidir}/requirements.txt