Skip to content

Commit

Permalink
Refactor parse() methods to conform with naming conventions (#352)
Browse files Browse the repository at this point in the history
* Refactor `parse()` methods to conform with naming conventions

* Rename test param to reflect its type
  • Loading branch information
pederhan authored Dec 2, 2024
1 parent f5edd2c commit 2d4a87c
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 65 deletions.
22 changes: 7 additions & 15 deletions mreg_cli/api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,46 +21,38 @@ class MACAddressField(FrozenModel):

address: MacAddress

@classmethod
def validate(cls, value: str | MacAddress | Self) -> Self:
"""Validate a MAC address and return it as a string."""
try:
return cls.validate_naive(value)
except ValueError as e:
raise InputFailure(e) from e

# HACK: extremely hacky workaround for our custom exceptions always
# logging errors/warnings even when caught.
@classmethod
def validate_naive(cls, value: str | MacAddress | Self) -> Self:
def validate(cls, value: str | MacAddress | Self) -> Self:
"""Validate but raise built-in exceptions on failure."""
if isinstance(value, MACAddressField):
return cls.validate_naive(value.address)
return cls.validate(value.address)
try:
return cls(address=value) # pyright: ignore[reportArgumentType]
except ValidationError as e:
raise ValueError(f"Invalid MAC address '{value}'") from e
raise InputFailure(f"Invalid MAC address '{value}'") from e

@classmethod
def parse(cls, obj: Any) -> MacAddress:
def parse_or_raise(cls, obj: Any) -> MacAddress:
"""Parse a MAC address from a string. Returns the MAC address as a string.
:param obj: The object to parse.
:returns: The MAC address as a string.
:raises ValueError: If the object is not a valid MAC address.
"""
# Match interface of NetworkOrIP.parse
# Match interface of NetworkOrIP.parse_or_raise
return cls.validate(obj).address

@classmethod
def parse_optional(cls, obj: Any) -> MacAddress | None:
def parse(cls, obj: Any) -> MacAddress | None:
"""Parse a MAC address from a string. Returns None if the MAC address is invalid.
:param obj: The object to parse.
:returns: The MAC address as a string or None if it is invalid.
"""
try:
return cls.parse(obj)
return cls.parse_or_raise(obj)
except ValueError:
return None

Expand Down
90 changes: 52 additions & 38 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable, ClassVar, Iterable, List, Literal, Self, cast, overload

from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator
from pydantic import ValidationError as PydanticValidationError
from pydantic_extra_types.mac_address import MacAddress
from typing_extensions import Unpack

Expand Down Expand Up @@ -76,39 +77,41 @@ def validate(cls, value: str | IP_AddressT | IP_NetworkT | Self) -> Self:
return cls.validate(value.ip_or_network)
try:
return cls(ip_or_network=value) # pyright: ignore[reportArgumentType] # validator handles this
except ValidationError as e:
except PydanticValidationError as e:
raise InputFailure(f"Invalid IP address or network: {value}") from e

@overload
@classmethod
def parse(cls, value: Any, mode: None = None) -> IP_AddressT | IP_NetworkT: ...
def parse_or_raise(cls, value: Any, mode: None = None) -> IP_AddressT | IP_NetworkT: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["ip"]) -> IP_AddressT: ...
def parse_or_raise(cls, value: Any, mode: Literal["ip"]) -> IP_AddressT: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["ipv4"]) -> ipaddress.IPv4Address: ...
def parse_or_raise(cls, value: Any, mode: Literal["ipv4"]) -> ipaddress.IPv4Address: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["ipv6"]) -> ipaddress.IPv6Address: ...
def parse_or_raise(cls, value: Any, mode: Literal["ipv6"]) -> ipaddress.IPv6Address: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["network"]) -> IP_NetworkT: ...
def parse_or_raise(cls, value: Any, mode: Literal["network"]) -> IP_NetworkT: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["networkv4"]) -> ipaddress.IPv4Network: ...
def parse_or_raise(cls, value: Any, mode: Literal["networkv4"]) -> ipaddress.IPv4Network: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["networkv6"]) -> ipaddress.IPv6Network: ...
def parse_or_raise(cls, value: Any, mode: Literal["networkv6"]) -> ipaddress.IPv6Network: ...

