Skip to content

Commit

Permalink
Add utilities for interfacing with Neon HANA (#497)
Browse files Browse the repository at this point in the history
* Add `hana_utils` with unit tests
Update unit tests to test entire module in one step

* Revert bad unit test automation refactor

* Fix bad refactor

* Refactor tests to minimize auth requests

* Refactor token logic to internal methods to ensure stable API

* Refactor to better support `server` configuration

* Refactor to skip Configuration parsing

* Update patch missed in refactor

---------

Co-authored-by: Daniel McKnight <daniel@neon.ai>
  • Loading branch information
NeonDaniel and NeonDaniel authored Jan 24, 2024
1 parent 3cd6c64 commit 4389594
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,12 @@ jobs:
with:
name: language-util-test-results
path: tests/language-util-test-results.xml

- name: Test Hana Utils
run: |
pytest tests/hana_util_tests.py --doctest-modules --junitxml=tests/hana-util-test-results.xml
- name: Upload hana utils test results
uses: actions/upload-artifact@v2
with:
name: hana-util-test-results
path: tests/hana-util-test-results.xml
132 changes: 132 additions & 0 deletions neon_utils/hana_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2022 Neongecko.com Inc.
# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds,
# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo
# BSD-3 License
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import requests
import json

from os.path import join, isfile
from time import time
from typing import Optional
from ovos_utils.log import LOG
from ovos_utils.xdg_utils import xdg_cache_home

_DEFAULT_BACKEND_URL = "https://hana.neonaialpha.com"
_client_config = {}
_client_config_path = join(xdg_cache_home(), "neon", "hana_token.json")
_headers = {}


class ServerException(Exception):
"""Exception class representing a backend server communication error"""


def _init_client(backend_address: str):
"""
Initialize request headers for making backend requests. If a local cache is
available it will be used, otherwise an auth request will be made to the
specified backend server
@param backend_address: Hana server URL to connect to
"""
global _client_config
global _headers

if not _client_config:
if isfile(_client_config_path):
with open(_client_config_path) as f:
_client_config = json.load(f)
else:
_get_token(backend_address)

if not _headers:
_headers = {"Authorization": f"Bearer {_client_config['access_token']}"}


def _get_token(backend_address: str, username: str = "guest",
password: str = "password"):
"""
Get new auth tokens from the specified server. This will cache the returned
token, overwriting any previous data at the cache path.
@param backend_address: Hana server URL to connect to
@param username: Username to authorize
@param password: Password for specified username
"""
global _client_config
# TODO: username/password from configuration
resp = requests.post(f"{backend_address}/auth/login",
json={"username": username,
"password": password})
if not resp.ok:
raise ServerException(f"Error logging into {backend_address}. "
f"{resp.status_code}: {resp.text}")
_client_config = resp.json()
with open(_client_config_path, "w+") as f:
json.dump(_client_config, f, indent=2)


def _refresh_token(backend_address: str):
"""
Get new tokens from the specified server using an existing refresh token
(if it exists). This will update the cached tokens and associated metadata.
@param backend_address: Hana server URL to connect to
"""
global _client_config
_init_client(backend_address)
update = requests.post(f"{backend_address}/auth/refresh", json={
"access_token": _client_config.get("access_token"),
"refresh_token": _client_config.get("refresh_token"),
"client_id": _client_config.get("client_id")})
if not update.ok:
raise ServerException(f"Error updating token from {backend_address}. "
f"{update.status_code}: {update.text}")
_client_config = update.json()
with open(_client_config_path, "w+") as f:
json.dump(_client_config, f, indent=2)


def request_backend(endpoint: str, request_data: dict,
server_url: str = _DEFAULT_BACKEND_URL) -> dict:
"""
Make a request to a Hana backend server and return the json response
@param endpoint: server endpoint to query
@param request_data: dict data to send in request body
@param server_url: Base URL of Hana server to query
@returns: dict response
"""
_init_client(server_url)
if time() >= _client_config.get("expiration", 0):
try:
_refresh_token(server_url)
except ServerException as e:
LOG.error(e)
_get_token(server_url)
resp = requests.post(f"{server_url}/{endpoint.lstrip('/')}",
json=request_data, headers=_headers)
if resp.ok:
return resp.json()
else:
raise ServerException(f"Error response {resp.status_code}: {resp.text}")
128 changes: 128 additions & 0 deletions tests/hana_util_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# NEON AI (TM) SOFTWARE, Software Development Kit & Application Framework
# All trademark and other rights reserved by their respective owners
# Copyright 2008-2022 Neongecko.com Inc.
# Contributors: Daniel McKnight, Guy Daniels, Elon Gasper, Richard Leeds,
# Regina Bloomstine, Casimiro Ferreira, Andrii Pernatii, Kirill Hrymailo
# BSD-3 License
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import json
import unittest
from os import remove
from os.path import join, dirname, isfile
from unittest.mock import patch


valid_config = {}
valid_headers = {}


class HanaUtilTests(unittest.TestCase):
import neon_utils.hana_utils
test_server = "https://hana.neonaialpha.com"
test_path = join(dirname(__file__), "hana_test.json")
neon_utils.hana_utils._client_config_path = test_path

def tearDown(self) -> None:
global valid_config
global valid_headers
import neon_utils.hana_utils
if isfile(self.test_path):
remove(self.test_path)
if neon_utils.hana_utils._client_config:
valid_config = neon_utils.hana_utils._client_config
if neon_utils.hana_utils._headers:
valid_headers = neon_utils.hana_utils._headers
neon_utils.hana_utils._client_config = {}
neon_utils.hana_utils._headers = {}

def test_request_backend(self):
# Use a valid config and skip extra auth
import neon_utils.hana_utils
neon_utils.hana_utils._client_config = valid_config
neon_utils.hana_utils._headers = valid_headers
from neon_utils.hana_utils import request_backend
resp = request_backend("/neon/get_response",
{"lang_code": "en-us",
"utterance": "who are you",
"user_profile": {}}, self.test_server)
self.assertEqual(resp['lang_code'], "en-us")
self.assertIsInstance(resp['answer'], str)
# TODO: Test invalid route, invalid request data

def test_00_get_token(self):
from neon_utils.hana_utils import _get_token

# Test valid request
_get_token(self.test_server)
from neon_utils.hana_utils import _client_config
self.assertTrue(isfile(self.test_path))
with open(self.test_path) as f:
credentials_on_disk = json.load(f)
self.assertEqual(credentials_on_disk, _client_config)
# TODO: Test invalid request, rate-limited request

@patch("neon_utils.hana_utils._get_token")
def test_refresh_token(self, get_token):
import neon_utils.hana_utils

def _write_token(*_, **__):
with open(self.test_path, 'w+') as c:
json.dump(valid_config, c)
neon_utils.hana_utils._client_config = valid_config

from neon_utils.hana_utils import _refresh_token
get_token.side_effect = _write_token

self.assertFalse(isfile(self.test_path))

# Test valid request (auth + refresh)
_refresh_token(self.test_server)
get_token.assert_called_once()
from neon_utils.hana_utils import _client_config
self.assertTrue(isfile(self.test_path))
with open(self.test_path) as f:
credentials_on_disk = json.load(f)
self.assertEqual(credentials_on_disk, _client_config)

# Test refresh of existing token (no auth)
_refresh_token(self.test_server)
get_token.assert_called_once()
with open(self.test_path) as f:
new_credentials = json.load(f)
self.assertNotEqual(credentials_on_disk, new_credentials)
self.assertEqual(credentials_on_disk['client_id'],
new_credentials['client_id'])
self.assertEqual(credentials_on_disk['username'],
new_credentials['username'])
self.assertGreater(new_credentials['expiration'],
credentials_on_disk['expiration'])
self.assertNotEqual(credentials_on_disk['access_token'],
new_credentials['access_token'])
self.assertNotEqual(credentials_on_disk['refresh_token'],
new_credentials['refresh_token'])

# TODO: Test invalid refresh


if __name__ == '__main__':
unittest.main()

0 comments on commit 4389594

Please sign in to comment.