diff --git a/neon_utils/hana_utils.py b/neon_utils/hana_utils.py index e113c8b8..2da152d6 100644 --- a/neon_utils/hana_utils.py +++ b/neon_utils/hana_utils.py @@ -156,7 +156,7 @@ def request_backend(endpoint: str, request_data: dict, _client_config = {} _headers = {} _init_client(server_url) - if time() >= _client_config.get("expiration", 0): + if _client_config.get("expiration", 0) - time() < 30: try: _refresh_token(server_url) except ServerException as e: @@ -172,4 +172,16 @@ def request_backend(endpoint: str, request_data: dict, return resp.json() else: + try: + error = resp.json()["detail"] + # Token is actually expired, refresh and retry + if error == "Invalid or expired token.": + LOG.warning(f"Token is expired. time={time()}|" + f"expiration={_client_config.get('expiration')}") + _refresh_token(server_url) + resp = requests.post(**request_kwargs) + if resp.ok: + return resp.json() + except Exception as e: + LOG.error(e) raise ServerException(f"Error response {resp.status_code}: {resp.text}") diff --git a/tests/hana_util_tests.py b/tests/hana_util_tests.py index f7488168..2944c046 100644 --- a/tests/hana_util_tests.py +++ b/tests/hana_util_tests.py @@ -25,10 +25,13 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + import json import unittest + from os import remove from os.path import join, dirname, isfile +from time import time from unittest.mock import patch @@ -70,6 +73,28 @@ def test_request_backend(self, config_path): self.assertIsInstance(resp['answer'], str) # TODO: Test invalid route, invalid request data + @patch("neon_utils.hana_utils._get_client_config_path") + @patch("neon_utils.hana_utils._refresh_token") + def test_request_backend_refresh_token(self, refresh_token, config_path): + config_path.return_value = self.test_path + + import neon_utils.hana_utils + from neon_utils.hana_utils import request_backend + neon_utils.hana_utils.set_default_backend_url(self.test_server) + neon_utils.hana_utils._init_client(self.test_server) + real_client_config = neon_utils.hana_utils._client_config + neon_utils.hana_utils._client_config['expiration'] = time() + 29 + neon_utils.hana_utils._refresh_token = refresh_token + resp = request_backend("/neon/get_response", + {"lang_code": "en-us", + "utterance": "how are you", + "user_profile": {}}, self.test_server) + self.assertEqual(resp['lang_code'], "en-us") + self.assertIsInstance(resp['answer'], str) + refresh_token.assert_called_once_with(self.test_server) + + neon_utils.hana_utils._client_config = real_client_config + @patch("neon_utils.hana_utils._get_client_config_path") def test_00_get_token(self, config_path): config_path.return_value = self.test_path