Skip to content

Commit

Permalink
Add and improve type annotations (partial #5358)
Browse files Browse the repository at this point in the history
  • Loading branch information
nadove-ucsc committed Sep 6, 2023
1 parent 5746788 commit 20b2270
Show file tree
Hide file tree
Showing 11 changed files with 57 additions and 47 deletions.
6 changes: 3 additions & 3 deletions src/azul/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def inner_func(t):
return inner_func


def none_safe_itemgetter(*items):
def none_safe_itemgetter(*items: str) -> Callable:
"""
Like `itemgetter` except that the returned callable returns `None`
(or a tuple of `None`) if it's passed None.
Expand Down Expand Up @@ -157,7 +157,7 @@ def f(v):
return f


def compose_keys(f, g):
def compose_keys(f: Callable, g: Callable) -> Callable:
"""
Composes unary functions.
Expand Down Expand Up @@ -270,7 +270,7 @@ def __init__(self, depth: int, leaf_factory):
if depth else
leaf_factory)

def to_dict(self):
def to_dict(self) -> dict:
return {
k: v.to_dict() if isinstance(v, NestedDict) else v
for k, v in self.items()
Expand Down
2 changes: 1 addition & 1 deletion src/azul/indexer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def from_json(cls, json: SourcedBundleFQIDJSON) -> 'SourcedBundleFQID':
source = cls.source_ref_cls().from_json(json.pop('source'))
return cls(source=source, **json)

def upcast(self):
def upcast(self) -> BundleFQID:
return BundleFQID(uuid=self.uuid,
version=self.version)

Expand Down
8 changes: 6 additions & 2 deletions src/azul/indexer/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from collections.abc import (
Iterable,
)
from typing import (
Optional,
)

from azul.indexer import (
BundlePartition,
Expand All @@ -14,6 +17,7 @@
)
from azul.indexer.document import (
Contribution,
EntityType,
FieldTypes,
)

Expand All @@ -22,7 +26,7 @@ class Transformer(metaclass=ABCMeta):

@classmethod
@abstractmethod
def entity_type(cls) -> str:
def entity_type(cls) -> EntityType:
"""
The type of entity this transformer creates and aggregates
contributions for.
Expand Down Expand Up @@ -63,7 +67,7 @@ def transform(self, partition: BundlePartition) -> Iterable[Contribution]:

@classmethod
@abstractmethod
def get_aggregator(cls, entity_type) -> EntityAggregator:
def get_aggregator(cls, entity_type: EntityType) -> Optional[EntityAggregator]:
"""
Returns the aggregator to be used for entities of the given type that
occur in the document to be aggregated. A document for an entity of
Expand Down
12 changes: 6 additions & 6 deletions src/azul/openapi/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class optional(NamedTuple):


# noinspection PyShadowingBuiltins
def object(additional_properties=False, **props: Union[TYPE, optional]):
def object(additional_properties=False, **props: Union[TYPE, optional]) -> JSON:
"""
>>> from azul.doctests import assert_json
>>> assert_json(object(x=int, y=int, relative=optional(bool)))
Expand Down Expand Up @@ -96,7 +96,7 @@ def object(additional_properties=False, **props: Union[TYPE, optional]):
additionalProperties=additional_properties)


def properties(**props: TYPE):
def properties(**props: TYPE) -> JSON:
"""
Returns a JSON schema `properties` attribute value.
Expand All @@ -115,7 +115,7 @@ def properties(**props: TYPE):
return {name: make_type(prop) for name, prop in props.items()}


def array(item: TYPE, *items: TYPE, **kwargs):
def array(item: TYPE, *items: TYPE, **kwargs) -> JSON:
"""
Returns the schema for an array of items of a given type, or a sequence of
types.
Expand Down Expand Up @@ -213,7 +213,7 @@ def enum(*items: PrimitiveJSON, type_: TYPE = None) -> JSON:
}


def pattern(regex: Union[str, re.Pattern], _type: TYPE = str):
def pattern(regex: Union[str, re.Pattern], _type: TYPE = str) -> JSON:
"""
Returns schema for a JSON string matching the given pattern.
Expand Down Expand Up @@ -463,7 +463,7 @@ def make_type(t: TYPE) -> JSON:
assert False, type(t)


def union(*ts: TYPE, for_openapi: bool = True):
def union(*ts: TYPE, for_openapi: bool = True) -> JSON:
"""
The union of one or more types.
Expand Down Expand Up @@ -496,7 +496,7 @@ def union(*ts: TYPE, for_openapi: bool = True):
return {'anyOf': ts}


def nullable(t: TYPE, for_openapi: bool = True):
def nullable(t: TYPE, for_openapi: bool = True) -> JSON:
"""
Given a schema, return a schema that additionally permits the `null` value.
Expand Down
6 changes: 4 additions & 2 deletions src/azul/plugins/metadata/hca/indexer/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@
BundlePartition,
)
from azul.indexer.aggregate import (
EntityAggregator,
SimpleAggregator,
)
from azul.indexer.document import (
ClosedRange,
Contribution,
ContributionCoordinates,
EntityReference,
EntityType,
FieldType,
FieldTypes,
Nested,
Expand Down Expand Up @@ -470,7 +472,7 @@ def __init__(self,
...

@classmethod
def get_aggregator(cls, entity_type):
def get_aggregator(cls, entity_type: EntityType) -> Optional[EntityAggregator]:
if entity_type == 'files':
return FileAggregator()
elif entity_type in SampleTransformer.inner_entity_types():
Expand Down Expand Up @@ -1699,7 +1701,7 @@ def _singleton_entity(self) -> DatedEntity:
return BundleAsEntity(self.api_bundle)

@classmethod
def get_aggregator(cls, entity_type):
def get_aggregator(cls, entity_type: EntityType) -> Optional[EntityAggregator]:
if entity_type == 'files':
return None
else:
Expand Down
46 changes: 24 additions & 22 deletions src/azul/plugins/metadata/hca/service/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
AnyJSON,
JSON,
JSONs,
MutableJSON,
MutableJSONs,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -185,7 +187,7 @@ def __init__(self, aggs: JSON):
super().__init__()
self.aggs = aggs

def make_response(self):
def make_response(self) -> SummaryResponse:
def agg_value(*path: str) -> AnyJSON:
agg = self.aggs
for name in path:
Expand Down Expand Up @@ -294,24 +296,24 @@ def __init__(self,
self.entity_type = entity_type
self.catalog = catalog

def make_response(self):
def make_response(self) -> SearchResponse:
return SearchResponse(pagination=self.pagination,
termFacets=self.make_facets(),
hits=self.make_hits())

def make_bundles(self, entry):
def make_bundles(self, entry) -> MutableJSONs:
return [
{"bundleUuid": b["uuid"], "bundleVersion": b["version"]}
for b in entry["bundles"]
]

def make_sources(self, entry):
def make_sources(self, entry) -> MutableJSONs:
return [
{'sourceId': s['id'], 'sourceSpec': s['spec']}
for s in entry['sources']
]

def make_protocols(self, entry):
def make_protocols(self, entry) -> MutableJSONs:
return [
*(
{
Expand Down Expand Up @@ -340,7 +342,7 @@ def make_protocols(self, entry):
)
]

def make_dates(self, entry):
def make_dates(self, entry) -> MutableJSONs:
return [
{
'aggregateLastModifiedDate': dates['aggregate_last_modified_date'],
Expand All @@ -353,7 +355,7 @@ def make_dates(self, entry):
for dates in entry['contents']['dates']
]

def make_projects(self, entry):
def make_projects(self, entry) -> MutableJSONs:
projects = []
contents = entry['contents']
for project in contents["projects"]:
Expand Down Expand Up @@ -420,7 +422,7 @@ def make_translated_file(self, file: JSON) -> JSON:
}
return translated_file

def make_specimen(self, specimen):
def make_specimen(self, specimen) -> MutableJSON:
return {
"id": specimen["biomaterial_id"],
"organ": specimen.get("organ", None),
Expand All @@ -430,7 +432,7 @@ def make_specimen(self, specimen):
"source": specimen.get("_source", None)
}

def make_specimens(self, entry):
def make_specimens(self, entry) -> MutableJSONs:
return [
self.make_specimen(specimen)
for specimen in entry["contents"]["specimens"]
Expand All @@ -444,32 +446,32 @@ def make_specimens(self, entry):
('totalCellsRedundant', 'total_estimated_cells_redundant')
]

def make_cell_suspension(self, cell_suspension):
def make_cell_suspension(self, cell_suspension) -> MutableJSON:
return {
k: cell_suspension.get(v, None)
for k, v in self.cell_suspension_fields
}

def make_cell_suspensions(self, entry):
def make_cell_suspensions(self, entry) -> MutableJSONs:
return [
self.make_cell_suspension(cs)
for cs in entry["contents"]["cell_suspensions"]
]

def make_cell_line(self, cell_line):
def make_cell_line(self, cell_line) -> MutableJSON:
return {
"id": cell_line["biomaterial_id"],
"cellLineType": cell_line.get("cell_line_type", None),
"modelOrgan": cell_line.get("model_organ", None),
}

def make_cell_lines(self, entry):
def make_cell_lines(self, entry) -> MutableJSONs:
return [
self.make_cell_line(cell_line)
for cell_line in entry["contents"]["cell_lines"]
]

def make_donor(self, donor):
def make_donor(self, donor) -> MutableJSON:
return {
"id": donor["biomaterial_id"],
"donorCount": donor.get("donor_count", None),
Expand All @@ -481,23 +483,23 @@ def make_donor(self, donor):
"disease": donor.get("diseases", None)
}

def make_donors(self, entry):
def make_donors(self, entry) -> MutableJSONs:
return [self.make_donor(donor) for donor in entry["contents"]["donors"]]

def make_organoid(self, organoid):
def make_organoid(self, organoid) -> MutableJSON:
return {
"id": organoid["biomaterial_id"],
"modelOrgan": organoid.get("model_organ", None),
"modelOrganPart": organoid.get("model_organ_part", None)
}

def make_organoids(self, entry):
def make_organoids(self, entry) -> MutableJSONs:
return [
self.make_organoid(organoid)
for organoid in entry["contents"]["organoids"]
]

def make_sample(self, sample, entity_dict, entity_type):
def make_sample(self, sample, entity_dict, entity_type) -> MutableJSON:
is_aggregate = isinstance(sample['document_id'], list)
organ_prop = 'organ' if entity_type == 'specimens' else 'model_organ'
return {
Expand All @@ -506,7 +508,7 @@ def make_sample(self, sample, entity_dict, entity_type):
**entity_dict
}

def make_samples(self, entry):
def make_samples(self, entry) -> MutableJSONs:
pieces = [
(self.make_cell_line, 'cellLines', 'sample_cell_lines'),
(self.make_organoid, 'organoids', 'sample_organoids'),
Expand All @@ -518,10 +520,10 @@ def make_samples(self, entry):
for sample in entry['contents'].get(sample_entity_type, [])
]

def make_hits(self):
def make_hits(self) -> MutableJSONs:
return list(map(self.make_hit, self.hits))

def make_hit(self, es_hit):
def make_hit(self, es_hit) -> MutableJSON:
hit = Hit(protocols=self.make_protocols(es_hit),
entryId=es_hit['entity_id'],
sources=self.make_sources(es_hit),
Expand Down Expand Up @@ -606,7 +608,7 @@ def choose_entry(_term):
# https://github.com/DataBiosphere/azul/issues/2460
type='terms')

def make_facets(self):
def make_facets(self) -> MutableJSON:
facets = {}
for facet, agg in self.aggs.items():
if facet != '_project_agg': # Filter out project specific aggs
Expand Down
2 changes: 1 addition & 1 deletion src/azul/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def from_json(cls, json: JSON) -> 'Filters':
return cls(explicit=json['explicit'],
source_ids=set(json['source_ids']))

def to_json(self):
def to_json(self) -> JSON:
return {
'explicit': self.explicit,
'source_ids': sorted(self.source_ids)
Expand Down
2 changes: 1 addition & 1 deletion src/azul/service/manifest_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def command_lines(cls,
}

@classmethod
def _option(cls, s: str):
def _option(cls, s: str) -> str:
"""
>>> f = CurlManifestGenerator._option
>>> f('')
Expand Down
4 changes: 2 additions & 2 deletions src/azul/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
)


def to_camel_case(text: str):
def to_camel_case(text: str) -> str:
camel_cased = ''.join(part.title() for part in text.split('_'))
return camel_cased[0].lower() + camel_cased[1:]


def departition(before, sep, after):
def departition(before: Optional[str], sep: str, after: Optional[str]) -> str:
"""
>>> departition(None, '.', 'after')
'after'
Expand Down
2 changes: 1 addition & 1 deletion src/azul/uuids.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def from_json(cls, partition: JSON) -> UUID_PARTITION:
def to_json(self) -> JSON:
return attr.asdict(self)

def contains(self, member: UUID):
def contains(self, member: UUID) -> bool:
"""
>>> p = UUIDPartition(prefix_length=7, prefix=0b0111_1111)
>>> p.contains(UUID('fdd4524e-14c4-41d7-9071-6cadab09d75c'))
Expand Down
Loading

0 comments on commit 20b2270

Please sign in to comment.