@classmethod
def parse(cls, value: Any, mode: IPNetMode | None = None) -> IP_AddressT | IP_NetworkT:
def parse_or_raise(
cls, value: Any, mode: IPNetMode | None = None
) -> IP_AddressT | IP_NetworkT:
"""Parse a value as an IP address or network.
Optionally specify the mode to validate the input as.
Expand All @@ -131,35 +134,46 @@ def parse(cls, value: Any, mode: IPNetMode | None = None) -> IP_AddressT | IP_Ne
return func(ipnet)
return ipnet.ip_or_network

@overload
@classmethod
def parse_ip_optional(
cls, ip: str, version: Literal[4, 6] | None = None
) -> IP_AddressT | None:
"""Check if a value is a valid IP address.
def parse(cls, value: Any, mode: None = None) -> IP_AddressT | IP_NetworkT | None: ...

:param ip: The IP address to parse.
:param version: The IP version to parse as. Parses as any version if None.
:returns: The parsed IP address or None if parsing fails.
"""
mode = "ipv4" if version == 4 else "ipv6" if version == 6 else "ip"
try:
return cls.parse(ip, mode=mode)
except ValueError:
return None
@overload
@classmethod
def parse(cls, value: Any, mode: Literal["ip"]) -> IP_AddressT | None: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["ipv4"]) -> ipaddress.IPv4Address | None: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["ipv6"]) -> ipaddress.IPv6Address | None: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["network"]) -> IP_NetworkT | None: ...

@overload
@classmethod
def parse(cls, value: Any, mode: Literal["networkv4"]) -> ipaddress.IPv4Network | None: ...

@overload
@classmethod
def parse_network_optional(
cls, ip: str, version: Literal[4, 6] | None = None
) -> IP_NetworkT | None:
"""Parse a value as an IP network. Returns None if parsing fails.
def parse(cls, value: Any, mode: Literal["networkv6"]) -> ipaddress.IPv6Network | None: ...

:param ip: The IP network to parse.
:param version: The IP version to parse as. Parses as any version if None.
:returns: The parsed IP network or None if parsing fails.
@classmethod
def parse(cls, value: Any, mode: IPNetMode | None = None) -> IP_AddressT | IP_NetworkT | None:
"""Parse a value as an IP address or network, or None if parsing fails.
Optionally specify the mode to validate the input as.
:param value:The value to parse.
:param mode: The mode to validate the input as.
:returns: The parsed value as an IP address or network, or None.
"""
mode = "networkv4" if version == 4 else "networkv6" if version == 6 else "network"
try:
return cls.parse(ip, mode=mode)
return cls.parse_or_raise(value, mode)
except ValueError:
return None

