diff --git a/src/gallia/dumpcap.py b/src/gallia/dumpcap.py index 758acf822..7d033f737 100644 --- a/src/gallia/dumpcap.py +++ b/src/gallia/dumpcap.py @@ -17,8 +17,9 @@ from urllib.parse import urlparse from gallia.log import get_logger +from gallia.net import split_host_port from gallia.transports import TargetURI, TransportScheme -from gallia.utils import auto_int, handle_task_error, set_task_handler_ctx_variable, split_host_port +from gallia.utils import auto_int, handle_task_error, set_task_handler_ctx_variable logger = get_logger(__name__) diff --git a/src/gallia/net.py b/src/gallia/net.py index 86f379dac..756ca68bd 100644 --- a/src/gallia/net.py +++ b/src/gallia/net.py @@ -2,7 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 +import ipaddress import subprocess +from urllib.parse import urlparse import pydantic from pydantic.networks import IPvAnyAddress @@ -13,6 +15,39 @@ logger = get_logger(__name__) +def split_host_port( + hostport: str, + default_port: int | None = None, +) -> tuple[str, int | None]: + """Splits a combination of ip address/hostname + port into hostname/ip address + and port. The default_port argument can be used to return a port if it is + absent in the hostport argument.""" + # Special case: If hostport is an ipv6 then the urlparser does some weird + # things with the colons and tries to parse ports. Catch this case early. + host = "" + port = default_port + try: + # If hostport is a valid ip address (v4 or v6) there + # is no port included + host = str(ipaddress.ip_address(hostport)) + except ValueError: + pass + + # Only parse if hostport is not a valid ip address. + if host == "": + # urlparse() and urlsplit() insists on absolute URLs starting with "//". + url = urlparse(f"//{hostport}") + host = url.hostname if url.hostname else url.netloc + port = url.port if url.port else default_port + return host, port + + +def join_host_port(host: str, port: int) -> str: + if ":" in host: + return f"[{host}]:port" + return f"{host}:{port}" + + class AddrInfo(pydantic.BaseModel): family: str local: IPvAnyAddress diff --git a/src/gallia/transports/base.py b/src/gallia/transports/base.py index bc9672c8f..c11bd8c18 100644 --- a/src/gallia/transports/base.py +++ b/src/gallia/transports/base.py @@ -10,8 +10,8 @@ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from gallia.log import get_logger +from gallia.net import join_host_port from gallia.transports.schemes import TransportScheme -from gallia.utils import join_host_port logger = get_logger(__name__) diff --git a/src/gallia/utils.py b/src/gallia/utils.py index 8878e9b6a..357c34dbf 100644 --- a/src/gallia/utils.py +++ b/src/gallia/utils.py @@ -7,7 +7,6 @@ import asyncio import contextvars import importlib.util -import ipaddress import logging import re import sys @@ -16,7 +15,6 @@ from pathlib import Path from types import ModuleType from typing import TYPE_CHECKING, Any -from urllib.parse import urlparse import aiofiles @@ -45,39 +43,6 @@ def strtobool(val: str) -> bool: raise ValueError(f"invalid truth value {val!r}") -def split_host_port( - hostport: str, - default_port: int | None = None, -) -> tuple[str, int | None]: - """Splits a combination of ip address/hostname + port into hostname/ip address - and port. The default_port argument can be used to return a port if it is - absent in the hostport argument.""" - # Special case: If hostport is an ipv6 then the urlparser does some weird - # things with the colons and tries to parse ports. Catch this case early. - host = "" - port = default_port - try: - # If hostport is a valid ip address (v4 or v6) there - # is no port included - host = str(ipaddress.ip_address(hostport)) - except ValueError: - pass - - # Only parse if hostport is not a valid ip address. - if host == "": - # urlparse() and urlsplit() insists on absolute URLs starting with "//". - url = urlparse(f"//{hostport}") - host = url.hostname if url.hostname else url.netloc - port = url.port if url.port else default_port - return host, port - - -def join_host_port(host: str, port: int) -> str: - if ":" in host: - return f"[{host}]:port" - return f"{host}:{port}" - - def camel_to_snake(s: str) -> str: """Convert a CamelCase string to a snake_case string.""" # https://stackoverflow.com/a/1176023 diff --git a/tests/pytest/test_helpers.py b/tests/pytest/test_helpers.py index 57b762120..b90a6c6c0 100644 --- a/tests/pytest/test_helpers.py +++ b/tests/pytest/test_helpers.py @@ -5,11 +5,11 @@ import pytest from gallia.log import setup_logging +from gallia.net import split_host_port from gallia.services.uds.core.utils import ( address_and_size_length, uds_memory_parameters, ) -from gallia.utils import split_host_port setup_logging()