Skip to content

Commit

Permalink
MultitaskTrainModules: Unit tests for order and uniqueness in tasks p…
Browse files Browse the repository at this point in the history
…roperty

caikit#707

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
Signed-off-by: Kelly A <kellyaa@users.noreply.github.com>
  • Loading branch information
gabe-l-hart authored and kellyaa committed May 3, 2024
1 parent a86011c commit fb22bb4
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions tests/core/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@

# Local
from caikit.core import TaskBase, task
from caikit.interfaces.common.data_model import File
from sample_lib import SampleModule
from sample_lib.data_model.sample import SampleInputType, SampleOutputType, SampleTask
from sample_lib.data_model.sample import (
OtherOutputType,
SampleInputType,
SampleOutputType,
SampleTask,
)
from sample_lib.modules.multi_task import FirstTask, MultiTaskModule, SecondTask
import caikit.core

Expand Down Expand Up @@ -171,7 +177,7 @@ def test_task_is_not_required_for_modules():
class Stuff(caikit.core.ModuleBase):
pass

assert Stuff.tasks == set()
assert Stuff.tasks == []


def test_raises_if_tasks_not_list():
Expand Down Expand Up @@ -611,6 +617,32 @@ def run(self, sample_input: Union[str, int]) -> SampleOutputType:
pass


def test_tasks_property_order():
"""Ensure that the tasks returned by .tasks have a deterministic order that
respects the order given in the module decorator
"""
assert MultiTaskModule.tasks == [FirstTask, SecondTask]


def test_tasks_property_unique():
"""Ensure that entries in the tasks list is unique even when inherited from
modules with the same tasks
"""

@caikit.core.module(
id=str(uuid.uuid4()),
name="DerivedMultitaskModule",
version="0.0.1",
task=SecondTask,
)
class DerivedMultitaskModule(MultiTaskModule):
@SecondTask.taskmethod()
def run_second_task(self, file_input: File) -> OtherOutputType:
return OtherOutputType("I'm a derivative!")

assert DerivedMultitaskModule.tasks == [SecondTask, FirstTask]


# ----------- BACKWARDS COMPATIBILITY ------------------------------------------- ##


Expand Down

0 comments on commit fb22bb4

Please sign in to comment.