Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated endpoint parsing #618

Merged
merged 13 commits into from
Oct 31, 2023
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,6 @@ venv.bak/

# mypy
.mypy_cache/

# OSX specific files
.DS_Store
20 changes: 12 additions & 8 deletions dapr/aio/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from dapr.clients.exceptions import DaprInternalError
from dapr.clients.grpc._state import StateOptions, StateItem
from dapr.clients.grpc._helpers import getWorkflowRuntimeStatus
from dapr.conf.helpers import parse_endpoint
from dapr.conf.helpers import GrpcEndpoint
from dapr.conf import settings
from dapr.proto import api_v1, api_service_v1, common_v1
from dapr.proto.runtime.v1.dapr_pb2 import UnsubscribeConfigurationResponse
Expand Down Expand Up @@ -139,14 +139,18 @@ def __init__(
address = settings.DAPR_GRPC_ENDPOINT or (f"{settings.DAPR_RUNTIME_HOST}:"
f"{settings.DAPR_GRPC_PORT}")

self._scheme, self._hostname, self._port = parse_endpoint(address)
try:
self._endpoint = GrpcEndpoint(address)
except ValueError as error:
raise DaprInternalError(f'{error}') from error

if self._scheme == "https":
self._channel = grpc.aio.secure_channel(f"{self._hostname}:{self._port}",
if self._endpoint.is_secure():
self._channel = grpc.aio.secure_channel(self._endpoint.get_endpoint(),
credentials=self.get_credentials(),
options=options)
options=options) # type: ignore
else:
self._channel = grpc.aio.insecure_channel(address, options) # type: ignore
self._channel = grpc.aio.insecure_channel(self._endpoint.get_endpoint(),
options) # type: ignore

if settings.DAPR_API_TOKEN:
api_token_interceptor = DaprClientInterceptorAsync([
Expand All @@ -164,7 +168,7 @@ def get_credentials(self):

async def close(self):
"""Closes Dapr runtime gRPC channel."""
if self._channel:
if hasattr(self, '_channel') and self._channel:
self._channel.close()

async def __aenter__(self) -> Self: # type: ignore
Expand Down Expand Up @@ -1442,7 +1446,7 @@ async def wait(self, timeout_s: float):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(timeout_s)
try:
s.connect((self._hostname, self._port))
s.connect((self._endpoint.get_hostname(), self._endpoint.get_port_as_int()))
return
except Exception as e:
remaining = (start + timeout_s) - time.time()
Expand Down
31 changes: 17 additions & 14 deletions dapr/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
validateNotNone,
validateNotBlankString,
)
from dapr.conf.helpers import parse_endpoint
from dapr.conf.helpers import GrpcEndpoint
from dapr.clients.grpc._request import (
InvokeMethodRequest,
BindingRequest,
Expand Down Expand Up @@ -138,15 +138,18 @@ def __init__(
address = settings.DAPR_GRPC_ENDPOINT or (f"{settings.DAPR_RUNTIME_HOST}:"
f"{settings.DAPR_GRPC_PORT}")

self._scheme, self._hostname, self._port = parse_endpoint(address)
try:
self._endpoint = GrpcEndpoint(address)
except ValueError as error:
raise DaprInternalError(f'{error}') from error

if self._scheme == "https":
self._channel = grpc.secure_channel(f"{self._hostname}:{self._port}", # type: ignore
if self._endpoint.is_secure():
self._channel = grpc.secure_channel(self._endpoint.get_endpoint(), # type: ignore
self.get_credentials(),

options=options)
else:
self._channel = grpc.insecure_channel(address, options=options) # type: ignore
self._channel = grpc.insecure_channel(self._endpoint.get_endpoint(), # type: ignore
options=options)

if settings.DAPR_API_TOKEN:
api_token_interceptor = DaprClientInterceptor([
Expand All @@ -166,7 +169,7 @@ def get_credentials(self):

def close(self):
"""Closes Dapr runtime gRPC channel."""
if self._channel:
if hasattr(self, '_channel') and self._channel:
self._channel.close()

def __del__(self):
Expand Down Expand Up @@ -805,8 +808,8 @@ def delete_state(
:class:`DaprResponse` gRPC metadata returned from callee
"""
if metadata is not None:
warn('metadata argument is deprecated. Dapr already intercepts API token headers '
'and this is not needed.', DeprecationWarning, stacklevel=2)
warn('metadata argument is deprecated. Dapr already intercepts API token '
'headers and this is not needed.', DeprecationWarning, stacklevel=2)

if not store_name or len(store_name) == 0 or len(store_name.strip()) == 0:
raise ValueError("State store name cannot be empty")
Expand Down Expand Up @@ -861,8 +864,8 @@ def get_secret(
:class:`GetSecretResponse` object with the secret and metadata returned from callee
"""
if metadata is not None:
warn('metadata argument is deprecated. Dapr already intercepts API token headers '
'and this is not needed.', DeprecationWarning, stacklevel=2)
warn('metadata argument is deprecated. Dapr already intercepts API token '
'headers and this is not needed.', DeprecationWarning, stacklevel=2)

req = api_v1.GetSecretRequest(
store_name=store_name,
Expand Down Expand Up @@ -908,8 +911,8 @@ def get_bulk_secret(
:class:`GetBulkSecretResponse` object with secrets and metadata returned from callee
"""
if metadata is not None:
warn('metadata argument is deprecated. Dapr already intercepts API token headers '
'and this is not needed.', DeprecationWarning, stacklevel=2)
warn('metadata argument is deprecated. Dapr already intercepts API token '
'headers and this is not needed.', DeprecationWarning, stacklevel=2)

req = api_v1.GetBulkSecretRequest(
store_name=store_name,
Expand Down Expand Up @@ -1431,7 +1434,7 @@ def wait(self, timeout_s: float):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(timeout_s)
try:
s.connect((self._hostname, self._port))
s.connect((self._endpoint.get_hostname(), self._endpoint.get_port_as_int()))
return
except Exception as e:
remaining = (start + timeout_s) - time.time()
Expand Down
181 changes: 128 additions & 53 deletions dapr/conf/helpers.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,131 @@
from typing import Tuple


def parse_endpoint(address: str) -> Tuple[str, str, int]:
scheme = "http"
fqdn = "localhost"
port = 80
addr = address

addr_list = address.split("://")

if len(addr_list) == 2:
# A scheme was explicitly specified
scheme = addr_list[0]
if scheme == "https":
port = 443
addr = addr_list[1]

addr_list = addr.split(":")
if len(addr_list) == 2:
# A port was explicitly specified
if len(addr_list[0]) > 0:
fqdn = addr_list[0]
# Account for Endpoints of the type http://localhost:3500/v1.0/invoke
addr_list = addr_list[1].split("/")
port = addr_list[0] # type: ignore
elif len(addr_list) == 1:
# No port was specified
# Account for Endpoints of the type :3500/v1.0/invoke
addr_list = addr_list[0].split("/")
fqdn = addr_list[0]
else:
# IPv6 address
addr_list = addr.split("]:")
if len(addr_list) == 2:
# A port was explicitly specified
fqdn = addr_list[0]
fqdn = fqdn.replace("[", "")

addr_list = addr_list[1].split("/")
port = addr_list[0] # type: ignore
elif len(addr_list) == 1:
# No port was specified
addr_list = addr_list[0].split("/")
fqdn = addr_list[0]
fqdn = fqdn.replace("[", "")
fqdn = fqdn.replace("]", "")
from urllib.parse import urlparse, parse_qs


class URIParseConfig:
DEFAULT_SCHEME = "dns"
DEFAULT_HOSTNAME = "localhost"
DEFAULT_PORT = 443
DEFAULT_TLS = False
DEFAULT_AUTHORITY = ""
ACCEPTED_SCHEMES = ["dns", "unix", "unix-abstract", "vsock", "http", "https", "grpc", "grpcs"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @elena-kolevska
something's off here, both grpc and grpcs are not present in the naming resolution doc and are also missing in the go-sdk: https://github.com/dapr/go-sdk/blob/main/client/internal/parse.go#L160

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in this PR: #700

VALID_SCHEMES = ["dns", "unix", "unix-abstract", "vsock", "grpc", "grpcs"]


class GrpcEndpoint:
def __init__(self, url: str):
self.authority = URIParseConfig.DEFAULT_AUTHORITY
self.url = url

url = self.preprocess_url(url)
parsed_url = urlparse(url)
validate_path_and_query(parsed_url.path, parsed_url.query, parsed_url.scheme)
tls = extract_tls_from_query(parsed_url.query, parsed_url.scheme)

self.scheme = parsed_url.scheme or URIParseConfig.DEFAULT_SCHEME
self.hostname = parsed_url.hostname or URIParseConfig.DEFAULT_HOSTNAME
self.port = parsed_url.port or URIParseConfig.DEFAULT_PORT
self.tls = tls or URIParseConfig.DEFAULT_TLS

def is_secure(self) -> bool:
return self.tls

def get_scheme(self) -> str:
return self.scheme if self.scheme in URIParseConfig.VALID_SCHEMES \
else URIParseConfig.DEFAULT_SCHEME

def get_port(self) -> str:
port = self.get_port_as_int()
if port == 0:
return ""

return str(port)

def get_port_as_int(self) -> int:
if self.scheme in ["unix", "unix-abstract"]:
return 0

return self.port

def get_hostname(self) -> str:
hostname = self.hostname
if self.hostname.count(":") == 7:
# IPv6 address
hostname = f"[{hostname}]"
return hostname

def get_endpoint(self) -> str:
scheme = self.get_scheme()
port = "" if len(self.get_port()) == 0 else f":{self.port}"

if scheme == "unix":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a warning if the user uses the deprecated "https" or "http" schemas?

Copy link
Contributor Author

@elena-kolevska elena-kolevska Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separator = "://" if self.url.startswith("unix://") else ":"
return f"{scheme}{separator}{self.hostname}"

if scheme == "vsock":
port = "" if self.port == 0 else f":{self.port}"
return f"{scheme}:{self.get_hostname()}{port}"

if scheme == "unix-abstract":
return f"{scheme}:{self.get_hostname()}{port}"

if scheme == "dns":
authority = f"//{self.authority}/" if self.authority else ""
return f"{scheme}:{authority}{self.get_hostname()}{port}"

return f"{scheme}:{self.get_hostname()}{port}"

def preprocess_url(self, url: str) -> str:
url_list = url.split(":")
if len(url_list) == 3 and "://" not in url:
# A URI like dns:mydomain:5000 or vsock:mycid:5000 was used
url = url.replace(":", "://", 1)
elif len(url_list) == 2 and "://" not in url and url_list[
0] in URIParseConfig.ACCEPTED_SCHEMES:
# A URI like dns:mydomain was used
url = url.replace(":", "://", 1)
else:
raise ValueError(f"Invalid address: {address}")
url_list = url.split("://")
if len(url_list) == 1:
# If a scheme was not explicitly specified in the URL
# we need to add a default scheme,
# because of how urlparse works
url = f'{URIParseConfig.DEFAULT_SCHEME}://{url}'
else:
# If a scheme was explicitly specified in the URL
# we need to make sure it is a valid scheme
scheme = url_list[0]
if scheme not in URIParseConfig.ACCEPTED_SCHEMES:
raise ValueError(f"Invalid scheme '{scheme}' in URL '{url}'")

# We should do a special check if the scheme is dns, and it uses
# an authority in the format of dns:[//authority/]host[:port]
if scheme.lower() == "dns":
# A URI like dns://authority/mydomain was used
url_list = url.split("/")
if len(url_list) < 4:
raise ValueError(f"Invalid dns authority '{url_list[2]}' in URL '{url}'")
self.authority = url_list[2]
url = f'dns://{url_list[3]}'
return url


def validate_path_and_query(path: str, query: str, scheme: str) -> None:
if path:
raise ValueError(f"Paths are not supported for gRPC endpoints: '{path}'")
if query:
query_dict = parse_qs(query)
if 'tls' in query_dict and scheme in ["http", "https"]:
raise ValueError(
f"The tls query parameter is not supported for http(s) endpoints: '{query}'")
query_dict.pop('tls', None)
if query_dict:
raise ValueError(f"Query parameters are not supported for gRPC endpoints: '{query}'")

try:
port = int(port)
except ValueError:
raise ValueError(f"invalid port: {port}")

return scheme, fqdn, port
def extract_tls_from_query(query: str, scheme: str) -> bool:
query_dict = parse_qs(query)
tls_str = query_dict.get('tls', [""])[0]
tls = tls_str.lower() == 'true'
if scheme == "https":
tls = True
return tls
16 changes: 4 additions & 12 deletions tests/clients/test_secure_dapr_async_grpc_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,25 @@ def tearDown(self):
@patch.object(settings, "DAPR_GRPC_ENDPOINT", "https://domain1.com:5000")
def test_init_with_DAPR_GRPC_ENDPOINT(self):
dapr = DaprGrpcClientAsync()
self.assertEqual("domain1.com", dapr._hostname)
self.assertEqual(5000, dapr._port)
self.assertEqual("https", dapr._scheme)
self.assertEqual("dns:domain1.com:5000", dapr._endpoint.get_endpoint())

@patch.object(settings, "DAPR_GRPC_ENDPOINT", "https://domain1.com:5000")
def test_init_with_DAPR_GRPC_ENDPOINT_and_argument(self):
dapr = DaprGrpcClientAsync("https://domain2.com:5002")
self.assertEqual("domain2.com", dapr._hostname)
self.assertEqual(5002, dapr._port)
self.assertEqual('https', dapr._scheme)
self.assertEqual("dns:domain2.com:5002", dapr._endpoint.get_endpoint())

@patch.object(settings, "DAPR_GRPC_ENDPOINT", "https://domain1.com:5000")
@patch.object(settings, "DAPR_RUNTIME_HOST", "domain2.com")
@patch.object(settings, "DAPR_GRPC_PORT", "5002")
def test_init_with_DAPR_GRPC_ENDPOINT_and_DAPR_RUNTIME_HOST(self):
dapr = DaprGrpcClientAsync()
self.assertEqual("domain1.com", dapr._hostname)
self.assertEqual(5000, dapr._port)
self.assertEqual('https', dapr._scheme)
self.assertEqual("dns:domain1.com:5000", dapr._endpoint.get_endpoint())

@patch.object(settings, "DAPR_RUNTIME_HOST", "domain1.com")
@patch.object(settings, "DAPR_GRPC_PORT", "5000")
def test_init_with_argument_and_DAPR_GRPC_ENDPOINT_and_DAPR_RUNTIME_HOST(self):
dapr = DaprGrpcClientAsync("https://domain2.com:5002")
self.assertEqual("domain2.com", dapr._hostname)
self.assertEqual(5002, dapr._port)
self.assertEqual('https', dapr._scheme)
self.assertEqual("dns:domain2.com:5002", dapr._endpoint.get_endpoint())

async def test_dapr_api_token_insertion(self):
pass
Expand Down
Loading