From cda873b637207113e6b98b8e53bd6fa985924d10 Mon Sep 17 00:00:00 2001 From: Vertex MG Team Date: Sun, 8 Dec 2024 04:10:51 -0800 Subject: [PATCH] MediaPipe Text Classification notebook PiperOrigin-RevId: 703980753 --- ...garden_mediapipe_text_classification.ipynb | 461 ++++++------------ 1 file changed, 158 insertions(+), 303 deletions(-) diff --git a/notebooks/community/model_garden/model_garden_mediapipe_text_classification.ipynb b/notebooks/community/model_garden/model_garden_mediapipe_text_classification.ipynb index fc7023b65..41873d58d 100644 --- a/notebooks/community/model_garden/model_garden_mediapipe_text_classification.ipynb +++ b/notebooks/community/model_garden/model_garden_mediapipe_text_classification.ipynb @@ -4,11 +4,12 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "ur8xi4C7S06n" }, "outputs": [], "source": [ - "# Copyright 2023 Google LLC\n", + "# Copyright 2024 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -31,39 +32,18 @@ "source": [ "# Vertex AI Model Garden MediaPipe with text classification\n", "\n", - "\n", - "
\n", - " \n", - " \"Colab Run in Colab\n", + "\n", + " \n", - "\n", - " \n", - " \n", - "
\n", + " \n", + " \"Google
Run in Colab Enterprise\n", "
\n", "
\n", + " \n", " \n", - " \"GitHub\n", - " View on GitHub\n", - " \n", - " \n", - " \n", - " \"Vertex\n", - "Open in Vertex AI Workbench\n", + " \"GitHub
View on GitHub\n", "
\n", "
" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dwGLvtIeECLK" - }, - "source": [ - "**_NOTE_**: This notebook has been tested in the following environment:\n", - "\n", - "* Python version = 3.9\n", - "\n", - "**NOTE**: The checkpoint and the dataset linked in this Colab are not owned or distributed by Google, and are made available by third parties. Please review the terms and conditions made available by the third parties before using the checkpoint and data." + "
" ] }, { @@ -92,11 +72,7 @@ "* Vertex AI\n", "* Cloud Storage\n", "\n", - "Learn about [Vertex AI\n", - "pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage\n", - "pricing](https://cloud.google.com/storage/pricing), and use the [Pricing\n", - "Calculator](https://cloud.google.com/products/calculator/)\n", - "to generate a cost estimate based on your projected usage." + "Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing), [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage." ] }, { @@ -108,219 +84,113 @@ "## Before you begin" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "z__i0w0lCAsW" - }, - "source": [ - "### Colab only\n", - "Run the following commands to install dependencies and to authenticate with Google Cloud if running on Colab." - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "Jvqs-ehKlaYh" }, "outputs": [], "source": [ - "! pip3 install --upgrade pip\n", + "# @title Setup Google Cloud project\n", "\n", - "import sys\n", + "# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).\n", "\n", - "if \"google.colab\" in sys.modules:\n", - " ! pip3 install --upgrade google-cloud-aiplatform\n", + "# @markdown 2. For finetuning, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Frestricted_image_training_nvidia_a100_80gb_gpus)** to check if your project already has the required 8 Nvidia A100 80 GB GPUs in the us-central1 region. If yes, then run this notebook in the us-central1 region. If you do not have 8 Nvidia A100 80 GPUs or have more GPU requirements than this, then schedule your job with Nvidia H100 GPUs via Dynamic Workload Scheduler using [these instructions](https://cloud.google.com/vertex-ai/docs/training/schedule-jobs-dws). For Dynamic Workload Scheduler, check the [us-central1](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) or [europe-west4](https://console.cloud.google.com/iam-admin/quotas?location=europe-west4&metric=aiplatform.googleapis.com%2Fcustom_model_training_preemptible_nvidia_h100_gpus) quota for Nvidia H100 GPUs. If you do not have enough GPUs, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request quota.\n", "\n", - " # Automatically restart kernel after installs\n", - " import IPython\n", + "# @markdown 3. For serving, **[click here](https://console.cloud.google.com/iam-admin/quotas?location=us-central1&metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_l4_gpus)** to check if your project already has the required 1 L4 GPU in the us-central1 region. If yes, then run this notebook in the us-central1 region. If you need more L4 GPUs for your project, then you can follow [these instructions](https://cloud.google.com/docs/quotas/view-manage#viewing_your_quota_console) to request more. Alternatively, if you want to run predictions with A100 80GB or H100 GPUs, we recommend using the regions listed below. **NOTE:** Make sure you have associated quota in selected regions. Click the links to see your current quota for each GPU type: [Nvidia A100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_a100_80gb_gpus), [Nvidia H100 80GB](https://console.cloud.google.com/iam-admin/quotas?metric=aiplatform.googleapis.com%2Fcustom_model_serving_nvidia_h100_gpus).\n", "\n", - " app = IPython.Application.instance()\n", - " app.kernel.do_shutdown(True)\n", + "# @markdown > | Machine Type | Accelerator Type | Recommended Regions |\n", + "# @markdown | ----------- | ----------- | ----------- |\n", + "# @markdown | a2-ultragpu-1g | 1 NVIDIA_A100_80GB | us-central1, us-east4, europe-west4, asia-southeast1, us-east4 |\n", + "# @markdown | a3-highgpu-2g | 2 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |\n", + "# @markdown | a3-highgpu-4g | 4 NVIDIA_H100_80GB | us-west1, asia-southeast1, europe-west4 |\n", + "# @markdown | a3-highgpu-8g | 8 NVIDIA_H100_80GB | us-central1, us-east5, europe-west4, us-west1, asia-southeast1 |\n", "\n", - " from google.colab import auth as google_auth\n", + "# @markdown 4. **[Optional]** [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. \"us\") is not considered a match for a single region covered by the multi-region range (eg. \"us-central1\"). If not set, a unique GCS bucket will be created instead.\n", "\n", - " google_auth.authenticate_user()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WReHDGG5g0XY" - }, - "source": [ - "#### Set your project ID\n", + "BUCKET_URI = \"gs://\" # @param {type:\"string\"}\n", "\n", - "**If you don't know your project ID**, see the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oM1iC_MfAts1" - }, - "outputs": [], - "source": [ - "PROJECT_ID = \"[your-project-id]\" # @param {type:\"string\"}\n", + "# @markdown 5. **[Optional]** Set region. If not set, the region will be set automatically according to Colab Enterprise environment.\n", "\n", - "# Set the project id\n", - "! gcloud config set project {PROJECT_ID}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "region" - }, - "source": [ - "#### Region\n", + "REGION = \"\" # @param {type:\"string\"}\n", "\n", - "You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tTy1gX11kCJY" - }, - "outputs": [], - "source": [ - "REGION = \"us-central1\" # @param {type: \"string\"}\n", - "REGION_PREFIX = REGION.split(\"-\")[0]\n", - "assert REGION_PREFIX in (\n", - " \"us\",\n", - " \"europe\",\n", - " \"asia\",\n", - "), f'{REGION} is not supported. It must be prefixed by \"us\", \"asia\", or \"europe\".'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zgPO1eR3CYjk" - }, - "source": [ - "### Create a Cloud Storage bucket\n", + "! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git\n", "\n", - "Create a storage bucket to store intermediate artifacts such as datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MzGDU7TWdts_" - }, - "outputs": [], - "source": [ - "BUCKET_URI = f\"gs://your-bucket-name-{PROJECT_ID}-unique\" # @param {type:\"string\"}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-EcIXiGsCePi" - }, - "source": [ - "**Only if your bucket doesn't already exist**: Run the following cell to create your Cloud Storage bucket." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NIq7R4HZCfIc" - }, - "outputs": [], - "source": [ - "! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "960505627ddf" - }, - "source": [ - "### Import libraries" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PyQmSRbKA8r-" - }, - "outputs": [], - "source": [ + "import datetime\n", + "import importlib\n", "import json\n", "import os\n", - "from datetime import datetime\n", + "import uuid\n", "\n", - "from google.cloud import aiplatform" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "init_aip:mbsdk,all" - }, - "source": [ - "### Initialize Vertex AI SDK for Python\n", + "from google.cloud import aiplatform\n", "\n", - "Initialize the Vertex AI SDK for Python for your project." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9wExiMUxFk91" - }, - "outputs": [], - "source": [ - "now = datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n", + "common_util = importlib.import_module(\n", + " \"vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util\"\n", + ")\n", "\n", - "STAGING_BUCKET = os.path.join(BUCKET_URI, \"temp/%s\" % now)\n", "\n", - "EVALUATION_RESULT_OUTPUT_DIRECTORY = os.path.join(STAGING_BUCKET, \"evaluation\")\n", - "EVALUATION_RESULT_OUTPUT_FILE = os.path.join(\n", - " EVALUATION_RESULT_OUTPUT_DIRECTORY, \"evaluation.json\"\n", - ")\n", + "# Get the default cloud project id.\n", + "PROJECT_ID = os.environ[\"GOOGLE_CLOUD_PROJECT\"]\n", "\n", - "EXPORTED_MODEL_OUTPUT_DIRECTORY = os.path.join(STAGING_BUCKET, \"model\")\n", - "EXPORTED_MODEL_OUTPUT_FILE = os.path.join(\n", - " EXPORTED_MODEL_OUTPUT_DIRECTORY, \"model.tflite\"\n", - ")\n", + "# Get the default region for launching jobs.\n", + "if not REGION:\n", + " REGION = os.environ[\"GOOGLE_CLOUD_REGION\"]\n", "\n", - "aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "n6IFz75WGCam" - }, - "source": [ - "### Define training machine specs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "riG_qUokg0XZ" - }, - "outputs": [], - "source": [ - "TRAINING_JOB_DISPLAY_NAME = \"mediapipe_text_classifier_%s\" % now\n", - "TRAINING_CONTAINER = f\"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/mediapipe-train\"\n", - "TRAINING_MACHINE_TYPE = \"n1-highmem-16\"\n", - "TRAINING_ACCELERATOR_TYPE = \"NVIDIA_TESLA_V100\"\n", - "TRAINING_ACCELERATOR_COUNT = 2" + "# Enable the Vertex AI API and Compute Engine API, if not already.\n", + "print(\"Enabling Vertex AI API and Compute Engine API.\")\n", + "! gcloud services enable aiplatform.googleapis.com compute.googleapis.com\n", + "\n", + "# Cloud Storage bucket for storing the experiment artifacts.\n", + "# A unique GCS bucket will be created for the purpose of this notebook. If you\n", + "# prefer using your own GCS bucket, change the value yourself below.\n", + "now = datetime.datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", + "BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n", + "\n", + "if BUCKET_URI is None or BUCKET_URI.strip() == \"\" or BUCKET_URI == \"gs://\":\n", + " BUCKET_URI = f\"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}\"\n", + " BUCKET_NAME = \"/\".join(BUCKET_URI.split(\"/\")[:3])\n", + " ! gsutil mb -l {REGION} {BUCKET_URI}\n", + "else:\n", + " assert BUCKET_URI.startswith(\"gs://\"), \"BUCKET_URI must start with `gs://`.\"\n", + " shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep \"Location constraint:\" | sed \"s/Location constraint://\"\n", + " bucket_region = shell_output[0].strip().lower()\n", + " if bucket_region != REGION:\n", + " raise ValueError(\n", + " \"Bucket region %s is different from notebook region %s\"\n", + " % (bucket_region, REGION)\n", + " )\n", + "print(f\"Using this GCS Bucket: {BUCKET_URI}\")\n", + "\n", + "STAGING_BUCKET = os.path.join(BUCKET_URI, \"temporal\")\n", + "MODEL_BUCKET = os.path.join(BUCKET_URI, \"mediapipe_text_classification\")\n", + "\n", + "\n", + "# Initialize Vertex AI API.\n", + "print(\"Initializing Vertex AI API.\")\n", + "aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)\n", + "\n", + "# Gets the default SERVICE_ACCOUNT.\n", + "shell_output = ! gcloud projects describe $PROJECT_ID\n", + "project_number = shell_output[-1].split(\":\")[1].strip().replace(\"'\", \"\")\n", + "SERVICE_ACCOUNT = f\"{project_number}-compute@developer.gserviceaccount.com\"\n", + "print(\"Using this default Service Account:\", SERVICE_ACCOUNT)\n", + "\n", + "\n", + "# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket\n", + "! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME\n", + "\n", + "! gcloud config set project $PROJECT_ID\n", + "! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/storage.admin\"\n", + "! gcloud projects add-iam-policy-binding --no-user-output-enabled {PROJECT_ID} --member=serviceAccount:{SERVICE_ACCOUNT} --role=\"roles/aiplatform.user\"\n", + "\n", + "REGION_PREFIX = REGION.split(\"-\")[0]\n", + "assert REGION_PREFIX in (\n", + " \"us\",\n", + " \"europe\",\n", + " \"asia\",\n", + "), f'{REGION} is not supported. It must be prefixed by \"us\", \"asia\", or \"europe\".'" ] }, { @@ -332,33 +202,24 @@ "## Train your customized models" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "zgPO1eR3CYjk" - }, - "source": [ - "### Get the Dataset\n", - "\n", - "The following code block uses the [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) dataset which contains 67,349 movie reviews for training and 872 movie reviews for testing. The dataset has two classes: positive and negative movie reviews. Positive reviews are labeled with 1 and negative reviews with 0.\n", - "\n", - "The SST-2 dataset is stored as a TSV file. The only difference between the TSV and CSV formats is that TSV uses a tab `\\t` character as its delimiter and CSV uses a comma `,`.\n" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "IndQ_m6ddUEM" }, "outputs": [], "source": [ - "training_data_path = (\n", - " \"gs://mediapipe-tasks/text_classifier/SST-2/train.tsv\" # @param {type:\"string\"}\n", - ")\n", - "validation_data_path = (\n", - " \"gs://mediapipe-tasks/text_classifier/SST-2/dev.tsv\" # @param {type:\"string\"}\n", - ")\n", + "# @title Set the dataset\n", + "\n", + "# @markdown The following code block uses the [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) dataset which contains 67,349 movie reviews for training and 872 movie reviews for testing. The dataset has two classes: positive and negative movie reviews. Positive reviews are labeled with 1 and negative reviews with 0.\n", + "\n", + "# @markdown The SST-2 dataset is stored as a TSV file. The only difference between the TSV and CSV formats is that TSV uses a tab `\\t` character as its delimiter and CSV uses a comma `,`.\n", + "\n", + "training_data_path = \"gs://mediapipe-tasks/text_classifier/SST-2/train.tsv\" # @param {type:\"string\"}\n", + "\n", + "validation_data_path = \"gs://mediapipe-tasks/text_classifier/SST-2/dev.tsv\" # @param {type:\"string\"}\n", "\n", "# The delimiter used in the dataset.\n", "delimiter = \"\\t\" # @param {type:\"string\"}\n", @@ -379,39 +240,30 @@ "label_column = \"label\" # @param {type:\"string\"}" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "aaff6f5be7f6" - }, - "source": [ - "### Set fine-tuning options\n", - "\n", - "You can pick between different model architectures to further customize your training:\n", - "\n", - "* Average Word Embedding Model\n", - "* BERT-classifier\n", - "\n", - "To set the model architecture and other training parameters, adjust the following values:" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "um_XKbmpTaHx" }, "outputs": [], "source": [ - "model_architecture = (\n", - " \"average_word_embedding\" # @param [\"average_word_embedding\", \"mobilebert\"]\n", - ")\n", + "# @title Set fine-tuning options\n", + "\n", + "# @markdown You can pick between different model architectures to further customize your training:\n", + "# @markdown * Average Word Embedding Model\n", + "# @markdown * BERT-classifier\n", + "\n", + "# @markdown To set the model architecture and other training parameters, adjust the below values:\n", + "\n", + "model_architecture = \"average_word_embedding\" # @param [\"average_word_embedding\", \"mobilebert\"]\n", "\n", "# The learning rate to use for gradient descent-based\n", "# optimizers. Defaults to 3e-5 for the BERT-based classifier\n", "# and 0 for the average word-embedding classifier because\n", "# it does not need such an optimizer.\n", - "learning_rate: float = 0.0 # @param {type:\"number\"}\n", + "learning_rate: float = 0.0001 # @param {type:\"number\"}\n", "\n", "# Batch size for training. Defaults to 32 for the average\n", "# word-embedding classifier and 48 for the BERT-based\n", @@ -451,26 +303,31 @@ "vocab_size: int = 10000 # @param {type:\"number\"}" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "HwcCjwlBTQIz" - }, - "source": [ - "### Run fine-tuning\n", - "With your training dataset and fine-tuning options prepared, you are ready to start the fine-tuning process. This process is resource intensive and can take a few minutes to a few hours depending on the model archtiecture and your available compute resources. On Vertex AI with GPU processing, the example fine-tuning below takes between 2-3 minutes to train an Average Word Embedding Model on the SST-2 dataset.\n", - "\n", - "To begin the fine-tuning process, use the following code:\n" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "aec22792ee84" }, "outputs": [], "source": [ + "# @title Run finetuning job\n", + "\n", + "# @markdown With your training dataset and fine-tuning options prepared, you are ready to start the fine-tuning process. This process is resource intensive and can take a few minutes to a few hours depending on the model archtiecture and your available compute resources. On Vertex AI with GPU processing, the example fine-tuning below takes between 2-3 minutes to train an Average Word Embedding Model on the SST-2 dataset.\n", + "\n", + "# @markdown To begin the fine-tuning process, use the following code:\n", + "\n", + "EVALUATION_RESULT_OUTPUT_DIRECTORY = os.path.join(STAGING_BUCKET, \"evaluation\")\n", + "EVALUATION_RESULT_OUTPUT_FILE = os.path.join(\n", + " EVALUATION_RESULT_OUTPUT_DIRECTORY, \"evaluation.json\"\n", + ")\n", + "\n", + "EXPORTED_MODEL_OUTPUT_DIRECTORY = os.path.join(STAGING_BUCKET, \"model\")\n", + "EXPORTED_MODEL_OUTPUT_FILE = os.path.join(\n", + " EXPORTED_MODEL_OUTPUT_DIRECTORY, \"model.tflite\"\n", + ")\n", + "\n", "model_export_path = EXPORTED_MODEL_OUTPUT_DIRECTORY\n", "evaluation_result_path = EVALUATION_RESULT_OUTPUT_DIRECTORY\n", "\n", @@ -502,6 +359,13 @@ " \"dropout_rate\": dropout_rate,\n", "}\n", "\n", + "TRAINING_JOB_DISPLAY_NAME = \"mediapipe_text_classifier_%s\" % now\n", + "TRAINING_CONTAINER = f\"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/mediapipe-train\"\n", + "TRAINING_MACHINE_TYPE = \"n1-highmem-16\"\n", + "TRAINING_ACCELERATOR_TYPE = \"NVIDIA_TESLA_V100\"\n", + "TRAINING_ACCELERATOR_COUNT = 2\n", + "\n", + "\n", "worker_pool_specs = [\n", " {\n", " \"machine_spec\": {\n", @@ -528,6 +392,15 @@ " }\n", "]\n", "\n", + "# Check quota.\n", + "common_util.check_quota(\n", + " project_id=PROJECT_ID,\n", + " region=REGION,\n", + " accelerator_type=TRAINING_ACCELERATOR_TYPE,\n", + " accelerator_count=2,\n", + " is_for_training=True,\n", + ")\n", + "\n", "training_job = aiplatform.CustomJob(\n", " display_name=TRAINING_JOB_DISPLAY_NAME,\n", " project=PROJECT_ID,\n", @@ -538,44 +411,20 @@ "training_job.run()" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "rXMF2tnV_WS0" - }, - "source": [ - "## Export model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g0BGaofgsMsy" - }, - "source": [ - "After finetuning, you can save the Tensorflow Lite model, try it out in the [Text Classification](https://mediapipe-studio.webapps.google.com/demo/text_classifier) demo in MediaPipe Studio or integrate it with your on-device application by following the [Text classification task guide](https://developers.google.com/mediapipe/solutions/text/text_classifier). The exported model contains the generates required model metadata, as well as a classification label file." - ] - }, { "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "NYuQowyZEtxK" }, "outputs": [], "source": [ - "import sys\n", - "\n", - "\n", - "def copy_model(model_source, model_dest):\n", - " ! gsutil cp {model_source} {model_dest}\n", - "\n", - "copy_model(EXPORTED_MODEL_OUTPUT_FILE, \"text_classification_model.tflite\")\n", + "# @title Export model\n", "\n", - "if \"google.colab\" in sys.modules:\n", - " from google.colab import files\n", + "# @markdown After finetuning, you can save the Tensorflow Lite model, try it out in the [Text Classification](https://mediapipe-studio.webapps.google.com/demo/text_classifier) demo in MediaPipe Studio or integrate it with your on-device application by following the [Text classification task guide](https://developers.google.com/mediapipe/solutions/text/text_classifier). The exported model contains the generates required model metadata, as well as a classification label file.\n", "\n", - " files.download(\"text_classification_model.tflite\")" + "! gsutil cp $EXPORTED_MODEL_OUTPUT_FILE text_classification_model.tflite" ] }, { @@ -591,15 +440,21 @@ "cell_type": "code", "execution_count": null, "metadata": { + "cellView": "form", "id": "Ax6vQVZhp9pR" }, "outputs": [], "source": [ + "# @title Clean up training jobs and buckets\n", + "# @markdown Delete temporary GCS buckets.\n", + "\n", + "delete_bucket = False # @param {type:\"boolean\"}\n", + "if delete_bucket:\n", + " ! gsutil -m rm -r $BUCKET_NAME\n", + "\n", "# Delete training data and jobs.\n", "if training_job.list(filter=f'display_name=\"{TRAINING_JOB_DISPLAY_NAME}\"'):\n", - " training_job.delete()\n", - "\n", - "!gsutil rm -r {STAGING_BUCKET}" + " training_job.delete()" ] } ],