-
Notifications
You must be signed in to change notification settings - Fork 722
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add renamed ResolverStrategies at tfx.dsl.input_resolution.strategies. (
#3755) PiperOrigin-RevId: 373274374 Co-authored-by: jjong <jjong@google.com>
- Loading branch information
1 parent
e0fac27
commit bc56366
Showing
7 changed files
with
893 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2021 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
108 changes: 108 additions & 0 deletions
108
tfx/dsl/input_resolution/strategies/latest_artifact_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2021 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Experimental Resolver for getting the latest artifact.""" | ||
|
||
from typing import Dict, List, Optional | ||
|
||
from tfx import types | ||
from tfx.dsl.components.common import resolver | ||
from tfx.orchestration import data_types | ||
from tfx.orchestration import metadata | ||
from tfx.types import artifact_utils | ||
from tfx.utils import doc_controls | ||
|
||
import ml_metadata as mlmd | ||
|
||
|
||
class LatestArtifactStrategy(resolver.ResolverStrategy): | ||
"""Strategy that resolves the latest n(=1) artifacts per each channel. | ||
Note that this ResolverStrategy is experimental and is subject to change in | ||
terms of both interface and implementation. | ||
Don't construct LatestArtifactStrategy directly, example usage: | ||
``` | ||
model_resolver = Resolver( | ||
instance_name='latest_model_resolver', | ||
strategy_class=LatestArtifactStrategy, | ||
model=Channel(type=Model)) | ||
model_resolver.outputs['model'] | ||
``` | ||
""" | ||
|
||
def __init__(self, desired_num_of_artifacts: Optional[int] = 1): | ||
self._desired_num_of_artifact = desired_num_of_artifacts | ||
|
||
def _resolve(self, input_dict: Dict[str, List[types.Artifact]]): | ||
result = {} | ||
for k, artifact_list in input_dict.items(): | ||
sorted_artifact_list = sorted( | ||
artifact_list, key=lambda a: a.id, reverse=True) | ||
result[k] = sorted_artifact_list[:min( | ||
len(sorted_artifact_list), self._desired_num_of_artifact)] | ||
return result | ||
|
||
@doc_controls.do_not_generate_docs | ||
def resolve( | ||
self, | ||
pipeline_info: data_types.PipelineInfo, | ||
metadata_handler: metadata.Metadata, | ||
source_channels: Dict[str, types.Channel], | ||
) -> resolver.ResolveResult: | ||
pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) | ||
if pipeline_context is None: | ||
raise RuntimeError('Pipeline context absent for %s' % pipeline_context) | ||
|
||
candidate_dict = {} | ||
for k, c in source_channels.items(): | ||
candidate_artifacts = metadata_handler.get_qualified_artifacts( | ||
contexts=[pipeline_context], | ||
type_name=c.type_name, | ||
producer_component_id=c.producer_component_id, | ||
output_key=c.output_key) | ||
candidate_dict[k] = [ | ||
artifact_utils.deserialize_artifact(a.type, a.artifact) | ||
for a in candidate_artifacts | ||
] | ||
|
||
resolved_dict = self._resolve(candidate_dict) | ||
resolve_state_dict = { | ||
k: len(artifact_list) >= self._desired_num_of_artifact | ||
for k, artifact_list in resolved_dict.items() | ||
} | ||
|
||
return resolver.ResolveResult( | ||
per_key_resolve_result=resolved_dict, | ||
per_key_resolve_state=resolve_state_dict) | ||
|
||
@doc_controls.do_not_generate_docs | ||
def resolve_artifacts( | ||
self, store: mlmd.MetadataStore, | ||
input_dict: Dict[str, List[types.Artifact]] | ||
) -> Optional[Dict[str, List[types.Artifact]]]: | ||
"""Resolves artifacts from channels by querying MLMD. | ||
Args: | ||
store: An MLMD MetadataStore object. | ||
input_dict: The input_dict to resolve from. | ||
Returns: | ||
If `min_count` for every input is met, returns a | ||
Dict[str, List[Artifact]]. Otherwise, return None. | ||
""" | ||
resolved_dict = self._resolve(input_dict) | ||
all_min_count_met = all( | ||
len(artifact_list) >= self._desired_num_of_artifact | ||
for artifact_list in resolved_dict.values()) | ||
return resolved_dict if all_min_count_met else None |
99 changes: 99 additions & 0 deletions
99
tfx/dsl/input_resolution/strategies/latest_artifact_strategy_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# Copyright 2021 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Test for LatestArtifactStrategy.""" | ||
|
||
import tensorflow as tf | ||
from tfx import types | ||
from tfx.dsl.input_resolution.strategies import latest_artifact_strategy | ||
from tfx.orchestration import data_types | ||
from tfx.orchestration import metadata | ||
from tfx.types import standard_artifacts | ||
from tfx.utils import test_case_utils | ||
|
||
from ml_metadata.proto import metadata_store_pb2 | ||
|
||
|
||
class LatestArtifactStrategyTest(test_case_utils.TfxTest): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self._connection_config = metadata_store_pb2.ConnectionConfig() | ||
self._connection_config.sqlite.SetInParent() | ||
self._metadata = self.enter_context( | ||
metadata.Metadata(connection_config=self._connection_config)) | ||
self._store = self._metadata.store | ||
self._pipeline_info = data_types.PipelineInfo( | ||
pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id') | ||
self._component_info = data_types.ComponentInfo( | ||
component_type='a.b.c', | ||
component_id='my_component', | ||
pipeline_info=self._pipeline_info) | ||
|
||
def testStrategy(self): | ||
contexts = self._metadata.register_pipeline_contexts_if_not_exists( | ||
self._pipeline_info) | ||
artifact_one = standard_artifacts.Examples() | ||
artifact_one.uri = 'uri_one' | ||
self._metadata.publish_artifacts([artifact_one]) | ||
artifact_two = standard_artifacts.Examples() | ||
artifact_two.uri = 'uri_two' | ||
self._metadata.register_execution( | ||
exec_properties={}, | ||
pipeline_info=self._pipeline_info, | ||
component_info=self._component_info, | ||
contexts=contexts) | ||
self._metadata.publish_execution( | ||
component_info=self._component_info, | ||
output_artifacts={'key': [artifact_one, artifact_two]}) | ||
expected_artifact = max(artifact_one, artifact_two, key=lambda a: a.id) | ||
|
||
strategy = latest_artifact_strategy.LatestArtifactStrategy() | ||
resolve_result = strategy.resolve( | ||
pipeline_info=self._pipeline_info, | ||
metadata_handler=self._metadata, | ||
source_channels={ | ||
'input': | ||
types.Channel( | ||
type=artifact_one.type, | ||
producer_component_id=self._component_info.component_id, | ||
output_key='key') | ||
}) | ||
|
||
self.assertTrue(resolve_result.has_complete_result) | ||
self.assertEqual([ | ||
artifact.uri | ||
for artifact in resolve_result.per_key_resolve_result['input'] | ||
], [expected_artifact.uri]) | ||
self.assertTrue(resolve_result.per_key_resolve_state['input']) | ||
|
||
def testStrategy_IrMode(self): | ||
artifact_one = standard_artifacts.Examples() | ||
artifact_one.uri = 'uri_one' | ||
artifact_one.id = 1 | ||
artifact_two = standard_artifacts.Examples() | ||
artifact_two.uri = 'uri_two' | ||
artifact_one.id = 2 | ||
|
||
expected_artifact = max(artifact_one, artifact_two, key=lambda a: a.id) | ||
|
||
strategy = latest_artifact_strategy.LatestArtifactStrategy() | ||
result = strategy.resolve_artifacts( | ||
self._store, {'input': [artifact_two, artifact_one]}) | ||
self.assertIsNotNone(result) | ||
self.assertEqual([a.uri for a in result['input']], | ||
[expected_artifact.uri]) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Oops, something went wrong.