Skip to content

Commit

Permalink
Mutox classifier (#44)
Browse files Browse the repository at this point in the history
* Main code transfer of Mutox classifier from SeamlessM4T

* Minor changes, import inside mutox broken

* Corrections from PR #44 comments

* Added unit tests for mutox builder, classifier

* Resolved comments PR#44: Added MutoxConfig opt layer, style changes, repo decoupling, other

* Resolved comments 2 PR#44: Missing comments, style changes, others

* Resolved comments 3 PR#44: opt sigmoid layer change, card edit, other linter/mypy related

* Resolved comments 4 PR#44

* Resolved comments 5 PR#44: new integration and unit tests for mutox, other related changes

* Resolved comments 6 PR#44: modifying integration test to increase coverage, other changes to satisfy mypy

* Added line #type: ignore to pass mypy check
  • Loading branch information
David-OC17 authored Dec 4, 2024
1 parent f17dffa commit a914bbd
Show file tree
Hide file tree
Showing 11 changed files with 833 additions and 1 deletion.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,49 @@ print(blaser_qe(src=src_embs, mt=mt_embs).item()) # 4.708
Detailed model cards with more examples: [facebook/blaser-2.0-ref](https://huggingface.co/facebook/blaser-2.0-ref),
[facebook/blaser-2.0-qe](https://huggingface.co/facebook/blaser-2.0-qe).

### Classifying the toxicity of sentences with MuTox

[MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox), the first highly multilingual audio-based classifier (binary) and dataset with toxicity labels. The dataset consists of 20k audio utterances for English and Spanish, and 4k for the other 19 languages, and uses the multi-model and multilingual encoders from SONAR. The output of the MuTox classifier is a logit of the evaluated being _"toxic"_, according to the definition adopted in the corresponding dataset.

```Python
from sonar.models.mutox.loader import load_mutox_model
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
import torch

if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

t2vec_model = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
)
text_column='lang_txt'
classifier = load_mutox_model(
"sonar_mutox",
device=device,
dtype=dtype,
).eval()

with torch.inference_mode():
emb = t2vec_model.predict(["De peur que le pays ne se prostitue et ne se remplisse de crimes."], source_lang='fra_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='eng_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-53.5938]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-21.4062]], device='cuda:0', dtype=torch.float16)
```

For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox).

### Demo notebooks
See more complete demo notebooks :

Expand Down
246 changes: 246 additions & 0 deletions examples/mutox_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Meta Platforms, Inc. and affiliates\n",
"# All rights reserved.\n",
"#\n",
"# This source code is licensed under the license found in the\n",
"# MIT_LICENSE file in the root directory of this source tree."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MUTOX toxicity classification\n",
"\n",
"Mutox enables toxicity scoring for speech and text using sonar embeddings and a classifier trained with a _Binary Cross Entropy loss with logits_ objective. To obtain probabilities from the classifier's output, apply a sigmoid layer. This notebook demonstrates encoding speech and text into sonar embeddings and classifying their toxicity."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from pathlib import Path\n",
"\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda:0\")\n",
" dtype = torch.float16\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
" dtype = torch.float32"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Speech Scoring"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. download some demo audio segments\n",
"2. create a tsv file to feed to the speech scoring pipeline\n",
"3. load the model and build the pipeline\n",
"4. go through the batches in the pipeline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# get demo file\n",
"import urllib.request\n",
"import tempfile\n",
"\n",
"files = [\n",
" (\"https://dl.fbaipublicfiles.com/seamless/tests/commonvoice_example_en_clocks.wav\", \"commonvoice_example_en_clocks.wav\"),\n",
" (\"https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav\", \"LJ037-0171_sr16k.wav\")\n",
"]\n",
"\n",
"tmpdir = Path(tempfile.mkdtemp())\n",
"tsv_file = (tmpdir / 'data.tsv')\n",
"with tsv_file.open('w') as tsv_file_p:\n",
" print('path', file=tsv_file_p)\n",
" for (uri, name) in files:\n",
" dl = tmpdir / name\n",
" urllib.request.urlretrieve(uri, dl)\n",
" print(dl, file=tsv_file_p)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sonar.inference_pipelines.speech import SpeechInferenceParams\n",
"from sonar.inference_pipelines.mutox_speech import MutoxSpeechClassifierPipeline\n",
"\n",
"pipeline_builder = MutoxSpeechClassifierPipeline.load_model_from_name(\n",
" mutox_classifier_name =\"sonar_mutox\",\n",
" encoder_name=f\"sonar_speech_encoder_eng\",\n",
" device=device,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pipeline = pipeline_builder.build_pipeline(SpeechInferenceParams(\n",
" data_file=tsv_file,\n",
" audio_root_dir=None,\n",
" audio_path_index=0,\n",
" target_lang=\"eng\",\n",
" batch_size=4,\n",
" pad_idx=0,\n",
" device=device,\n",
" fbank_dtype=torch.float32,\n",
" n_parallel=4\n",
"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note:** This model was trained using a \"Binary Cross Entropy loss with logits\" objective (as described in the paper). To convert the model's output into probabilities, apply a sigmoid function to the output.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/tmp/tmpqasvhgx6/commonvoice_example_en_clocks.wav\t-42.40079116821289\n",
"/tmp/tmpqasvhgx6/LJ037-0171_sr16k.wav\t-47.90427780151367\n"
]
}
],
"source": [
"for batch in pipeline:\n",
" ex = batch['audio']\n",
" for idx, path in enumerate(ex['path']):\n",
" print(str(path), ex[\"data\"][idx].item(), sep=\"\\t\")\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# cleanup tmp dir\n",
"import shutil\n",
"shutil.rmtree(tmpdir)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Text Scoring\n",
"\n",
"1. load the sonar text encoder\n",
"2. load the mutox classifier model\n",
"3. compute embedding for a sentence\n",
"4. score this embedding"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using the cached checkpoint of mutox. Set `force` to `True` to download again.\n"
]
}
],
"source": [
"from sonar.models.mutox.loader import load_mutox_model\n",
"from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline\n",
"\n",
"t2vec_model = TextToEmbeddingModelPipeline(\n",
" encoder=\"text_sonar_basic_encoder\",\n",
" tokenizer=\"text_sonar_basic_encoder\",\n",
" device=device,\n",
")\n",
"text_column='lang_txt'\n",
"classifier = load_mutox_model(\n",
" \"sonar_mutox\",\n",
" device=device,\n",
" dtype=dtype,\n",
").eval()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-19.7812]], device='cuda:0', dtype=torch.float16)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.inference_mode():\n",
" emb = t2vec_model.predict([\"De peur que le pays ne se prostitue et ne se remplisse de crimes.\"], source_lang='fra_Latn')\n",
" x = classifier(emb.to(device).half())\n",
"\n",
"x"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "SONAR",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
16 changes: 16 additions & 0 deletions sonar/cards/sonar_mutox.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

#This card is a duplicate of the original found at
#[Facebook Research's Seamless Communication repository]
#(https://github.com/facebookresearch/seamless_communication/blob/main/src/seamless_communication/cards/mutox.yaml).
#It is included here to prevent circular dependencies between the Seamless Communication

name: sonar_mutox
model_type: mutox_classifier
model_arch: mutox
checkpoint: "https://dl.fbaipublicfiles.com/seamless/models/mutox.pt"
input_size: 1024
Loading

0 comments on commit a914bbd

Please sign in to comment.