Skip to content
This repository has been archived by the owner on Sep 27, 2024. It is now read-only.

Commit

Permalink
MCT 1.1 release: update TFX dependencies and update MLMD_Model_Card_T…
Browse files Browse the repository at this point in the history
…oolkit_Demo

PiperOrigin-RevId: 394550241
  • Loading branch information
shuklak13 authored and ml-fairness-infra-github committed Sep 2, 2021
1 parent aeac57a commit 64f7e9f
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 44 deletions.
5 changes: 4 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
<!-- mdlint off(HEADERS_TOO_MANY_H1) -->

# Current Version (Still in Development)
# Release 1.1.0

## Major Features and Improvements

## Bug fixes and other changes

* Update TFX compatibility to TFX 1.2.
* Fix bug where all datasets from MLMD were being compressed into one model_card.Dataset object.

## Breaking changes and Deprecations

# Release 1.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"outputs": [],
"source": [
"!pip install --upgrade pip==20.2\n",
"!pip install \"tfx==0.26.0\"\n",
"!pip install \"tfx==1.2.0\"\n",
"!pip install model-card-toolkit"
]
},
Expand Down Expand Up @@ -151,7 +151,6 @@
"from tfx.components import CsvExampleGen\n",
"from tfx.components import Evaluator\n",
"from tfx.components import Pusher\n",
"from tfx.components import ResolverNode\n",
"from tfx.components import SchemaGen\n",
"from tfx.components import StatisticsGen\n",
"from tfx.components import Trainer\n",
Expand All @@ -167,7 +166,6 @@
"from tfx.types import Channel\n",
"from tfx.types.standard_artifacts import Model\n",
"from tfx.types.standard_artifacts import ModelBlessing\n",
"from tfx.utils.dsl_utils import external_input\n",
"\n",
"import ml_metadata as mlmd"
]
Expand Down Expand Up @@ -290,7 +288,7 @@
"# `pipeline_root` and `metadata_connection_config` may be passed to\n",
"# InteractiveContext. Calls to InteractiveContext are no-ops outside of the\n",
"# notebook.\n",
"context = InteractiveContext()"
"context = InteractiveContext(pipeline_name=\"Census Income Classification Pipeline\")"
]
},
{
Expand Down Expand Up @@ -322,7 +320,7 @@
},
"outputs": [],
"source": [
"example_gen = CsvExampleGen(input=external_input(_data_root))\n",
"example_gen = CsvExampleGen(input_base=_data_root)\n",
"context.run(example_gen)"
]
},
Expand Down Expand Up @@ -356,7 +354,7 @@
"outputs": [],
"source": [
"# Get the URI of the output artifact representing the training examples, which is a directory\n",
"train_uri = os.path.join(example_gen.outputs['examples'].get()[0].uri, 'train')\n",
"train_uri = os.path.join(example_gen.outputs['examples'].get()[0].uri, 'Split-train')\n",
"\n",
"# Get the list of files in this directory (all compressed TFRecord files)\n",
"tfrecord_filenames = [os.path.join(train_uri, name)\n",
Expand Down Expand Up @@ -1231,10 +1229,10 @@
"# different collections. \n",
"model_card.quantitative_analysis.graphics.collection = filter_graphs(\n",
" model_card.quantitative_analysis.graphics.collection, TARGET_EVAL_GRAPH_NAMES)\n",
"model_card.model_parameters.data.eval.graphics.collection = filter_graphs(\n",
" model_card.model_parameters.data.eval.graphics.collection, TARGET_DATASET_GRAPH_NAMES)\n",
"model_card.model_parameters.data.train.graphics.collection = filter_graphs(\n",
" model_card.model_parameters.data.train.graphics.collection, TARGET_DATASET_GRAPH_NAMES)"
"model_card.model_parameters.data[0].graphics.collection = filter_graphs(\n",
" model_card.model_parameters.data[0].graphics.collection, TARGET_DATASET_GRAPH_NAMES)\n",
"model_card.model_parameters.data[1].graphics.collection = filter_graphs(\n",
" model_card.model_parameters.data[1].graphics.collection, TARGET_DATASET_GRAPH_NAMES)"
]
},
{
Expand All @@ -1254,13 +1252,15 @@
},
"outputs": [],
"source": [
"model_card.model_parameters.data.train.graphics.description = (\n",
"model_card.model_parameters.data[0].name = 'train_set'\n",
"model_card.model_parameters.data[0].graphics.description = (\n",
" 'This section includes graphs displaying the class distribution for the '\n",
" '“Race” and “Sex” attributes in our training dataset. We chose to '\n",
" 'show these graphs in particular because we felt it was important that '\n",
" 'users see the class imbalance.'\n",
")\n",
"model_card.model_parameters.data.eval.graphics.description = (\n",
"model_card.model_parameters.data[1].name = 'eval_set'\n",
"model_card.model_parameters.data[1].graphics.description = (\n",
" 'Like the training set, we provide graphs showing the class distribution '\n",
" 'of the data we used to evaluate our model’s performance. '\n",
")\n",
Expand Down
5 changes: 3 additions & 2 deletions model_card_toolkit/model_card_toolkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,9 @@ def _scaffold_model_card(self) -> ModelCard:
graphics.annotate_eval_result_plots(model_card, eval_result)

for stats_artifact in stats_artifacts:
train_stats = tfx_util.read_stats_proto(stats_artifact.uri, 'train')
eval_stats = tfx_util.read_stats_proto(stats_artifact.uri, 'eval')
train_stats = tfx_util.read_stats_proto(stats_artifact.uri,
'Split-train')
eval_stats = tfx_util.read_stats_proto(stats_artifact.uri, 'Split-eval')
graphics.annotate_dataset_feature_statistics_plots(
model_card, [train_stats, eval_stats])
return model_card
Expand Down
3 changes: 2 additions & 1 deletion model_card_toolkit/template/html/default_template.html.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
{% endmacro %}
{% macro render_graphics(graphics) %}
<div class="img-container">
{% if graphics.description %}<p>{{ graphics.description }}</p>{% endif %}
{% for graph in graphics %}
<div class="img-item">
<img src='data:image/jpeg;base64,{{ graph.image }}' alt='{{ graph.name }}' />
Expand Down Expand Up @@ -131,7 +132,7 @@
<h2>Model Details</h2>
{% if model_details.overview %}<h3>Overview</h3>
{{ model_details.overview }}{% endif %}
{% if model_details.version %}<h3>Version</h3>
{% if model_details.version and model_details.version.name %}<h3>Version</h3>
{{ render_if_exist('name', model_details.version.name) }}
{{ render_if_exist('date', model_details.version.date) }}
{{ render_if_exist('diff', model_details.version.diff) }}
Expand Down
Loading

0 comments on commit 64f7e9f

Please sign in to comment.