Expand Down Expand Up @@ -1617,7 +1631,7 @@ def __hash__(self):
def ip_network(self) -> IP_NetworkT:
"""IP network object for the network."""
try:
return NetworkOrIP.parse(self.network, mode="network")
return NetworkOrIP.parse_or_raise(self.network, mode="network")
except IPNetworkWarning as e:
logger.error(
"Invalid network address %s for network with ID %s", self.network, self.id
Expand Down Expand Up @@ -1744,7 +1758,7 @@ def output(self, padding: int = 25) -> None:
def fmt(label: str, value: Any) -> None:
manager.add_line(f"{label:<{padding}}{value}")

ipnet = NetworkOrIP.parse(self.network, mode="network")
ipnet = NetworkOrIP.parse_or_raise(self.network, mode="network")
reserved_ips = self.get_reserved_ips()
# Remove network address and broadcast address from reserved IPs
reserved_ips_filtered = [
Expand Down Expand Up @@ -1836,9 +1850,9 @@ def overlaps(self, other: Network | str | IP_NetworkT) -> bool:
if isinstance(other, Network):
other = other.network
if isinstance(other, str):
other = NetworkOrIP.parse(other, mode="network")
other = NetworkOrIP.parse_or_raise(other, mode="network")

self_net = NetworkOrIP.parse(self.network, mode="network")
self_net = NetworkOrIP.parse_or_raise(self.network, mode="network")
return self_net.overlaps(other)

def get_first_available_ip(self) -> IP_AddressT:
Expand Down Expand Up @@ -2763,11 +2777,11 @@ def get_by_any_means(
if identifier.isdigit():
return Host.get_by_id(int(identifier))

if ip := NetworkOrIP.parse_ip_optional(identifier):
if ip := NetworkOrIP.parse(identifier, mode="ip"):
host = cls.get_by_ip_or_raise(ip, inform_as_ptr=inform_as_ptr)
return host

if mac := MACAddressField.parse_optional(identifier):
if mac := MACAddressField.parse(identifier):
return cls.get_by_mac_or_raise(mac)

# Let us try to find the host by name...
Expand Down Expand Up @@ -2861,7 +2875,7 @@ def has_ip_with_mac(self, arg_mac: MACAddressField | str) -> IPAddress | None:
:returns: The IP address object if found, None otherwise.
"""
if not isinstance(arg_mac, MACAddressField):
arg_mac = MACAddressField.validate_naive(arg_mac)
arg_mac = MACAddressField.validate(arg_mac)
return next((ip for ip in self.ipaddresses if ip.macaddress == arg_mac), None)

def ips_with_macaddresses(self) -> list[IPAddress]:
Expand Down
6 changes: 3 additions & 3 deletions mreg_cli/commands/host_submodules/a_aaaa.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _ip_change(args: argparse.Namespace, ipversion: IP_Version) -> None:
if args.old == args.new:
raise EntityAlreadyExists("New and old IP are equal")

old_ip = NetworkOrIP.parse(args.old, mode="ip")
old_ip = NetworkOrIP.parse_or_raise(args.old, mode="ip")

new_ip = NetworkOrIP.validate(args.new)
if new_ip.is_network():
Expand Down Expand Up @@ -88,7 +88,7 @@ def _ip_move(args: argparse.Namespace, ipversion: IP_Version) -> None:
:param args: argparse.Namespace (ip, fromhost, tohost)
:param ipversion: 4 or 6
"""
ip = NetworkOrIP.parse(args.ip, mode="ip")
ip = NetworkOrIP.parse_or_raise(args.ip, mode="ip")
if ip.version != ipversion:
raise InputFailure(
f"IP version {ip.version} does not match the requested version {ipversion}"
Expand Down Expand Up @@ -123,7 +123,7 @@ def _ip_remove(args: argparse.Namespace, ipversion: IP_Version) -> None:
:param args: argparse.Namespace (name, ip)
"""
host = Host.get_by_any_means_or_raise(args.name)
ip = NetworkOrIP.parse(args.ip, mode="ip")
ip = NetworkOrIP.parse_or_raise(args.ip, mode="ip")
if ip.version != ipversion:
raise InputFailure(
f"IP version {ip.version} does not match the requested version {ipversion}"
Expand Down
8 changes: 4 additions & 4 deletions mreg_cli/commands/host_submodules/rr.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def ptr_change(args: argparse.Namespace) -> None:
if not old_host.ptr_overrides:
raise EntityNotFound(f"No PTR records for {old_host}")

ip = NetworkOrIP.parse(args.ip, mode="ip")
ip = NetworkOrIP.parse_or_raise(args.ip, mode="ip")
ptr_override = old_host.get_ptr_override(ip)
if not ptr_override:
raise EntityNotFound(f"No PTR record for {old_host} with IP {ip}")
Expand Down Expand Up @@ -503,7 +503,7 @@ def ptr_remove(args: argparse.Namespace) -> None:
:param args: argparse.Namespace (ip, name)
"""
host = Host.get_by_any_means_or_raise(args.name)
ip = NetworkOrIP.parse(args.ip, mode="ip")
ip = NetworkOrIP.parse_or_raise(args.ip, mode="ip")
ptr_override = host.get_ptr_override(ip)
if not ptr_override:
raise EntityNotFound(f"No PTR record for {host} with IP {ip}")
Expand All @@ -529,7 +529,7 @@ def ptr_add(args: argparse.Namespace) -> None:
:param args: argparse.Namespace (ip, name, force)
"""
ip = NetworkOrIP.parse(args.ip, mode="ip")
ip = NetworkOrIP.parse_or_raise(args.ip, mode="ip")

host = Host.get_by_any_means_or_raise(args.name)
existing_ptrs = PTR_override.get_list_by_field("ipaddress", str(ip))
Expand Down Expand Up @@ -563,7 +563,7 @@ def ptr_show(args: argparse.Namespace) -> None:
:param args: argparse.Namespace (ip)
"""
ip = NetworkOrIP.parse(args.ip, mode="ip")
ip = NetworkOrIP.parse_or_raise(args.ip, mode="ip")
host = Host.get_by_any_means_or_raise(str(ip), inform_as_ptr=False)
if not host.ptr_overrides:
OutputManager().add_line(f"No PTR records for {host.name.hostname}")
Expand Down
2 changes: 1 addition & 1 deletion mreg_cli/commands/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create(args: argparse.Namespace) -> None:
if args.location and not is_valid_location_tag(args.location):
raise InputFailure("Not a valid location tag")

arg_network = NetworkOrIP.parse(args.network, mode="network")
arg_network = NetworkOrIP.parse_or_raise(args.network, mode="network")
networks = Network.get_list()
for network in networks:
if network.overlaps(arg_network):
Expand Down
4 changes: 2 additions & 2 deletions mreg_cli/commands/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def network_list(args: argparse.Namespace) -> None:
permissions = Permission.get_by_query(query=params, ordering="range,group", limit=None)

if args.range is not None:
argnetwork = NetworkOrIP.parse(args.range, mode="network")
argnetwork = NetworkOrIP.parse_or_raise(args.range, mode="network")

for permission in permissions:
permnet = permission.range
Expand Down Expand Up @@ -107,7 +107,7 @@ def network_add(args: argparse.Namespace) -> None:
:param args: argparse.Namespace (range, group, regex)
"""
NetworkOrIP.parse(args.range, mode="network")
NetworkOrIP.parse_or_raise(args.range, mode="network")

query = {
"range": args.range,
Expand Down
28 changes: 26 additions & 2 deletions tests/api/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

from datetime import datetime
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network
from typing import Any
from typing import Any, Callable

import pytest

from mreg_cli.api.models import IPNetMode, Network, NetworkOrIP
from mreg_cli.exceptions import (
InputFailure,
InvalidIPAddress,
InvalidIPv4Address,
InvalidIPv6Address,
Expand Down Expand Up @@ -80,10 +81,33 @@
)
def test_network_or_ip_parse(inp: str, mode: IPNetMode, expect: Any) -> None:
"""Test the network or IP address from string."""
res = NetworkOrIP.parse(inp, mode)
res = NetworkOrIP.parse_or_raise(inp, mode)
assert res == expect


@pytest.mark.parametrize(
"inp, expect_type_call",
[
("192.168.0.1", NetworkOrIP.is_ipv4),
("192.168.0.0/24", NetworkOrIP.is_ipv4_network),
("2001:db8::1", NetworkOrIP.is_ipv6),
("2001:db8::/64", NetworkOrIP.is_ipv6_network),
("2001:db8::/", NetworkOrIP.is_ipv6), # valid address because validator removes suffix
pytest.param(
"192.168.0.0/33", None, marks=pytest.mark.xfail(raises=InputFailure, strict=True)
),
pytest.param(
"2001:db8::/129", None, marks=pytest.mark.xfail(raises=InputFailure, strict=True)
),
],
)
def test_network_or_ip_validate(inp: Any, expect_type_call: Callable[[NetworkOrIP], bool]) -> None:
"""Test the validation of network or IP address."""
res = NetworkOrIP.validate(inp)
# Ensure it's validated as the correct type
assert expect_type_call(res)


@pytest.mark.parametrize(
"inp, expect",
[
Expand Down

0 comments on commit 2d4a87c

Please sign in to comment.