Skip to content

Commit

Permalink
Merge pull request #110 from datastax/feature/#109-ops-py-client-param
Browse files Browse the repository at this point in the history
adapt ops to the new client-requiring make_request
  • Loading branch information
erichare authored Nov 16, 2023
2 parents 4d1adc0 + 4e3639a commit c3e29f6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
38 changes: 22 additions & 16 deletions astrapy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
from astrapy.defaults import DEFAULT_DEV_OPS_API_VERSION, DEFAULT_DEV_OPS_URL

import logging
import httpx


logger = logging.getLogger(__name__)


class AstraDBOps:
# Initialize the shared httpx client as a class attribute
client = httpx.Client()

def __init__(self, token, dev_ops_url=None, dev_ops_api_version=None):
dev_ops_url = (dev_ops_url or DEFAULT_DEV_OPS_URL).strip("/")
dev_ops_api_version = (
Expand All @@ -34,6 +39,7 @@ def _ops_request(self, method, path, options=None, json_data=None):
options = {} if options is None else options

return make_request(
client=self.client,
base_url=self.base_url,
method=method,
auth_header="Authorization",
Expand Down Expand Up @@ -465,7 +471,7 @@ def get_available_classic_regions(self):
dict: A list of available classic regions.
"""
return self._ops_request(
method=http_methods.GET, path=f"/availableRegions"
method=http_methods.GET, path="/availableRegions"
).json()

def get_available_regions(self):
Expand All @@ -476,7 +482,7 @@ def get_available_regions(self):
dict: A list of available regions for serverless deployment.
"""
return self._ops_request(
method=http_methods.GET, path=f"/regions/serverless"
method=http_methods.GET, path="/regions/serverless"
).json()

def get_roles(self):
Expand All @@ -487,7 +493,7 @@ def get_roles(self):
dict: A list of roles within the organization.
"""
return self._ops_request(
method=http_methods.GET, path=f"/organizations/roles"
method=http_methods.GET, path="/organizations/roles"
).json()

def create_role(self, role_definition=None):
Expand All @@ -502,7 +508,7 @@ def create_role(self, role_definition=None):
"""
return self._ops_request(
method=http_methods.POST,
path=f"/organizations/roles",
path="/organizations/roles",
json_data=role_definition,
).json()

Expand Down Expand Up @@ -563,7 +569,7 @@ def invite_user(self, user_definition=None):
"""
return self._ops_request(
method=http_methods.PUT,
path=f"/organizations/users",
path="/organizations/users",
json_data=user_definition,
).json()

Expand All @@ -575,7 +581,7 @@ def get_users(self):
dict: A list of users within the organization.
"""
return self._ops_request(
method=http_methods.GET, path=f"/organizations/users"
method=http_methods.GET, path="/organizations/users"
).json()

def get_user(self, user=""):
Expand Down Expand Up @@ -631,7 +637,7 @@ def get_clients(self):
dict: A list of client IDs and their associated secrets.
"""
return self._ops_request(
method=http_methods.GET, path=f"/clientIdSecrets"
method=http_methods.GET, path="/clientIdSecrets"
).json()

def create_token(self, roles=None):
Expand All @@ -646,7 +652,7 @@ def create_token(self, roles=None):
"""
return self._ops_request(
method=http_methods.POST,
path=f"/clientIdSecrets",
path="/clientIdSecrets",
json_data=roles,
).json()

Expand All @@ -671,7 +677,7 @@ def get_organization(self):
Returns:
dict: The details of the organization.
"""
return self._ops_request(method=http_methods.GET, path=f"/currentOrg").json()
return self._ops_request(method=http_methods.GET, path="/currentOrg").json()

def get_access_lists(self):
"""
Expand All @@ -680,7 +686,7 @@ def get_access_lists(self):
Returns:
dict: A list of access lists.
"""
return self._ops_request(method=http_methods.GET, path=f"/access-lists").json()
return self._ops_request(method=http_methods.GET, path="/access-lists").json()

def get_access_list_template(self):
"""
Expand All @@ -690,7 +696,7 @@ def get_access_list_template(self):
dict: An access list template.
"""
return self._ops_request(
method=http_methods.GET, path=f"/access-list/template"
method=http_methods.GET, path="/access-list/template"
).json()

def validate_access_list(self):
Expand All @@ -701,7 +707,7 @@ def validate_access_list(self):
dict: The validation result of the access list configuration.
"""
return self._ops_request(
method=http_methods.POST, path=f"/access-list/validate"
method=http_methods.POST, path="/access-list/validate"
).json()

def get_private_links(self):
Expand All @@ -712,7 +718,7 @@ def get_private_links(self):
dict: A list of private link connections.
"""
return self._ops_request(
method=http_methods.GET, path=f"/organizations/private-link"
method=http_methods.GET, path="/organizations/private-link"
).json()

def get_streaming_providers(self):
Expand All @@ -723,7 +729,7 @@ def get_streaming_providers(self):
dict: A list of available streaming service providers.
"""
return self._ops_request(
method=http_methods.GET, path=f"/streaming/providers"
method=http_methods.GET, path="/streaming/providers"
).json()

def get_streaming_tenants(self):
Expand All @@ -734,7 +740,7 @@ def get_streaming_tenants(self):
dict: A list of streaming tenants and their details.
"""
return self._ops_request(
method=http_methods.GET, path=f"/streaming/tenants"
method=http_methods.GET, path="/streaming/tenants"
).json()

def create_streaming_tenant(self, tenant=None):
Expand All @@ -749,7 +755,7 @@ def create_streaming_tenant(self, tenant=None):
"""
return self._ops_request(
method=http_methods.POST,
path=f"/streaming/tenants",
path="/streaming/tenants",
json_data=tenant,
).json()

Expand Down
7 changes: 4 additions & 3 deletions tests/astrapy/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import os
import uuid

from dotenv import load_dotenv
from faker import Faker

logger = logging.getLogger(__name__)
fake = Faker()


from dotenv import load_dotenv

load_dotenv()


Expand All @@ -54,7 +53,7 @@ def test_client_type(devops_client):
@pytest.mark.describe("should get all databases")
def test_get_databases(devops_client):
response = devops_client.get_databases()
assert type(response) is list
assert isinstance(response, list)


@pytest.mark.describe("should create a database")
Expand All @@ -71,6 +70,8 @@ def test_create_database(devops_client):
"dbType": "vector",
}
response = devops_client.create_database(database_definition=database_definition)
assert response is not None
assert "id" in response
assert response["id"] is not None
ASTRA_TEMP_DB = response["id"]

Expand Down
4 changes: 2 additions & 2 deletions tests/astrapy/test_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Iterable, TypeVar

from astrapy.db import AstraDB
from astrapy.defaults import DEFAULT_KEYSPACE_NAME, DEFAULT_REGION
from astrapy.defaults import DEFAULT_KEYSPACE_NAME

from dotenv import load_dotenv
import pytest
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_collection():
assert inserted_ids == {str(i) for i in range(N)}
yield astra_db_collection
if int(os.getenv("TEST_PAGINATION_SKIP_DELETE_COLLECTION", "0")) == 0:
res = astra_db.delete_collection(collection_name=TEST_COLLECTION_NAME)
_ = astra_db.delete_collection(collection_name=TEST_COLLECTION_NAME)


@pytest.mark.describe(
Expand Down

0 comments on commit c3e29f6

Please sign in to comment.