Skip to content

Commit

Permalink
Merge pull request #28 from kostrykin/develop
Browse files Browse the repository at this point in the history
Refactor for v1.0.0-beta3
  • Loading branch information
kostrykin authored Aug 30, 2024
2 parents 69b5767 + fcf4b89 commit 14f00f0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 13 deletions.
23 changes: 20 additions & 3 deletions repype/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,18 @@ def suggest_stage_id(class_name: str) -> str:
"""
Suggests a stage identifier based on a class name.
This function validates the class name, then finds and groups tokens in the class name.
This function validates the `class_name`, then tokenizes it.
Tokens are grouped if they are consecutive and alphanumeric, but do not start with numbers.
The function then converts the tokens to lowercase, removes underscores, and joins them with hyphens.
Example:
.. runblock:: pycon
>>> from repype.stage import suggest_stage_id
>>> print(suggest_stage_id('TheGreatMapperStage'))
>>> print(suggest_stage_id('TheGreat123PCMapper'))
Arguments:
class_name: The name of the class to suggest a configuration namespace for.
Expand All @@ -59,7 +67,11 @@ def suggest_stage_id(class_name: str) -> str:
AssertionError: If the class name is not valid.
"""
assert class_name != '_' and re.match('[a-zA-Z]', class_name) and re.match('^[a-zA-Z_](?:[a-zA-Z0-9_])*$', class_name), f'not a valid class name: "{class_name}"'
tokens1 = re.findall('[A-Z0-9][^A-Z0-9_]*', class_name)

# Find all tokens in the class name (letters or numbers, followed by lowercase letters until the next underscore)
tokens1 = re.findall('[a-zA-Z0-9][^A-Z0-9_]*', class_name)

# Join tokens that are alphanumeric and consecutive
tokens2 = list()
i1 = 0
while i1 < len(tokens1):
Expand All @@ -73,7 +85,12 @@ def suggest_stage_id(class_name: str) -> str:
else:
break
tokens2.append(token.lower().replace('_', ''))
if len(tokens2) >= 2 and tokens2[-1] == 'stage': tokens2 = tokens2[:-1]

# Remove the last token if it is "stage"
if len(tokens2) >= 2 and tokens2[-1] == 'stage':
tokens2 = tokens2[:-1]

# Join the tokens
return '-'.join(tokens2)


Expand Down
20 changes: 15 additions & 5 deletions repype/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,27 @@ def root(self) -> Self:
return self.parent.root if self.parent else self

@property
def marginal_stages(self) -> List[str]:
def marginal_stages(self) -> Iterator[str]:
"""
The stages which are considered marginal.
Outputs of marginal stages are removed from the *pipeline data objects* when storing the results of the task.
The default implementation reads the list of marginal stages from the ``marginal_stages`` field in the task specification.
Returns:
List of the stage identifiers corresponding to the marginal stages.
Yields:
Stage identifiers corresponding to the marginal stages.
"""
return self.full_spec.get('marginal_stages', [])
for stage_spec in self.full_spec.get('marginal_stages', []):
assert isinstance(stage_spec, str), f'Stage identifier must be a string (f{type(stage_spec)}).'

# Load the stage from a module
if '.' in stage_spec:
stage_cls = load_from_module(stage_spec)
yield stage_cls().id

# Use the stage identifier directly
else:
yield stage_spec

@property
def data_filepath(self) -> pathlib.Path:
Expand Down Expand Up @@ -352,7 +362,7 @@ def get_marginal_fields(self, pipeline: repype.pipeline.Pipeline) -> FrozenSet[s
Returns:
Set of marginal fields.
"""
marginal_fields = sum((list(stage.outputs) for stage in pipeline.stages if stage.id in self.marginal_stages), list())
marginal_fields = sum((list(stage.outputs) for stage in pipeline.stages if stage.id in set(self.marginal_stages)), list())
return frozenset(marginal_fields)

def load(self, pipeline: Optional[repype.pipeline.Pipeline] = None) -> TaskData:
Expand Down
2 changes: 1 addition & 1 deletion repype/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
VERSION_MAJOR = 1
VERSION_MINOR = 0
VERSION_PATCH = 0
VERSION_SUFFIX = 'beta2'
VERSION_SUFFIX = 'beta3'

VERSION = '%d.%d.%s%s' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, '-%s' % VERSION_SUFFIX if VERSION_SUFFIX else '')
1 change: 1 addition & 0 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test(self):
self.assertEqual(repype.stage.suggest_stage_id('TheGreat123PCMapper' ), 'the-great-123-pc-mapper')
self.assertEqual(repype.stage.suggest_stage_id('TheGreatMapperStage' ), 'the-great-mapper' )
self.assertEqual(repype.stage.suggest_stage_id('Stage' ), 'stage' )
self.assertEqual(repype.stage.suggest_stage_id('stage1_abc_cls' ), 'stage-1-abc-cls' )

def test_illegal(self):
self.assertRaises(AssertionError, lambda: repype.stage.suggest_stage_id(''))
Expand Down
22 changes: 18 additions & 4 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def test_with_changed_pipeline(self, path):
self.assertTrue(task.is_pending(self.pipeline, config))


class Task__marginal_states(unittest.TestCase):
class Task__marginal_stages(unittest.TestCase):

@testsuite.with_temporary_paths(1)
def test_from_spec_missing(self, path):
Expand All @@ -549,7 +549,7 @@ def test_from_spec_missing(self, path):
parent = None,
spec = dict(),
)
self.assertEqual(task.marginal_stages, [])
self.assertEqual(list(task.marginal_stages), [])

@testsuite.with_temporary_paths(1)
def test_from_spec(self, path):
Expand All @@ -563,7 +563,21 @@ def test_from_spec(self, path):
],
),
)
self.assertEqual(task.marginal_stages, ['stage1', 'stage2'])
self.assertEqual(list(task.marginal_stages), ['stage1', 'stage2'])

@testsuite.with_temporary_paths(1)
def test_from_spec_class_names(self, path):
task = repype.task.Task(
path = path,
parent = None,
spec = dict(
marginal_stages = [
'tests.test_task.Task__create_pipeline.stage1_cls',
'tests.test_task.Task__create_pipeline.stage2_cls',
],
),
)
self.assertEqual(list(task.marginal_stages), ['stage1', 'stage2'])

@testsuite.with_temporary_paths(1)
def test_override(self, path):
Expand All @@ -579,7 +593,7 @@ class DerivedTask(repype.task.Task):
parent = None,
spec = dict(),
)
self.assertEqual(task.marginal_stages, ['stage1', 'stage2'])
self.assertEqual(list(task.marginal_stages), ['stage1', 'stage2'])


class Task__get_marginal_fields(unittest.TestCase):
Expand Down

0 comments on commit 14f00f0

Please sign in to comment.