Skip to content

Commit

Permalink
Add correct types to Scalar and Primitive constructors in `litera…
Browse files Browse the repository at this point in the history
…ls.py` (#2778)

* Add some type-information to `literals.py`

Signed-off-by: Felix Mulder <felix@Felixs-MBP.home>

* Fixup imports

Signed-off-by: Felix Mulder <felix.mulder@gmail.com>

---------

Signed-off-by: Felix Mulder <felix@Felixs-MBP.home>
Signed-off-by: Felix Mulder <felix.mulder@gmail.com>
  • Loading branch information
felixmulder authored Oct 1, 2024
1 parent cc4d27b commit 6e70129
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime as _datetime
from datetime import timedelta as _timedelta
from datetime import timezone as _timezone
from typing import Dict, Optional

Expand Down Expand Up @@ -48,12 +49,12 @@ def from_flyte_idl(cls, pb2_object):
class Primitive(_common.FlyteIdlEntity):
def __init__(
self,
integer=None,
float_value=None,
string_value=None,
boolean=None,
datetime=None,
duration=None,
integer: Optional[int] = None,
float_value: Optional[float] = None,
string_value: Optional[str] = None,
boolean: Optional[bool] = None,
datetime: Optional[_datetime] = None,
duration: Optional[_timedelta] = None,
):
"""
This object proxies the primitives supported by the Flyte IDL system. Only one value can be set.
Expand All @@ -77,35 +78,35 @@ def __init__(
self._duration = duration

@property
def integer(self):
def integer(self) -> Optional[int]:
"""
:rtype: int
"""
return self._integer

@property
def float_value(self):
def float_value(self) -> Optional[float]:
"""
:rtype: float
"""
return self._float_value

@property
def string_value(self):
def string_value(self) -> Optional[str]:
"""
:rtype: Text
"""
return self._string_value

@property
def boolean(self):
def boolean(self) -> Optional[bool]:
"""
:rtype: bool
"""
return self._boolean

@property
def datetime(self):
def datetime(self) -> Optional[_datetime]:
"""
:rtype: datetime.datetime
"""
Expand All @@ -114,7 +115,7 @@ def datetime(self):
return self._datetime.replace(tzinfo=_timezone.utc)

@property
def duration(self):
def duration(self) -> Optional[_timedelta]:
"""
:rtype: datetime.timedelta
"""
Expand Down Expand Up @@ -703,15 +704,15 @@ def from_flyte_idl(cls, pb2_object):
class Scalar(_common.FlyteIdlEntity):
def __init__(
self,
primitive: Primitive = None,
blob: Blob = None,
binary: Binary = None,
schema: Schema = None,
union: Union = None,
none_type: Void = None,
error: Error = None,
generic: Struct = None,
structured_dataset: StructuredDataset = None,
primitive: Optional[Primitive] = None,
blob: Optional[Blob] = None,
binary: Optional[Binary] = None,
schema: Optional[Schema] = None,
union: Optional[Union] = None,
none_type: Optional[Void] = None,
error: Optional[Error] = None,
generic: Optional[Struct] = None,
structured_dataset: Optional[StructuredDataset] = None,
):
"""
Scalar wrapper around Flyte types. Only one can be specified.
Expand Down

0 comments on commit 6e70129

Please sign in to comment.