From 9070c60ee950896fabc85cd360d78f4b1ac0ee06 Mon Sep 17 00:00:00 2001 From: Michal Schielmann Date: Wed, 7 Aug 2024 08:33:50 +0200 Subject: [PATCH] added support for idempotency key --- shift4/blacklist.py | 4 +- shift4/cards.py | 14 +++- shift4/charges.py | 22 +++-- shift4/credits.py | 10 ++- shift4/customers.py | 10 ++- shift4/disputes.py | 12 ++- shift4/payment_methods.py | 4 +- shift4/plans.py | 10 ++- shift4/request_options.py | 14 ++++ shift4/resource.py | 27 +++++-- shift4/subscriptions.py | 12 ++- tests/integration/test_charges.py | 102 ++++++++++++++++++++++++ tests/integration/test_customers.py | 12 +-- tests/integration/test_disputes.py | 2 +- tests/integration/test_events.py | 2 +- tests/integration/test_file_uploads.py | 2 +- tests/integration/test_fraud_warning.py | 2 +- tests/integration/test_plans.py | 6 +- tests/integration/testcase.py | 12 ++- 19 files changed, 220 insertions(+), 59 deletions(-) create mode 100644 shift4/request_options.py diff --git a/shift4/blacklist.py b/shift4/blacklist.py index 29fb77b..917cd85 100644 --- a/shift4/blacklist.py +++ b/shift4/blacklist.py @@ -2,8 +2,8 @@ class Blacklist(Resource): - def create(self, params): - return self._post("/blacklist", params) + def create(self, params, request_options=None): + return self._post("/blacklist", params, request_options=request_options) def get(self, blacklist_rule_id): return self._get("/blacklist/%s" % blacklist_rule_id) diff --git a/shift4/cards.py b/shift4/cards.py index 81dc1e1..3dbd9b0 100644 --- a/shift4/cards.py +++ b/shift4/cards.py @@ -2,14 +2,20 @@ class Cards(Resource): - def create(self, customer_id, params): - return self._post("/customers/%s/cards" % customer_id, params) + def create(self, customer_id, params, request_options=None): + return self._post( + "/customers/%s/cards" % customer_id, params, request_options=request_options + ) def get(self, customer_id, card_id): return self._get("/customers/%s/cards/%s" % (customer_id, card_id)) - def update(self, customer_id, card_id, params): - return self._post("/customers/%s/cards/%s" % (customer_id, card_id), params) + def update(self, customer_id, card_id, params, request_options=None): + return self._post( + "/customers/%s/cards/%s" % (customer_id, card_id), + params, + request_options=request_options, + ) def delete(self, customer_id, card_id): return self._delete("/customers/%s/cards/%s" % (customer_id, card_id)) diff --git a/shift4/charges.py b/shift4/charges.py index 73af10e..678d276 100644 --- a/shift4/charges.py +++ b/shift4/charges.py @@ -2,20 +2,26 @@ class Charges(Resource): - def create(self, params): - return self._post("/charges", params) + def create(self, params, request_options=None): + return self._post("/charges", params, request_options=request_options) def get(self, charge_id): return self._get("/charges/%s" % charge_id) - def update(self, charge_id, params): - return self._post("/charges/%s" % charge_id, params) + def update(self, charge_id, params, request_options=None): + return self._post( + "/charges/%s" % charge_id, params, request_options=request_options + ) def list(self, params=None): return self._get("/charges", params) - def capture(self, charge_id): - return self._post("/charges/%s/capture" % charge_id) + def capture(self, charge_id, request_options=None): + return self._post( + "/charges/%s/capture" % charge_id, request_options=request_options + ) - def refund(self, charge_id, params=None): - return self._post("/charges/%s/refund" % charge_id, params) + def refund(self, charge_id, params=None, request_options=None): + return self._post( + "/charges/%s/refund" % charge_id, params, request_options=request_options + ) diff --git a/shift4/credits.py b/shift4/credits.py index 34bd31a..db61c52 100644 --- a/shift4/credits.py +++ b/shift4/credits.py @@ -2,14 +2,16 @@ class Credits(Resource): - def create(self, params): - return self._post("/credits", params) + def create(self, params, request_options=None): + return self._post("/credits", params, request_options=request_options) def get(self, credit_id): return self._get("/credits/%s" % credit_id) - def update(self, credit_id, params): - return self._post("/credits/%s" % credit_id, params) + def update(self, credit_id, params, request_options=None): + return self._post( + "/credits/%s" % credit_id, params, request_options=request_options + ) def list(self, params=None): return self._get("/credits", params) diff --git a/shift4/customers.py b/shift4/customers.py index 9dbe210..a53b0cd 100644 --- a/shift4/customers.py +++ b/shift4/customers.py @@ -2,14 +2,16 @@ class Customers(Resource): - def create(self, params): - return self._post("/customers", params) + def create(self, params, request_options=None): + return self._post("/customers", params, request_options=request_options) def get(self, customer_id): return self._get("/customers/%s" % customer_id) - def update(self, customer_id, params): - return self._post("/customers/%s" % customer_id, params) + def update(self, customer_id, params, request_options=None): + return self._post( + "/customers/%s" % customer_id, params, request_options=request_options + ) def delete(self, customer_id): return self._delete("/customers/%s" % customer_id) diff --git a/shift4/disputes.py b/shift4/disputes.py index 62020dc..849efa9 100644 --- a/shift4/disputes.py +++ b/shift4/disputes.py @@ -5,11 +5,15 @@ class Disputes(Resource): def get(self, dispute_id): return self._get("/disputes/%s" % dispute_id) - def update(self, dispute_id, params): - return self._post("/disputes/%s" % dispute_id, params) + def update(self, dispute_id, params, request_options=None): + return self._post( + "/disputes/%s" % dispute_id, params, request_options=request_options + ) - def close(self, dispute_id): - return self._post("/disputes/%s/close" % dispute_id) + def close(self, dispute_id, request_options=None): + return self._post( + "/disputes/%s/close" % dispute_id, request_options=request_options + ) def list(self, params=None): return self._get("/disputes", params) diff --git a/shift4/payment_methods.py b/shift4/payment_methods.py index 2f3c27e..c8ac81c 100644 --- a/shift4/payment_methods.py +++ b/shift4/payment_methods.py @@ -2,8 +2,8 @@ class PaymentMethods(Resource): - def create(self, params): - return self._post("/payment-methods", params) + def create(self, params, request_options=None): + return self._post("/payment-methods", params, request_options=request_options) def get(self, payment_method_id): return self._get("/payment-methods/%s" % payment_method_id) diff --git a/shift4/plans.py b/shift4/plans.py index 2bfe1ea..d3104c4 100644 --- a/shift4/plans.py +++ b/shift4/plans.py @@ -2,14 +2,16 @@ class Plans(Resource): - def create(self, params): - return self._post("/plans", params) + def create(self, params, request_options=None): + return self._post("/plans", params, request_options=request_options) def get(self, plan_id): return self._get("/plans/%s" % plan_id) - def update(self, plan_id, params): - return self._post("/plans/%s" % plan_id, params) + def update(self, plan_id, params, request_options=None): + return self._post( + "/plans/%s" % plan_id, params, request_options=request_options + ) def delete(self, plan_id): return self._delete("/plans/%s" % plan_id) diff --git a/shift4/request_options.py b/shift4/request_options.py new file mode 100644 index 0000000..a54d6bf --- /dev/null +++ b/shift4/request_options.py @@ -0,0 +1,14 @@ +class RequestOptions: + __idempotency_key = None + + def name(self): + return self.__class__.__name__.lower() + + def set_idempotency_key(self, idempotency_key): + self.__idempotency_key = idempotency_key + + def has_idempotency_key(self): + return self.__idempotency_key is not None + + def get_idempotency_key(self): + return self.__idempotency_key diff --git a/shift4/resource.py b/shift4/resource.py index e3590aa..5b06229 100644 --- a/shift4/resource.py +++ b/shift4/resource.py @@ -13,8 +13,10 @@ def name(self): def _get(self, path, params=None, url=None): return self.__request("GET", path, params=params, url=url) - def _post(self, path, json=None, url=None): - return self.__request("POST", path, json=json, url=url) + def _post(self, path, json=None, url=None, request_options=None): + return self.__request( + "POST", path, json=json, url=url, request_options=request_options + ) def _multipart(self, path, params=None, files=None, url=None): return self.__request("POST", path, params=params, files=files, url=url) @@ -23,14 +25,23 @@ def _delete(self, path, params=None, url=None): return self.__request("DELETE", path, params=params, url=url) @classmethod - def __request(cls, method, path, params=None, json=None, files=None, url=None): + def __request( + cls, + method, + path, + params=None, + json=None, + files=None, + url=None, + request_options=None, + ): if url is None: url = api.api_url.rstrip("/") resp = requests.request( method, url=url + path, auth=(api.secret_key, ""), - headers=cls.__create_headers(), + headers=cls.__create_headers(request_options), files=files, params=params, json=json, @@ -51,12 +62,14 @@ def __request(cls, method, path, params=None, json=None, files=None, url=None): ) @classmethod - def __create_headers(cls): + def __create_headers(cls, request_options=None): user_agent = "Shift4-Python/%s (Python/%s.%s.%s)" % ( __version__, sys.version_info.major, sys.version_info.minor, sys.version_info.micro, ) - - return {"User-Agent": user_agent} + headers = {"User-Agent": user_agent} + if request_options is not None and request_options.has_idempotency_key(): + headers["Idempotency-Key"] = request_options.get_idempotency_key() + return headers diff --git a/shift4/subscriptions.py b/shift4/subscriptions.py index cfa514b..636e33e 100644 --- a/shift4/subscriptions.py +++ b/shift4/subscriptions.py @@ -2,14 +2,18 @@ class Subscriptions(Resource): - def create(self, params): - return self._post("/subscriptions", params) + def create(self, params, request_options=None): + return self._post("/subscriptions", params, request_options=request_options) def get(self, subscription_id): return self._get("/subscriptions/%s" % subscription_id) - def update(self, subscription_id, params): - return self._post("/subscriptions/%s" % subscription_id, params) + def update(self, subscription_id, params, request_options=None): + return self._post( + "/subscriptions/%s" % subscription_id, + params, + request_options=request_options, + ) def cancel(self, subscription_id): return self._delete("/subscriptions/%s" % subscription_id) diff --git a/tests/integration/test_charges.py b/tests/integration/test_charges.py index 654e9ef..3c2c179 100644 --- a/tests/integration/test_charges.py +++ b/tests/integration/test_charges.py @@ -1,3 +1,5 @@ +from shift4.request_options import RequestOptions +from . import random_string from .data.charges import valid_charge_req from .data.customers import valid_customer_req from .testcase import TestCase @@ -87,3 +89,103 @@ def test_list(self, api): self.assert_list_response_contains_exactly_by_id( charges_after_last_id, [charge2, charge1] ) + + def test_will_not_create_duplicate_if_same_idempotency_key_is_used(self, api): + # given + request_options = RequestOptions() + request_options.set_idempotency_key(random_string()) + charge_req = valid_charge_req() + + # when + first_call_response = api.charges.create( + charge_req, request_options=request_options + ) + second_call_response = api.charges.create( + charge_req, request_options=request_options + ) + + # then + assert first_call_response == second_call_response + + def test_will_create_two_instances_if_different_idempotency_keys_are_used( + self, api + ): + # given + request_options = RequestOptions() + request_options.set_idempotency_key(random_string()) + other_request_options = RequestOptions() + other_request_options.set_idempotency_key(random_string()) + charge_req = valid_charge_req() + + # when + first_call_response = api.charges.create( + charge_req, request_options=request_options + ) + second_call_response = api.charges.create( + charge_req, request_options=other_request_options + ) + + # then + assert first_call_response != second_call_response + + def test_will_create_two_instances_if_no_idempotency_keys_are_used(self, api): + # given + charge_req = valid_charge_req() + + # when + first_call_response = api.charges.create(charge_req) + second_call_response = api.charges.create(charge_req) + + # then + assert first_call_response != second_call_response + + def test_will_throw_exception_if_same_idempotency_key_is_used_for_two_different_create_requests( + self, api + ): + # given + request_options = RequestOptions() + request_options.set_idempotency_key(random_string()) + charge_req = valid_charge_req() + + # when + api.charges.create(charge_req, request_options=request_options) + charge_req["amount"] = "42" + exception = self.assert_shift4_exception( + api.charges.create(charge_req, request_options=request_options) + ) + + # then + assert exception.type == "invalid_request" + assert exception.code is None + assert ( + exception.message + == "Idempotent key used for request with different parameters." + ) + + def test_will_throw_exception_if_same_idempotency_key_is_used_for_two_different_update_requests( + self, api + ): + # given + request_options = RequestOptions() + request_options.set_idempotency_key(random_string()) + charge_req = valid_charge_req() + created = api.charges.create(charge_req, request_options=request_options) + update_request_params = { + "description": "updated description", + "metadata": {"key": "updated value"}, + } + + # when + api.charges.update(created["id"], update_request_params, request_options) + update_request_params["description"] = "other description" + exception = self.assert_shift4_exception( + api.charges.update(created["id"], update_request_params) + ) + + # then + assert exception.type == "invalid_request" + assert exception.code is None + assert ( + exception.message + == "Idempotent key used for request with different parameters." + ) diff --git a/tests/integration/test_customers.py b/tests/integration/test_customers.py index 5c261d4..c5d2abe 100644 --- a/tests/integration/test_customers.py +++ b/tests/integration/test_customers.py @@ -19,18 +19,18 @@ def test_create_and_get(self, api): def test_create_without_email(self, api): exception = self.assert_shift4_exception(api.customers.create, {}) assert exception.type == "invalid_request" - assert exception.code == None + assert exception.code is None assert exception.message == "email: Must not be empty." - assert exception.charge_id == None - assert exception.blacklist_rule_id == None + assert exception.charge_id is None + assert exception.blacklist_rule_id is None def test_get_with_invalid_id(self, api): exception = self.assert_shift4_exception(api.customers.get, "1") assert exception.type == "invalid_request" - assert exception.code == None + assert exception.code is None assert exception.message == "Customer '1' does not exist" - assert exception.charge_id == None - assert exception.blacklist_rule_id == None + assert exception.charge_id is None + assert exception.blacklist_rule_id is None def test_delete(self, api): # given diff --git a/tests/integration/test_disputes.py b/tests/integration/test_disputes.py index 0b26222..e6c72a4 100644 --- a/tests/integration/test_disputes.py +++ b/tests/integration/test_disputes.py @@ -54,4 +54,4 @@ def test_list(self, api): # when response = api.disputes.list({"limit": 100}) # then - self.assertListResponseContainsInAnyOrderById(response, [dispute]) + self.assert_list_response_contains_in_any_order_by_id(response, [dispute]) diff --git a/tests/integration/test_events.py b/tests/integration/test_events.py index 9d0f555..e1e0be4 100644 --- a/tests/integration/test_events.py +++ b/tests/integration/test_events.py @@ -26,4 +26,4 @@ def test_list(self, api): # when response = api.events.list({"limit": 100}) # then - self.assertListResponseContainsInAnyOrderById(response, [event]) + self.assert_list_response_contains_in_any_order_by_id(response, [event]) diff --git a/tests/integration/test_file_uploads.py b/tests/integration/test_file_uploads.py index 0315dcd..48b0554 100644 --- a/tests/integration/test_file_uploads.py +++ b/tests/integration/test_file_uploads.py @@ -43,4 +43,4 @@ def test_list(self, tmpdir, api): # when response = api.file_uploads.list({"limit": 100}) # - self.assertListResponseContainsInAnyOrderById(response, [uploaded]) + self.assert_list_response_contains_in_any_order_by_id(response, [uploaded]) diff --git a/tests/integration/test_fraud_warning.py b/tests/integration/test_fraud_warning.py index 6dde8e7..c9d25f9 100644 --- a/tests/integration/test_fraud_warning.py +++ b/tests/integration/test_fraud_warning.py @@ -36,4 +36,4 @@ def test_list(self, api): # when response = api.fraud_warnings.list({"limit": 100}) # then - self.assertListResponseContainsInAnyOrderById(response, [fraud_warning]) + self.assert_list_response_contains_in_any_order_by_id(response, [fraud_warning]) diff --git a/tests/integration/test_plans.py b/tests/integration/test_plans.py index f4cf109..3905ac3 100644 --- a/tests/integration/test_plans.py +++ b/tests/integration/test_plans.py @@ -53,5 +53,7 @@ def test_list(self, api): deleted_plans = api.plans.list({"limit": 100, "deleted": True}) # then - self.assertListResponseContainsInAnyOrderById(all_plans, [plan2, plan1]) - self.assertListResponseContainsInAnyOrderById(deleted_plans, [deleted_plan]) + self.assert_list_response_contains_in_any_order_by_id(all_plans, [plan2, plan1]) + self.assert_list_response_contains_in_any_order_by_id( + deleted_plans, [deleted_plan] + ) diff --git a/tests/integration/testcase.py b/tests/integration/testcase.py index df5c48d..73adc55 100644 --- a/tests/integration/testcase.py +++ b/tests/integration/testcase.py @@ -24,7 +24,8 @@ def api(self): shift4.uploads_url = previous_uploads_url shift4.api_url = previous_api_url - def assert_shift4_exception(self, fun, *args, **kwargs): + @staticmethod + def assert_shift4_exception(fun, *args, **kwargs): try: fun(*args, **kwargs) except shift4.Shift4Exception as e: @@ -34,19 +35,22 @@ def assert_shift4_exception(self, fun, *args, **kwargs): else: pytest.fail("Didn't receive exception") - def assert_card_matches_request(self, card, card_req): + @staticmethod + def assert_card_matches_request(card, card_req): assert card["first6"] == card_req["number"][:6] assert card["last4"] == card_req["number"][-4:] assert card["expMonth"] == card_req["expMonth"] assert card["expYear"] == card_req["expYear"] assert card["cardholderName"] == card_req["cardholderName"] - def assert_list_response_contains_exactly_by_id(self, response, objects): + @staticmethod + def assert_list_response_contains_exactly_by_id(response, objects): response_ids = list(map(lambda o: o["id"], response["list"])) object_ids = list(map(lambda o: o["id"], objects)) assert response_ids == object_ids - def assertListResponseContainsInAnyOrderById(self, response, objects): + @staticmethod + def assert_list_response_contains_in_any_order_by_id(response, objects): response_ids = list(map(lambda o: o["id"], response["list"])) object_ids = list(map(lambda o: o["id"], objects)) for oid in object_ids: