Skip to content

Commit

Permalink
Add renamed ResolverStrategies at tfx.dsl.input_resolution.strategies. (
Browse files Browse the repository at this point in the history
#3755)

PiperOrigin-RevId: 373274374

Co-authored-by: jjong <jjong@google.com>
  • Loading branch information
dhruvesh09 and chongkong authored May 17, 2021
1 parent e0fac27 commit bc56366
Show file tree
Hide file tree
Showing 7 changed files with 893 additions and 0 deletions.
13 changes: 13 additions & 0 deletions tfx/dsl/input_resolution/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
108 changes: 108 additions & 0 deletions tfx/dsl/input_resolution/strategies/latest_artifact_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental Resolver for getting the latest artifact."""

from typing import Dict, List, Optional

from tfx import types
from tfx.dsl.components.common import resolver
from tfx.orchestration import data_types
from tfx.orchestration import metadata
from tfx.types import artifact_utils
from tfx.utils import doc_controls

import ml_metadata as mlmd


class LatestArtifactStrategy(resolver.ResolverStrategy):
"""Strategy that resolves the latest n(=1) artifacts per each channel.
Note that this ResolverStrategy is experimental and is subject to change in
terms of both interface and implementation.
Don't construct LatestArtifactStrategy directly, example usage:
```
model_resolver = Resolver(
instance_name='latest_model_resolver',
strategy_class=LatestArtifactStrategy,
model=Channel(type=Model))
model_resolver.outputs['model']
```
"""

def __init__(self, desired_num_of_artifacts: Optional[int] = 1):
self._desired_num_of_artifact = desired_num_of_artifacts

def _resolve(self, input_dict: Dict[str, List[types.Artifact]]):
result = {}
for k, artifact_list in input_dict.items():
sorted_artifact_list = sorted(
artifact_list, key=lambda a: a.id, reverse=True)
result[k] = sorted_artifact_list[:min(
len(sorted_artifact_list), self._desired_num_of_artifact)]
return result

@doc_controls.do_not_generate_docs
def resolve(
self,
pipeline_info: data_types.PipelineInfo,
metadata_handler: metadata.Metadata,
source_channels: Dict[str, types.Channel],
) -> resolver.ResolveResult:
pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
if pipeline_context is None:
raise RuntimeError('Pipeline context absent for %s' % pipeline_context)

candidate_dict = {}
for k, c in source_channels.items():
candidate_artifacts = metadata_handler.get_qualified_artifacts(
contexts=[pipeline_context],
type_name=c.type_name,
producer_component_id=c.producer_component_id,
output_key=c.output_key)
candidate_dict[k] = [
artifact_utils.deserialize_artifact(a.type, a.artifact)
for a in candidate_artifacts
]

resolved_dict = self._resolve(candidate_dict)
resolve_state_dict = {
k: len(artifact_list) >= self._desired_num_of_artifact
for k, artifact_list in resolved_dict.items()
}

return resolver.ResolveResult(
per_key_resolve_result=resolved_dict,
per_key_resolve_state=resolve_state_dict)

@doc_controls.do_not_generate_docs
def resolve_artifacts(
self, store: mlmd.MetadataStore,
input_dict: Dict[str, List[types.Artifact]]
) -> Optional[Dict[str, List[types.Artifact]]]:
"""Resolves artifacts from channels by querying MLMD.
Args:
store: An MLMD MetadataStore object.
input_dict: The input_dict to resolve from.
Returns:
If `min_count` for every input is met, returns a
Dict[str, List[Artifact]]. Otherwise, return None.
"""
resolved_dict = self._resolve(input_dict)
all_min_count_met = all(
len(artifact_list) >= self._desired_num_of_artifact
for artifact_list in resolved_dict.values())
return resolved_dict if all_min_count_met else None
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for LatestArtifactStrategy."""

import tensorflow as tf
from tfx import types
from tfx.dsl.input_resolution.strategies import latest_artifact_strategy
from tfx.orchestration import data_types
from tfx.orchestration import metadata
from tfx.types import standard_artifacts
from tfx.utils import test_case_utils

from ml_metadata.proto import metadata_store_pb2


class LatestArtifactStrategyTest(test_case_utils.TfxTest):

def setUp(self):
super().setUp()
self._connection_config = metadata_store_pb2.ConnectionConfig()
self._connection_config.sqlite.SetInParent()
self._metadata = self.enter_context(
metadata.Metadata(connection_config=self._connection_config))
self._store = self._metadata.store
self._pipeline_info = data_types.PipelineInfo(
pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id')
self._component_info = data_types.ComponentInfo(
component_type='a.b.c',
component_id='my_component',
pipeline_info=self._pipeline_info)

def testStrategy(self):
contexts = self._metadata.register_pipeline_contexts_if_not_exists(
self._pipeline_info)
artifact_one = standard_artifacts.Examples()
artifact_one.uri = 'uri_one'
self._metadata.publish_artifacts([artifact_one])
artifact_two = standard_artifacts.Examples()
artifact_two.uri = 'uri_two'
self._metadata.register_execution(
exec_properties={},
pipeline_info=self._pipeline_info,
component_info=self._component_info,
contexts=contexts)
self._metadata.publish_execution(
component_info=self._component_info,
output_artifacts={'key': [artifact_one, artifact_two]})
expected_artifact = max(artifact_one, artifact_two, key=lambda a: a.id)

strategy = latest_artifact_strategy.LatestArtifactStrategy()
resolve_result = strategy.resolve(
pipeline_info=self._pipeline_info,
metadata_handler=self._metadata,
source_channels={
'input':
types.Channel(
type=artifact_one.type,
producer_component_id=self._component_info.component_id,
output_key='key')
})

self.assertTrue(resolve_result.has_complete_result)
self.assertEqual([
artifact.uri
for artifact in resolve_result.per_key_resolve_result['input']
], [expected_artifact.uri])
self.assertTrue(resolve_result.per_key_resolve_state['input'])

def testStrategy_IrMode(self):
artifact_one = standard_artifacts.Examples()
artifact_one.uri = 'uri_one'
artifact_one.id = 1
artifact_two = standard_artifacts.Examples()
artifact_two.uri = 'uri_two'
artifact_one.id = 2

expected_artifact = max(artifact_one, artifact_two, key=lambda a: a.id)

strategy = latest_artifact_strategy.LatestArtifactStrategy()
result = strategy.resolve_artifacts(
self._store, {'input': [artifact_two, artifact_one]})
self.assertIsNotNone(result)
self.assertEqual([a.uri for a in result['input']],
[expected_artifact.uri])


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit bc56366

Please sign in to comment.