-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Main code transfer of Mutox classifier from SeamlessM4T * Minor changes, import inside mutox broken * Corrections from PR #44 comments * Added unit tests for mutox builder, classifier * Resolved comments PR#44: Added MutoxConfig opt layer, style changes, repo decoupling, other * Resolved comments 2 PR#44: Missing comments, style changes, others * Resolved comments 3 PR#44: opt sigmoid layer change, card edit, other linter/mypy related * Resolved comments 4 PR#44 * Resolved comments 5 PR#44: new integration and unit tests for mutox, other related changes * Resolved comments 6 PR#44: modifying integration test to increase coverage, other changes to satisfy mypy * Added line #type: ignore to pass mypy check
- Loading branch information
1 parent
f17dffa
commit a914bbd
Showing
11 changed files
with
833 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Copyright (c) Meta Platforms, Inc. and affiliates\n", | ||
"# All rights reserved.\n", | ||
"#\n", | ||
"# This source code is licensed under the license found in the\n", | ||
"# MIT_LICENSE file in the root directory of this source tree." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# MUTOX toxicity classification\n", | ||
"\n", | ||
"Mutox enables toxicity scoring for speech and text using sonar embeddings and a classifier trained with a _Binary Cross Entropy loss with logits_ objective. To obtain probabilities from the classifier's output, apply a sigmoid layer. This notebook demonstrates encoding speech and text into sonar embeddings and classifying their toxicity." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"if torch.cuda.is_available():\n", | ||
" device = torch.device(\"cuda:0\")\n", | ||
" dtype = torch.float16\n", | ||
"else:\n", | ||
" device = torch.device(\"cpu\")\n", | ||
" dtype = torch.float32" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Speech Scoring" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"1. download some demo audio segments\n", | ||
"2. create a tsv file to feed to the speech scoring pipeline\n", | ||
"3. load the model and build the pipeline\n", | ||
"4. go through the batches in the pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get demo file\n", | ||
"import urllib.request\n", | ||
"import tempfile\n", | ||
"\n", | ||
"files = [\n", | ||
" (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n", | ||
" (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n", | ||
"]\n", | ||
"\n", | ||
"tmpdir = Path(tempfile.mkdtemp())\n", | ||
"tsv_file = (tmpdir / 'data.tsv')\n", | ||
"with tsv_file.open('w') as tsv_file_p:\n", | ||
" print('path', file=tsv_file_p)\n", | ||
" for (uri, name) in files:\n", | ||
" dl = tmpdir / name\n", | ||
" urllib.request.urlretrieve(uri, dl)\n", | ||
" print(dl, file=tsv_file_p)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from sonar.inference_pipelines.speech import SpeechInferenceParams\n", | ||
"from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n", | ||
"\n", | ||
"pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n", | ||
" mutox_classifier_name =\"sonar_mutox\",\n", | ||
" encoder_name=f\"sonar_speech_encoder_eng\",\n", | ||
" device=device,\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n", | ||
" data_file=tsv_file,\n", | ||
" audio_root_dir=None,\n", | ||
" audio_path_index=0,\n", | ||
" target_lang=\"eng\",\n", | ||
" batch_size=4,\n", | ||
" pad_idx=0,\n", | ||
" device=device,\n", | ||
" fbank_dtype=torch.float32,\n", | ||
" n_parallel=4\n", | ||
"))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n", | ||
"/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for batch in pipeline:\n", | ||
" ex = batch['audio']\n", | ||
" for idx, path in enumerate(ex['path']):\n", | ||
" print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# cleanup tmp dir\n", | ||
"import shutil\n", | ||
"shutil.rmtree(tmpdir)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Text Scoring\n", | ||
"\n", | ||
"1. load the sonar text encoder\n", | ||
"2. load the mutox classifier model\n", | ||
"3. compute embedding for a sentence\n", | ||
"4. score this embedding" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from sonar.models.mutox.loader import load_mutox_model\n", | ||
"from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n", | ||
"\n", | ||
"t2vec_model = TextToEmbeddingModelPipeline(\n", | ||
" encoder=\"text_sonar_basic_encoder\",\n", | ||
" tokenizer=\"text_sonar_basic_encoder\",\n", | ||
" device=device,\n", | ||
")\n", | ||
"text_column='lang_txt'\n", | ||
"classifier = load_mutox_model(\n", | ||
" \"sonar_mutox\",\n", | ||
" device=device,\n", | ||
" dtype=dtype,\n", | ||
").eval()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)" | ||
] | ||
}, | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"with torch.inference_mode():\n", | ||
" emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n", | ||
" x = classifier(emb.to(device).half())\n", | ||
"\n", | ||
"x" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "SONAR", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
#This card is a duplicate of the original found at | ||
#[Facebook Research's Seamless Communication repository] | ||
#(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml). | ||
#It is included here to prevent circular dependencies between the Seamless Communication | ||
|
||
name: sonar_mutox | ||
model_type: mutox_classifier | ||
model_arch: mutox | ||
checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt" | ||
input_size: 1024 |
Oops, something went wrong.