diff --git a/desktop/core/src/desktop/conf.py b/desktop/core/src/desktop/conf.py index 4e447d03401..71a8dc52110 100644 --- a/desktop/core/src/desktop/conf.py +++ b/desktop/core/src/desktop/conf.py @@ -2623,9 +2623,8 @@ def is_cm_managed(): def is_gs_enabled(): from desktop.lib.idbroker import conf as conf_idbroker # Circular dependencies desktop.conf -> idbroker.conf -> desktop.conf - return ('default' in list(GC_ACCOUNTS.keys()) and GC_ACCOUNTS['default'].JSON_CREDENTIALS.get()) or \ - conf_idbroker.is_idbroker_enabled('gs') or \ - is_raz_gs() + return ('default' in list(GC_ACCOUNTS.keys()) and GC_ACCOUNTS['default'].JSON_CREDENTIALS.get()) or is_raz_gs() or \ + conf_idbroker.is_idbroker_enabled('gs') def has_gs_access(user): from desktop.auth.backend import is_admin diff --git a/desktop/core/src/desktop/lib/idbroker/client.py b/desktop/core/src/desktop/lib/idbroker/client.py index 6fc4cad74eb..66001a5786b 100644 --- a/desktop/core/src/desktop/lib/idbroker/client.py +++ b/desktop/core/src/desktop/lib/idbroker/client.py @@ -13,8 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - from builtins import object import logging @@ -48,8 +46,11 @@ def from_core_site(cls, fs=None, user=None): def __init__(self, user=None, address=None, dt_path=None, path=None, security=None): - self.user=user - self.address=address + self.user = user + self.address = address + if not self.address: + raise PopupException('Failed to connect to IDBroker: No active or healthy instance was found.') + self.dt_path = dt_path self.path = path self.security = security @@ -60,9 +61,9 @@ def __init__(self, user=None, address=None, dt_path=None, path=None, security=No def _knox_token_params(self): if self.user: if self.security['type'] == 'kerberos': - return { 'doAs': self.user } + return {'doAs': self.user} else: - return { 'user.name': self.user } + return {'user.name': self.user} else: return None @@ -73,7 +74,8 @@ def get_auth_token(self): elif self.security['type'] == 'basic': self._client.set_basic_auth(self.security['params']['username'], self.security['params']['password']) try: - res = self._root.invoke("GET", self.dt_path + _KNOX_TOKEN_API, self._knox_token_params(), allow_redirects=True, log_response=False) # Can't log response because returns credentials + # Can't log response because returns credentials + res = self._root.invoke("GET", self.dt_path + _KNOX_TOKEN_API, self._knox_token_params(), allow_redirects=True, log_response=False) return res.get('access_token') except Exception as e: raise PopupException('Failed to authenticate to IDBroker with error: %s' % e.message) @@ -82,6 +84,7 @@ def get_auth_token(self): def get_cab(self): self._client.set_bearer_auth(self.get_auth_token()) try: - return self._root.invoke("GET", self.path + _CAB_API_CREDENTIALS_GLOBAL, allow_redirects=True, log_response=False) # Can't log response because returns credentials + # Can't log response because returns credentials + return self._root.invoke("GET", self.path + _CAB_API_CREDENTIALS_GLOBAL, allow_redirects=True, log_response=False) except Exception as e: raise PopupException('Failed to obtain storage credentials from IDBroker with error: %s' % e.message) diff --git a/desktop/core/src/desktop/lib/idbroker/conf.py b/desktop/core/src/desktop/lib/idbroker/conf.py index 80b0ce33952..efa0ad73098 100644 --- a/desktop/core/src/desktop/lib/idbroker/conf.py +++ b/desktop/core/src/desktop/lib/idbroker/conf.py @@ -13,61 +13,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import logging -import sys import requests from requests_kerberos import HTTPKerberosAuth from hadoop.core_site import get_conf -if sys.version_info[0] > 2: - from django.utils.translation import gettext_lazy as _t -else: - from django.utils.translation import ugettext_lazy as _t +from django.utils.translation import gettext_lazy as _t + LOG = logging.getLogger() + _CNF_CAB_ADDRESS = 'fs.%s.ext.cab.address' # http://host:8444/gateway _CNF_CAB_ADDRESS_DT_PATH = 'fs.%s.ext.cab.dt.path' # dt _CNF_CAB_ADDRESS_PATH = 'fs.%s.ext.cab.path' # aws-cab _CNF_CAB_USERNAME = 'fs.%s.ext.cab.username' # when not using kerberos _CNF_CAB_PASSWORD = 'fs.%s.ext.cab.password' + SUPPORTED_FS = {'s3a': 's3a', 'adl': 'azure', 'abfs': 'azure', 'azure': 'azure', 'gs': 'gs'} def validate_fs(fs=None): if fs in SUPPORTED_FS: return SUPPORTED_FS[fs] else: - LOG.warning('Selected FS %s is not supported by Hue IDBroker client' % fs) + LOG.warning('Selected filesystem %s is not supported by Hue IDBroker client.' % fs) return None def _handle_idbroker_ha(fs=None): - fs = validate_fs(fs) - idbrokeraddr = get_conf().get(_CNF_CAB_ADDRESS % fs) if fs else None - response = None + idbroker_addr_list = [] if fs: - id_broker_addr = get_conf().get(_CNF_CAB_ADDRESS % fs) - if id_broker_addr: - id_broker_addr_list = id_broker_addr.split(',') - for id_broker_addr in id_broker_addr_list: - try: - response = requests.get(id_broker_addr.rstrip('/') + '/dt/knoxtoken/api/v1/token', auth=HTTPKerberosAuth(), verify=False) - except Exception as e: - if 'Name or service not known' in str(e): - LOG.warn('IDBroker %s is not available for use' % id_broker_addr) - # Check response for None and if response code is successful (200) or authentication needed (401) - if (response is not None) and (response.status_code in (200, 401)): - idbrokeraddr = id_broker_addr - break - return idbrokeraddr - else: - return idbrokeraddr - else: - return idbrokeraddr + idbroker_addr = get_conf().get(_CNF_CAB_ADDRESS % fs, '') + idbroker_addr_list = idbroker_addr.split(',') + + response = None + for idb in idbroker_addr_list: + try: + response = requests.get(idb.rstrip('/') + '/dt/knoxtoken/api/v1/token', auth=HTTPKerberosAuth(), verify=False) + except Exception as e: + if 'Failed to establish a new connection' in str(e): + LOG.warning('IDBroker URL %s is not available.' % idb) + + # Check response for None and if response code is successful (200) or authentication needed (401) + if (response is not None) and (response.status_code in (200, 401)): + return idb + def get_cab_address(fs=None): + fs = validate_fs(fs) return _handle_idbroker_ha(fs) def get_cab_dt_path(fs=None): @@ -89,7 +82,12 @@ def get_cab_password(fs=None): def is_idbroker_enabled(fs=None): from desktop.conf import RAZ # Must be imported dynamically in order to have proper value - return get_cab_address(fs) is not None and not RAZ.IS_ENABLED.get() # Skipping IDBroker for FS when RAZ is present + fs = validate_fs(fs) + idbroker_addr_from_coresite = get_conf().get(_CNF_CAB_ADDRESS % fs) + + # When RAZ is configured, skip checking for IDBroker configs from core-site. + # RAZ gets precedence over IDBroker when both are configured in Hue. + return (not RAZ.IS_ENABLED.get() and bool(idbroker_addr_from_coresite)) def config_validator(): res = [] diff --git a/desktop/core/src/desktop/lib/idbroker/tests.py b/desktop/core/src/desktop/lib/idbroker/tests.py index ac725d55bf0..a26bf1a217a 100644 --- a/desktop/core/src/desktop/lib/idbroker/tests.py +++ b/desktop/core/src/desktop/lib/idbroker/tests.py @@ -13,63 +13,65 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import logging import unittest -import sys -from nose.tools import assert_equal, assert_true +from nose.tools import assert_equal +from unittest.mock import patch from desktop.lib.idbroker.client import IDBroker -if sys.version_info[0] > 2: - from unittest.mock import patch -else: - from mock import patch LOG = logging.getLogger() + class TestIDBroker(unittest.TestCase): def test_username_authentication(self): with patch('desktop.lib.idbroker.conf.get_conf') as conf: with patch('desktop.lib.idbroker.client.resource.Resource.invoke') as invoke: with patch('desktop.lib.idbroker.client.http_client.HttpClient.set_basic_auth') as set_basic_auth: - conf.return_value = { - 'fs.s3a.ext.cab.address': 'address', - 'fs.s3a.ext.cab.dt.path': 'dt_path', - 'fs.s3a.ext.cab.path': 'path', - 'fs.s3a.ext.cab.username': 'username', - 'fs.s3a.ext.cab.password': 'password' - } - invoke.return_value = { - 'Credentials': 'Credentials' - } - client = IDBroker.from_core_site('s3a', 'test') - - cab = client.get_cab() - assert_equal(invoke.call_count, 2) # get_cab calls twice - assert_equal(cab.get('Credentials'), 'Credentials') - assert_equal(set_basic_auth.call_count, 1) - - def test_kerberos_authentication(self): - with patch('desktop.lib.idbroker.conf.get_conf') as conf: - with patch('desktop.lib.idbroker.client.is_kerberos_enabled') as is_kerberos_enabled: - with patch('desktop.lib.idbroker.client.resource.Resource.invoke') as invoke: - with patch('desktop.lib.idbroker.client.http_client.HttpClient.set_kerberos_auth') as set_kerberos_auth: - is_kerberos_enabled.return_value = True + with patch('desktop.lib.idbroker.conf.get_cab_address') as get_cab_address: conf.return_value = { 'fs.s3a.ext.cab.address': 'address', 'fs.s3a.ext.cab.dt.path': 'dt_path', 'fs.s3a.ext.cab.path': 'path', - 'hadoop.security.authentication': 'kerberos', + 'fs.s3a.ext.cab.username': 'username', + 'fs.s3a.ext.cab.password': 'password' } invoke.return_value = { 'Credentials': 'Credentials' } - client = IDBroker.from_core_site('s3a', 'test') + get_cab_address.return_value = 'address' + client = IDBroker.from_core_site('s3a', 'test') cab = client.get_cab() + assert_equal(invoke.call_count, 2) # get_cab calls twice assert_equal(cab.get('Credentials'), 'Credentials') - assert_equal(set_kerberos_auth.call_count, 1) + assert_equal(set_basic_auth.call_count, 1) + + + def test_kerberos_authentication(self): + with patch('desktop.lib.idbroker.conf.get_conf') as conf: + with patch('desktop.lib.idbroker.client.is_kerberos_enabled') as is_kerberos_enabled: + with patch('desktop.lib.idbroker.client.resource.Resource.invoke') as invoke: + with patch('desktop.lib.idbroker.client.http_client.HttpClient.set_kerberos_auth') as set_kerberos_auth: + with patch('desktop.lib.idbroker.conf.get_cab_address') as get_cab_address: + is_kerberos_enabled.return_value = True + conf.return_value = { + 'fs.s3a.ext.cab.address': 'address', + 'fs.s3a.ext.cab.dt.path': 'dt_path', + 'fs.s3a.ext.cab.path': 'path', + 'hadoop.security.authentication': 'kerberos', + } + invoke.return_value = { + 'Credentials': 'Credentials' + } + get_cab_address.return_value = 'address' + + client = IDBroker.from_core_site('s3a', 'test') + cab = client.get_cab() + + assert_equal(invoke.call_count, 2) # get_cab calls twice + assert_equal(cab.get('Credentials'), 'Credentials') + assert_equal(set_kerberos_auth.call_count, 1) diff --git a/desktop/libs/aws/src/aws/conf.py b/desktop/libs/aws/src/aws/conf.py index 45055d609d1..26d5512af45 100644 --- a/desktop/libs/aws/src/aws/conf.py +++ b/desktop/libs/aws/src/aws/conf.py @@ -273,8 +273,8 @@ def get_default_get_environment_credentials(): def is_enabled(): return ('default' in list(AWS_ACCOUNTS.keys()) and AWS_ACCOUNTS['default'].get_raw() and AWS_ACCOUNTS['default'].ACCESS_KEY_ID.get()) or \ has_iam_metadata() or \ - conf_idbroker.is_idbroker_enabled('s3a') or \ - is_raz_s3() + is_raz_s3() or \ + conf_idbroker.is_idbroker_enabled('s3a') def is_ec2_instance(): diff --git a/desktop/libs/aws/src/aws/tests.py b/desktop/libs/aws/src/aws/tests.py index a11594b7427..dc260e1ae0d 100644 --- a/desktop/libs/aws/src/aws/tests.py +++ b/desktop/libs/aws/src/aws/tests.py @@ -13,10 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import logging -import sys import unittest from nose.tools import assert_equal, assert_true, assert_not_equal @@ -28,10 +25,8 @@ from desktop.lib.python_util import current_ms_from_utc from desktop.conf import RAZ -if sys.version_info[0] > 2: - from unittest.mock import patch -else: - from mock import patch +from unittest.mock import patch + LOG = logging.getLogger() @@ -54,86 +49,102 @@ def test_with_credentials(self): clear_cache() conf.clear_cache() + def test_with_idbroker(self): try: finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set with patch('aws.client.conf_idbroker.get_conf') as get_conf: - with patch('aws.client.Client.get_s3_connection'): - with patch('aws.client.IDBroker.get_cab') as get_cab: - with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata: - get_conf.return_value = { - 'fs.s3a.ext.cab.address': 'address' - } - get_cab.return_value = { - 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0} - } - has_iam_metadata.return_value = True - provider = get_credential_provider('default', 'hue') - assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId') - client1 = get_client(name='default', fs='s3a', user='hue') - client2 = get_client(name='default', fs='s3a', user='hue') - assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal - - get_cab.return_value = { - 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': int(current_ms_from_utc()) + 10*1000} - } - client3 = get_client(name='default', fs='s3a', user='hue') - client4 = get_client(name='default', fs='s3a', user='hue') - client5 = get_client(name='default', fs='s3a', user='test') - assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal - assert_not_equal(client4, client5) # Test different user have different clients + with patch('aws.client.conf_idbroker.get_cab_address') as get_cab_address: + with patch('aws.client.Client.get_s3_connection'): + with patch('aws.client.IDBroker.get_cab') as get_cab: + with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata: + get_conf.return_value = { + 'fs.s3a.ext.cab.address': 'address' + } + get_cab_address.return_value = 'address' + get_cab.return_value = { + 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0} + } + has_iam_metadata.return_value = True + provider = get_credential_provider('default', 'hue') + + assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId') + + client1 = get_client(name='default', fs='s3a', user='hue') + client2 = get_client(name='default', fs='s3a', user='hue') + + assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal + + get_cab.return_value = { + 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': int(current_ms_from_utc()) + 10*1000} + } + client3 = get_client(name='default', fs='s3a', user='hue') + client4 = get_client(name='default', fs='s3a', user='hue') + client5 = get_client(name='default', fs='s3a', user='test') + + assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal + assert_not_equal(client4, client5) # Test different user have different clients finally: finish() clear_cache() conf.clear_cache() + def test_with_idbroker_and_config(self): try: finish = conf.AWS_ACCOUNTS.set_for_testing({'default': {'region': 'ap-northeast-1'}}) with patch('aws.client.conf_idbroker.get_conf') as get_conf: - with patch('aws.client.Client.get_s3_connection'): - with patch('aws.client.IDBroker.get_cab') as get_cab: - with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata: - get_conf.return_value = { - 'fs.s3a.ext.cab.address': 'address' - } - get_cab.return_value = { - 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0} - } - has_iam_metadata.return_value = True - provider = get_credential_provider('default', 'hue') - assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId') - - client = Client.from_config(conf.AWS_ACCOUNTS['default'], get_credential_provider('default', 'hue')) - assert_equal(client._region, 'ap-northeast-1') - finally: - finish() - clear_cache() - conf.clear_cache() - - def test_with_idbroker_on_ec2(self): - try: - finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set - with patch('aws.client.aws_conf.get_region') as get_region: - with patch('aws.client.conf_idbroker.get_conf') as get_conf: + with patch('aws.client.conf_idbroker.get_cab_address') as get_cab_address: with patch('aws.client.Client.get_s3_connection'): with patch('aws.client.IDBroker.get_cab') as get_cab: with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata: - get_region.return_value = 'us-west-1' get_conf.return_value = { 'fs.s3a.ext.cab.address': 'address' } + get_cab_address.return_value = 'address' get_cab.return_value = { 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0} } has_iam_metadata.return_value = True - client = Client.from_config(None, get_credential_provider('default', 'hue')) - assert_equal(client._region, 'us-west-1') # Test different user have different clients + + provider = get_credential_provider('default', 'hue') + assert_equal(provider.get_credentials().get('AccessKeyId'), 'AccessKeyId') + + client = Client.from_config(conf.AWS_ACCOUNTS['default'], get_credential_provider('default', 'hue')) + assert_equal(client._region, 'ap-northeast-1') finally: finish() clear_cache() conf.clear_cache() + + def test_with_idbroker_on_ec2(self): + try: + finish = conf.AWS_ACCOUNTS.set_for_testing({}) # Set empty to test when no configs are set + with patch('aws.client.aws_conf.get_region') as get_region: + with patch('aws.client.conf_idbroker.get_conf') as get_conf: + with patch('aws.client.conf_idbroker.get_cab_address') as get_cab_address: + with patch('aws.client.Client.get_s3_connection'): + with patch('aws.client.IDBroker.get_cab') as get_cab: + with patch('aws.client.aws_conf.has_iam_metadata') as has_iam_metadata: + get_region.return_value = 'us-west-1' + get_conf.return_value = { + 'fs.s3a.ext.cab.address': 'address' + } + get_cab_address.return_value = 'address' + get_cab.return_value = { + 'Credentials': {'AccessKeyId': 'AccessKeyId', 'Expiration': 0} + } + has_iam_metadata.return_value = True + client = Client.from_config(None, get_credential_provider('default', 'hue')) + + assert_equal(client._region, 'us-west-1') # Test different user have different clients + finally: + finish() + clear_cache() + conf.clear_cache() + + def test_with_raz_enabled(self): with patch('aws.client.RazS3Connection') as raz_s3_connection: resets = [ diff --git a/desktop/libs/azure/src/azure/conf.py b/desktop/libs/azure/src/azure/conf.py index 496a2b7b27e..1beb1b496ae 100644 --- a/desktop/libs/azure/src/azure/conf.py +++ b/desktop/libs/azure/src/azure/conf.py @@ -166,9 +166,9 @@ def is_adls_enabled(): or (conf_idbroker.is_idbroker_enabled('azure') and has_azure_metadata())) and 'default' in list(ADLS_CLUSTERS.keys()) def is_abfs_enabled(): - return ('default' in list(AZURE_ACCOUNTS.keys()) and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get() \ - or (conf_idbroker.is_idbroker_enabled('azure') and has_azure_metadata())) and 'default' in list(ABFS_CLUSTERS.keys()) \ - or is_raz_abfs() + return is_raz_abfs() or \ + ('default' in list(AZURE_ACCOUNTS.keys()) and AZURE_ACCOUNTS['default'].get_raw() and AZURE_ACCOUNTS['default'].CLIENT_ID.get() or \ + (conf_idbroker.is_idbroker_enabled('azure') and has_azure_metadata())) and 'default' in list(ABFS_CLUSTERS.keys()) def has_adls_access(user): from desktop.conf import RAZ # Must be imported dynamically in order to have proper value diff --git a/desktop/libs/azure/src/azure/tests.py b/desktop/libs/azure/src/azure/tests.py index 9f306b93ac1..54f72c829bd 100644 --- a/desktop/libs/azure/src/azure/tests.py +++ b/desktop/libs/azure/src/azure/tests.py @@ -13,25 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import - import logging -import sys import unittest -from nose.plugins.skip import SkipTest -from nose.tools import assert_equal, assert_true, assert_not_equal +from nose.tools import assert_equal, assert_not_equal +from unittest.mock import patch from azure import conf from azure.client import get_credential_provider -from desktop.lib.fsmanager import get_client, clear_cache, is_enabled +from desktop.lib.fsmanager import get_client, clear_cache from desktop.lib.python_util import current_ms_from_utc -if sys.version_info[0] > 2: - from unittest.mock import patch -else: - from mock import patch LOG = logging.getLogger() @@ -48,7 +41,11 @@ def test_with_core_site(self): with patch('azure.conf.core_site.get_conf') as core_site_get_conf: get_token.return_value = {'access_token': 'access_token', 'token_type': '', 'expires_on': None} get_conf.return_value = {} - core_site_get_conf.return_value = {'dfs.adls.oauth2.client.id': 'client_id', 'dfs.adls.oauth2.credential': 'client_secret', 'dfs.adls.oauth2.refresh.url': 'refresh_url'} + core_site_get_conf.return_value = { + 'dfs.adls.oauth2.client.id': 'client_id', + 'dfs.adls.oauth2.credential': 'client_secret', + 'dfs.adls.oauth2.refresh.url': 'refresh_url' + } client1 = get_client(name='default', fs='adl') client2 = get_client(name='default', fs='adl', user='test') @@ -60,10 +57,15 @@ def test_with_core_site(self): f() clear_cache() + def test_with_credentials(self): try: - finish = (conf.AZURE_ACCOUNTS.set_for_testing({'default': {'client_id':'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}}), - conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})) + finish = ( + conf.AZURE_ACCOUNTS.set_for_testing({ + 'default': {'client_id': 'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'} + }), + conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}) + ) with patch('azure.client.conf_idbroker.get_conf') as get_conf: with patch('azure.client.WebHdfs.get_client'): with patch('azure.client.ActiveDirectory.get_token') as get_token: @@ -83,31 +85,41 @@ def test_with_credentials(self): def test_with_idbroker(self): try: - finish = (conf.AZURE_ACCOUNTS.set_for_testing({}), - conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})) + finish = ( + conf.AZURE_ACCOUNTS.set_for_testing({}), + conf.ADLS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}) + ) with patch('azure.client.conf_idbroker.get_conf') as get_conf: - with patch('azure.client.WebHdfs.get_client'): - with patch('azure.client.IDBroker.get_cab') as get_cab: - with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata: - get_conf.return_value = { - 'fs.azure.ext.cab.address': 'address' - } - has_azure_metadata.return_value = True - get_cab.return_value = { 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0 } - provider = get_credential_provider('default', 'hue') - assert_equal(provider.get_credentials().get('access_token'), 'access_token') - client1 = get_client(name='default', fs='adl', user='hue') - client2 = get_client(name='default', fs='adl', user='hue') - assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal - - get_cab.return_value = { - 'Credentials': {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000} - } - client3 = get_client(name='default', fs='adl', user='hue') - client4 = get_client(name='default', fs='adl', user='hue') - client5 = get_client(name='default', fs='adl', user='test') - assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal - assert_not_equal(client4, client5) # Test different user have different clients + with patch('azure.client.conf_idbroker.get_cab_address') as get_cab_address: + with patch('azure.client.WebHdfs.get_client'): + with patch('azure.client.IDBroker.get_cab') as get_cab: + with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata: + get_conf.return_value = { + 'fs.azure.ext.cab.address': 'address' + } + get_cab_address.return_value = 'address' + has_azure_metadata.return_value = True + get_cab.return_value = {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0} + provider = get_credential_provider('default', 'hue') + + assert_equal(provider.get_credentials().get('access_token'), 'access_token') + + client1 = get_client(name='default', fs='adl', user='hue') + client2 = get_client(name='default', fs='adl', user='hue') + + assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal + + get_cab.return_value = { + 'Credentials': { + 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000 + } + } + client3 = get_client(name='default', fs='adl', user='hue') + client4 = get_client(name='default', fs='adl', user='hue') + client5 = get_client(name='default', fs='adl', user='test') + + assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal + assert_not_equal(client4, client5) # Test different user have different clients finally: for f in finish: f() @@ -126,7 +138,11 @@ def test_with_core_site(self): with patch('azure.conf.core_site.get_conf') as core_site_get_conf: get_token.return_value = {'access_token': 'access_token', 'token_type': '', 'expires_on': None} get_conf.return_value = {} - core_site_get_conf.return_value = {'fs.azure.account.oauth2.client.id': 'client_id', 'fs.azure.account.oauth2.client.secret': 'client_secret', 'fs.azure.account.oauth2.client.endpoint': 'refresh_url'} + core_site_get_conf.return_value = { + 'fs.azure.account.oauth2.client.id': 'client_id', + 'fs.azure.account.oauth2.client.secret': 'client_secret', + 'fs.azure.account.oauth2.client.endpoint': 'refresh_url' + } client1 = get_client(name='default', fs='abfs') client2 = get_client(name='default', fs='abfs', user='test') @@ -138,10 +154,15 @@ def test_with_core_site(self): f() clear_cache() + def test_with_credentials(self): try: - finish = (conf.AZURE_ACCOUNTS.set_for_testing({'default': {'client_id':'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'}}), - conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})) + finish = ( + conf.AZURE_ACCOUNTS.set_for_testing({ + 'default': {'client_id': 'client_id', 'client_secret': 'client_secret', 'tenant_id': 'tenant_id'} + }), + conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}) + ) with patch('azure.client.conf_idbroker.get_conf') as get_conf: with patch('azure.client.ABFS.get_client'): with patch('azure.client.ActiveDirectory.get_token') as get_token: @@ -161,32 +182,42 @@ def test_with_credentials(self): def test_with_idbroker(self): try: - finish = (conf.AZURE_ACCOUNTS.set_for_testing({}), - conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}})) + finish = ( + conf.AZURE_ACCOUNTS.set_for_testing({}), + conf.ABFS_CLUSTERS.set_for_testing({'default': {'fs_defaultfs': 'fs_defaultfs', 'webhdfs_url': 'webhdfs_url'}}) + ) with patch('azure.client.conf_idbroker.get_conf') as get_conf: - with patch('azure.client.ABFS.get_client'): - with patch('azure.client.IDBroker.get_cab') as get_cab: - with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata: - get_conf.return_value = { - 'fs.azure.ext.cab.address': 'address' - } - has_azure_metadata.return_value = True - get_cab.return_value = { 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0 } - provider = get_credential_provider('default', 'hue') - assert_equal(provider.get_credentials().get('access_token'), 'access_token') - client1 = get_client(name='default', fs='abfs', user='hue') - client2 = get_client(name='default', fs='abfs', user='hue') - assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal - - get_cab.return_value = { - 'Credentials': {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000} - } - client3 = get_client(name='default', fs='abfs', user='hue') - client4 = get_client(name='default', fs='abfs', user='hue') - client5 = get_client(name='default', fs='abfs', user='test') - assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal - assert_not_equal(client4, client5) # Test different user have different clients + with patch('azure.client.conf_idbroker.get_cab_address') as get_cab_address: + with patch('azure.client.ABFS.get_client'): + with patch('azure.client.IDBroker.get_cab') as get_cab: + with patch('azure.client.conf.has_azure_metadata') as has_azure_metadata: + get_conf.return_value = { + 'fs.azure.ext.cab.address': 'address' + } + get_cab_address.return_value = 'address' + has_azure_metadata.return_value = True + get_cab.return_value = {'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': 0} + provider = get_credential_provider('default', 'hue') + + assert_equal(provider.get_credentials().get('access_token'), 'access_token') + + client1 = get_client(name='default', fs='abfs', user='hue') + client2 = get_client(name='default', fs='abfs', user='hue') + + assert_not_equal(client1, client2) # Test that with Expiration 0 clients not equal + + get_cab.return_value = { + 'Credentials': { + 'access_token': 'access_token', 'token_type': 'token_type', 'expires_on': int(current_ms_from_utc()) + 10*1000 + } + } + client3 = get_client(name='default', fs='abfs', user='hue') + client4 = get_client(name='default', fs='abfs', user='hue') + client5 = get_client(name='default', fs='abfs', user='test') + + assert_equal(client3, client4) # Test that with 10 sec expiration, clients equal + assert_not_equal(client4, client5) # Test different user have different clients finally: for f in finish: f() - clear_cache() \ No newline at end of file + clear_cache()