Skip to content

Commit

Permalink
artifact link
Browse files Browse the repository at this point in the history
Signed-off-by: dafnapension <dafnashein@yahoo.com>
  • Loading branch information
dafnapension committed Nov 18, 2024
1 parent c7a2285 commit a90b2e2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/unitxt/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,18 @@ def get_raw(obj):
return shallow_copy(obj)


class ArtifactLink(Artifact):
actual_artifact: Artifact
is_deprecated: bool

def prepare(self):
if self.is_deprecated:
UnitxtWarning(
message=f"Artifact {self.get_artifact_type()} is deprecated. "
f"Its replacement, {self.actual_artifact}, is instantiated instead"
)


class ArtifactList(list, Artifact):
def prepare(self):
for artifact in self:
Expand Down Expand Up @@ -467,15 +479,21 @@ def fetch_artifact(artifact_rep) -> Tuple[Artifact, Union[AbstractCatalog, None]
name, _ = separate_inside_and_outside_square_brackets(artifact_rep)
if is_name_legal_for_catalog(name):
catalog, artifact_rep, args = get_catalog_name_and_args(name=artifact_rep)
return catalog.get_with_overwrite(
artifact_to_return = catalog.get_with_overwrite(
artifact_rep, overwrite_args=args
), catalog
)
if isinstance(artifact_to_return, ArtifactLink):
artifact_to_return = artifact_to_return.actual_artifact
return artifact_to_return, catalog

# If Json string, first load into dictionary
if isinstance(artifact_rep, str):
artifact_rep = json.loads(artifact_rep)
# Load from dictionary (fails if not valid dictionary)
return Artifact.from_dict(artifact_rep), None
artifact_to_return = Artifact.from_dict(artifact_rep)
if isinstance(artifact_to_return, ArtifactLink):
artifact_to_return = artifact_to_return.actual_artifact
return artifact_to_return, None


def get_catalog_name_and_args(
Expand Down
21 changes: 21 additions & 0 deletions tests/library/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from unitxt.artifact import (
Artifact,
ArtifactLink,
fetch_artifact,
get_artifacts_data_classification,
reset_artifacts_json_cache,
Expand Down Expand Up @@ -217,3 +218,23 @@ def test_misconfigured_data_classification_policy(self):

# "Fixing" the env variable so that it does not affect other tests:
del os.environ["UNITXT_DATA_CLASSIFICATION_POLICY"]

def test_artifact_link(self):
with temp_catalog() as catalog_path:
rename = Rename(field_to_field={"label_text": "label"})
add_to_catalog(
Rename(field_to_field={"label_text": "label"}),
"rename.for.test.dict",
catalog_path=catalog_path,
)
add_to_catalog(
ArtifactLink(
actual_artifact="rename.for.test.dict",
is_deprecated=True,
),
"renamefields.for.test.dict",
catalog_path=catalog_path,
)

artifact, _ = fetch_artifact("renamefields.for.test.dict")
self.assertDictEqual(rename.to_dict(), artifact.to_dict())

0 comments on commit a90b2e2

Please sign in to comment.