From 449bf30c66391f78e8dffdab365b3af205edc7ee Mon Sep 17 00:00:00 2001 From: gentlegiantJGC Date: Tue, 27 Feb 2024 12:05:32 +0000 Subject: [PATCH] Fixed get overloads Moved default None above tag variant. None matches the tag variant causing the wrong type hint. Added more get overloads --- src/amulet_nbt/__init__.pyi | 74 ++++++----- src/amulet_nbt/_tag/compound.pyx | 117 ++++++++---------- src/amulet_nbt/_tag/compound.pyx.tp | 21 +--- .../tpf/CompoundGetSetdefault.pyx.tpf | 8 +- 4 files changed, 100 insertions(+), 120 deletions(-) diff --git a/src/amulet_nbt/__init__.pyi b/src/amulet_nbt/__init__.pyi index 8a10099b..99d177ca 100644 --- a/src/amulet_nbt/__init__.pyi +++ b/src/amulet_nbt/__init__.pyi @@ -402,25 +402,7 @@ class CompoundTag(AbstractBaseMutableTag, MutableMapping[str | bytes, AnyNBT]): """A shallow copy of the CompoundTag as a python dictionary.""" @overload - def get( - self, - key: str | bytes, - default: _TagT = None, - cls: Type[_TagT] = AbstractBaseTag, - ) -> _TagT: ... - @overload - def get( - self, - key: str | bytes, - default: None = None, - cls: Type[_TagT] = AbstractBaseTag, - ) -> None: ... - def get( - self, - key: str | bytes, - default: _TagT | None = None, - cls: Type[_TagT] = AbstractBaseTag, - ) -> _TagT | None: + def get(self, key: str | bytes, default: None = None) -> AnyNBT | None: """Get an item from the CompoundTag. :param key: The key to get @@ -431,76 +413,92 @@ class CompoundTag(AbstractBaseMutableTag, MutableMapping[str | bytes, AnyNBT]): :raises: TypeError if the stored type is not a subclass of cls. """ + @overload + def get(self, key: str | bytes, default: _TagT = None) -> AnyNBT: ... + @overload + def get( + self, + key: str | bytes, + default: None = None, + cls: Type[_TagT] = AbstractBaseTag, + ) -> _TagT | None: ... + @overload + def get( + self, + key: str | bytes, + default: _TagT = None, + cls: Type[_TagT] = AbstractBaseTag, + ) -> _TagT: ... @staticmethod def fromkeys(keys: Iterable[str | bytes], value: AnyNBT = None): ... @overload - def get_byte(self, key: str | bytes, default: ByteTag = None) -> ByteTag: ... - @overload def get_byte(self, key: str | bytes, default: None = None) -> ByteTag | None: ... @overload - def get_short(self, key: str | bytes, default: ShortTag = None) -> ShortTag: ... + def get_byte(self, key: str | bytes, default: ByteTag = None) -> ByteTag: ... @overload def get_short(self, key: str | bytes, default: None = None) -> ShortTag | None: ... @overload - def get_int(self, key: str | bytes, default: IntTag = None) -> IntTag: ... + def get_short(self, key: str | bytes, default: ShortTag = None) -> ShortTag: ... @overload def get_int(self, key: str | bytes, default: None = None) -> IntTag | None: ... @overload - def get_long(self, key: str | bytes, default: LongTag = None) -> LongTag: ... + def get_int(self, key: str | bytes, default: IntTag = None) -> IntTag: ... @overload def get_long(self, key: str | bytes, default: None = None) -> LongTag | None: ... @overload - def get_float(self, key: str | bytes, default: FloatTag = None) -> FloatTag: ... + def get_long(self, key: str | bytes, default: LongTag = None) -> LongTag: ... @overload def get_float(self, key: str | bytes, default: None = None) -> FloatTag | None: ... @overload - def get_double(self, key: str | bytes, default: DoubleTag = None) -> DoubleTag: ... + def get_float(self, key: str | bytes, default: FloatTag = None) -> FloatTag: ... @overload def get_double( self, key: str | bytes, default: None = None ) -> DoubleTag | None: ... @overload - def get_string(self, key: str | bytes, default: StringTag = None) -> StringTag: ... + def get_double(self, key: str | bytes, default: DoubleTag = None) -> DoubleTag: ... @overload def get_string( self, key: str | bytes, default: None = None ) -> StringTag | None: ... @overload - def get_list(self, key: str | bytes, default: ListTag = None) -> ListTag: ... + def get_string(self, key: str | bytes, default: StringTag = None) -> StringTag: ... @overload def get_list(self, key: str | bytes, default: None = None) -> ListTag | None: ... @overload + def get_list(self, key: str | bytes, default: ListTag = None) -> ListTag: ... + @overload + def get_compound( + self, key: str | bytes, default: None = None + ) -> CompoundTag | None: ... + @overload def get_compound( self, key: str | bytes, default: CompoundTag = None ) -> CompoundTag: ... @overload - def get_compound( + def get_byte_array( self, key: str | bytes, default: None = None - ) -> CompoundTag | None: ... + ) -> ByteArrayTag | None: ... @overload def get_byte_array( self, key: str | bytes, default: ByteArrayTag = None ) -> ByteArrayTag: ... @overload - def get_byte_array( + def get_int_array( self, key: str | bytes, default: None = None - ) -> ByteArrayTag | None: ... + ) -> IntArrayTag | None: ... @overload def get_int_array( self, key: str | bytes, default: IntArrayTag = None ) -> IntArrayTag: ... @overload - def get_int_array( + def get_long_array( self, key: str | bytes, default: None = None - ) -> IntArrayTag | None: ... + ) -> LongArrayTag | None: ... @overload def get_long_array( self, key: str | bytes, default: LongArrayTag = None ) -> LongArrayTag: ... - @overload - def get_long_array( - self, key: str | bytes, default: None = None - ) -> LongArrayTag | None: ... def setdefault_byte( self, key: str | bytes, default: ByteTag | None = None ) -> ByteTag: ... diff --git a/src/amulet_nbt/_tag/compound.pyx b/src/amulet_nbt/_tag/compound.pyx index 12cceda9..46208d07 100644 --- a/src/amulet_nbt/_tag/compound.pyx +++ b/src/amulet_nbt/_tag/compound.pyx @@ -315,22 +315,13 @@ cdef class CompoundTag(AbstractBaseMutableTag): return wrap_node(&dereference(it).second) @overload - def get( - self, - key: str | bytes, - default: TagT = None, - cls: Type[TagT] = AbstractBaseTag, - ) -> TagT: - ... - + def get(self, key: str | bytes, default: None = None) -> AnyNBT | None:... @overload - def get( - self, - key: str | bytes, - default: None = None, - cls: Type[TagT] = AbstractBaseTag, - ) -> TagT | None: - ... + def get(self, key: str | bytes, default: _TagT = None) -> AnyNBT: ... + @overload + def get(self, key: str | bytes, default: None = None, cls: Type[_TagT] = AbstractBaseTag) -> _TagT | None: ... + @overload + def get(self, key: str | bytes, default: _TagT = None, cls: Type[_TagT] = AbstractBaseTag) -> _TagT: ... def get( self, @@ -457,16 +448,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_byte( self, key: str | bytes, - default: amulet_nbt.ByteTag = None, - ) -> amulet_nbt.ByteTag: + default: None = None, + ) -> amulet_nbt.ByteTag | None: ... @overload def get_byte( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.ByteTag | None: + default: amulet_nbt.ByteTag, + ) -> amulet_nbt.ByteTag: ... def get_byte(self, string key: str | bytes, ByteTag default: amulet_nbt.ByteTag | None = None) -> amulet_nbt.ByteTag | None: @@ -522,16 +513,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_short( self, key: str | bytes, - default: amulet_nbt.ShortTag = None, - ) -> amulet_nbt.ShortTag: + default: None = None, + ) -> amulet_nbt.ShortTag | None: ... @overload def get_short( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.ShortTag | None: + default: amulet_nbt.ShortTag, + ) -> amulet_nbt.ShortTag: ... def get_short(self, string key: str | bytes, ShortTag default: amulet_nbt.ShortTag | None = None) -> amulet_nbt.ShortTag | None: @@ -587,16 +578,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_int( self, key: str | bytes, - default: amulet_nbt.IntTag = None, - ) -> amulet_nbt.IntTag: + default: None = None, + ) -> amulet_nbt.IntTag | None: ... @overload def get_int( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.IntTag | None: + default: amulet_nbt.IntTag, + ) -> amulet_nbt.IntTag: ... def get_int(self, string key: str | bytes, IntTag default: amulet_nbt.IntTag | None = None) -> amulet_nbt.IntTag | None: @@ -652,16 +643,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_long( self, key: str | bytes, - default: amulet_nbt.LongTag = None, - ) -> amulet_nbt.LongTag: + default: None = None, + ) -> amulet_nbt.LongTag | None: ... @overload def get_long( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.LongTag | None: + default: amulet_nbt.LongTag, + ) -> amulet_nbt.LongTag: ... def get_long(self, string key: str | bytes, LongTag default: amulet_nbt.LongTag | None = None) -> amulet_nbt.LongTag | None: @@ -717,16 +708,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_float( self, key: str | bytes, - default: amulet_nbt.FloatTag = None, - ) -> amulet_nbt.FloatTag: + default: None = None, + ) -> amulet_nbt.FloatTag | None: ... @overload def get_float( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.FloatTag | None: + default: amulet_nbt.FloatTag, + ) -> amulet_nbt.FloatTag: ... def get_float(self, string key: str | bytes, FloatTag default: amulet_nbt.FloatTag | None = None) -> amulet_nbt.FloatTag | None: @@ -782,16 +773,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_double( self, key: str | bytes, - default: amulet_nbt.DoubleTag = None, - ) -> amulet_nbt.DoubleTag: + default: None = None, + ) -> amulet_nbt.DoubleTag | None: ... @overload def get_double( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.DoubleTag | None: + default: amulet_nbt.DoubleTag, + ) -> amulet_nbt.DoubleTag: ... def get_double(self, string key: str | bytes, DoubleTag default: amulet_nbt.DoubleTag | None = None) -> amulet_nbt.DoubleTag | None: @@ -847,16 +838,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_string( self, key: str | bytes, - default: amulet_nbt.StringTag = None, - ) -> amulet_nbt.StringTag: + default: None = None, + ) -> amulet_nbt.StringTag | None: ... @overload def get_string( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.StringTag | None: + default: amulet_nbt.StringTag, + ) -> amulet_nbt.StringTag: ... def get_string(self, string key: str | bytes, StringTag default: amulet_nbt.StringTag | None = None) -> amulet_nbt.StringTag | None: @@ -912,16 +903,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_list( self, key: str | bytes, - default: amulet_nbt.ListTag = None, - ) -> amulet_nbt.ListTag: + default: None = None, + ) -> amulet_nbt.ListTag | None: ... @overload def get_list( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.ListTag | None: + default: amulet_nbt.ListTag, + ) -> amulet_nbt.ListTag: ... def get_list(self, string key: str | bytes, ListTag default: amulet_nbt.ListTag | None = None) -> amulet_nbt.ListTag | None: @@ -977,16 +968,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_compound( self, key: str | bytes, - default: amulet_nbt.CompoundTag = None, - ) -> amulet_nbt.CompoundTag: + default: None = None, + ) -> amulet_nbt.CompoundTag | None: ... @overload def get_compound( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.CompoundTag | None: + default: amulet_nbt.CompoundTag, + ) -> amulet_nbt.CompoundTag: ... def get_compound(self, string key: str | bytes, CompoundTag default: amulet_nbt.CompoundTag | None = None) -> amulet_nbt.CompoundTag | None: @@ -1042,16 +1033,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_byte_array( self, key: str | bytes, - default: amulet_nbt.ByteArrayTag = None, - ) -> amulet_nbt.ByteArrayTag: + default: None = None, + ) -> amulet_nbt.ByteArrayTag | None: ... @overload def get_byte_array( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.ByteArrayTag | None: + default: amulet_nbt.ByteArrayTag, + ) -> amulet_nbt.ByteArrayTag: ... def get_byte_array(self, string key: str | bytes, ByteArrayTag default: amulet_nbt.ByteArrayTag | None = None) -> amulet_nbt.ByteArrayTag | None: @@ -1107,16 +1098,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_int_array( self, key: str | bytes, - default: amulet_nbt.IntArrayTag = None, - ) -> amulet_nbt.IntArrayTag: + default: None = None, + ) -> amulet_nbt.IntArrayTag | None: ... @overload def get_int_array( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.IntArrayTag | None: + default: amulet_nbt.IntArrayTag, + ) -> amulet_nbt.IntArrayTag: ... def get_int_array(self, string key: str | bytes, IntArrayTag default: amulet_nbt.IntArrayTag | None = None) -> amulet_nbt.IntArrayTag | None: @@ -1172,16 +1163,16 @@ cdef class CompoundTag(AbstractBaseMutableTag): def get_long_array( self, key: str | bytes, - default: amulet_nbt.LongArrayTag = None, - ) -> amulet_nbt.LongArrayTag: + default: None = None, + ) -> amulet_nbt.LongArrayTag | None: ... @overload def get_long_array( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.LongArrayTag | None: + default: amulet_nbt.LongArrayTag, + ) -> amulet_nbt.LongArrayTag: ... def get_long_array(self, string key: str | bytes, LongArrayTag default: amulet_nbt.LongArrayTag | None = None) -> amulet_nbt.LongArrayTag | None: diff --git a/src/amulet_nbt/_tag/compound.pyx.tp b/src/amulet_nbt/_tag/compound.pyx.tp index f4c53521..397336e6 100644 --- a/src/amulet_nbt/_tag/compound.pyx.tp +++ b/src/amulet_nbt/_tag/compound.pyx.tp @@ -221,22 +221,13 @@ cdef class CompoundTag(AbstractBaseMutableTag): return wrap_node(&dereference(it).second) @overload - def get( - self, - key: str | bytes, - default: TagT = None, - cls: Type[TagT] = AbstractBaseTag, - ) -> TagT: - ... - + def get(self, key: str | bytes, default: None = None) -> AnyNBT | None:... @overload - def get( - self, - key: str | bytes, - default: None = None, - cls: Type[TagT] = AbstractBaseTag, - ) -> TagT | None: - ... + def get(self, key: str | bytes, default: _TagT = None) -> AnyNBT: ... + @overload + def get(self, key: str | bytes, default: None = None, cls: Type[_TagT] = AbstractBaseTag) -> _TagT | None: ... + @overload + def get(self, key: str | bytes, default: _TagT = None, cls: Type[_TagT] = AbstractBaseTag) -> _TagT: ... def get( self, diff --git a/src/amulet_nbt/tpf/CompoundGetSetdefault.pyx.tpf b/src/amulet_nbt/tpf/CompoundGetSetdefault.pyx.tpf index 123324c5..4cb0aede 100644 --- a/src/amulet_nbt/tpf/CompoundGetSetdefault.pyx.tpf +++ b/src/amulet_nbt/tpf/CompoundGetSetdefault.pyx.tpf @@ -2,16 +2,16 @@ def get_{{tag_name}}( self, key: str | bytes, - default: amulet_nbt.{{py_cls}} = None, - ) -> amulet_nbt.{{py_cls}}: + default: None = None, + ) -> amulet_nbt.{{py_cls}} | None: ... @overload def get_{{tag_name}}( self, key: str | bytes, - default: None = None, - ) -> amulet_nbt.{{py_cls}} | None: + default: amulet_nbt.{{py_cls}}, + ) -> amulet_nbt.{{py_cls}}: ... def get_{{tag_name}}(self, string key: str | bytes, {{py_cls}} default: amulet_nbt.{{py_cls}} | None = None) -> amulet_nbt.{{py_cls}} | None: