Skip to content

Commit

Permalink
Resolved comments 5 PR#44: new integration and unit tests for mutox, …
Browse files Browse the repository at this point in the history
…other related changes
  • Loading branch information
David-OC17 committed Nov 21, 2024
1 parent 96f5d6f commit a126f2d
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ with torch.inference_mode():

with torch.inference_mode():
emb = t2vec_model.predict(["She worked hard and made a significant contribution to the team."], source_lang='eng_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-58.0625]], device='cuda:0', dtype=torch.float16)
x = classifier(emb.to(device).to(dtype)) # tensor([[-53.5938]], device='cuda:0', dtype=torch.float16)

with torch.inference_mode():
emb = t2vec_model.predict(["El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."], source_lang='spa_Latn')
x = classifier(emb.to(device).to(dtype)) # tensor([[-24.6094]], device='cuda:0', dtype=torch.float16)
x = classifier(emb.to(device).to(dtype)) # tensor([[-21.4062]], device='cuda:0', dtype=torch.float16)
```

For a CLI way of running the MuTox pipeline, go to [Seamless Communication/.../MuTox](https://github.com/facebookresearch/seamless_communication/tree/main/src/seamless_communication/cli/toxicity/mutox).
Expand Down
4 changes: 3 additions & 1 deletion sonar/models/mutox/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def build_model(self) -> MutoxClassifier:
model_h3,
)

return MutoxClassifier(model_all,).to(
return MutoxClassifier(
model_all,
).to(
device=self.device,
dtype=self.dtype,
)
Expand Down
1 change: 0 additions & 1 deletion sonar/models/mutox/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tens
return outputs



@dataclass
class MutoxConfig:
"""Holds the configuration of a Mutox Classifier model."""
Expand Down
122 changes: 122 additions & 0 deletions tests/integration_tests/test_mutox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch
from sonar.models.mutox.loader import load_mutox_model
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline


@pytest.mark.parametrize(
"input_texts, source_lang, expected_outputs",
[
(
["De peur que le pays ne se prostitue et ne se remplisse de crimes."],
"fra_Latn",
[-19.7812],
),
(
["She worked hard and made a significant contribution to the team."],
"eng_Latn",
[-53.5938],
),
(
[
"El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."
],
"spa_Latn",
[-21.4062],
),
],
)
def test_sonar_mutox_classifier_integration(input_texts, source_lang, expected_outputs):
"""Integration test to compare classifier outputs with expected values."""
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

t2vec_model = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
)

classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval()

with torch.inference_mode():
embeddings = t2vec_model.predict(input_texts, source_lang=source_lang)
outputs = classifier(embeddings.to(device).to(dtype)).squeeze()

if outputs.dim() == 0:
outputs = [outputs.item()]
else:
outputs = outputs.tolist()

# Compare the outputs to expected values within a small tolerance
for output, expected in zip(outputs, expected_outputs):
assert abs(output - expected) < 0.1, (
f"Expected output {expected}, but got {output}. "
"Outputs should be close to expected values."
)


@pytest.mark.parametrize(
"input_texts, source_lang, expected_probabilities",
[
(
["De peur que le pays ne se prostitue et ne se remplisse de crimes."],
"fra_Latn",
[0.0],
),
(
["She worked hard and made a significant contribution to the team."],
"eng_Latn",
[0.0],
),
(
[
"El no tiene ni el más mínimo talento, todo lo que ha logrado ha sido gracias a sobornos y manipulaciones."
],
"spa_Latn",
[0.0],
),
],
)
def test_sonar_mutox_classifier_probability_integration(
input_texts, source_lang, expected_probabilities
):
"""Integration test to verify classifier output probabilities."""

if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32

t2vec_model = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=device,
)

classifier = load_mutox_model("sonar_mutox", device=device, dtype=dtype).eval()

for text, lang, expected_prob in zip(
input_texts, [source_lang] * len(input_texts), expected_probabilities
):
with torch.inference_mode():
emb = t2vec_model.predict([text], source_lang=lang)

prob = classifier(emb.to(device).to(dtype), output_prob=True)

assert abs(prob.item() - expected_prob) < 0.001, (
f"Expected probability {expected_prob}, but got {prob.item()}. "
"Output probability should be within a reasonable range."
)
23 changes: 23 additions & 0 deletions tests/unit_tests/test_mutox.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def test_mutox_classifier_forward():
), f"Expected output shape (3, 1), but instead got {output.shape}"


def test_mutox_classifier_forward_with_output_prob():
"""Test that MutoxClassifier forward pass applies sigmoid when output_prob=True."""
test_model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1),
)
model = MutoxClassifier(test_model)

test_input = torch.randn(3, 10)

output = model(test_input, output_prob=True)

assert output.shape == (
3,
1,
), f"Expected output shape (3, 1), but instead got {output.shape}"

assert (output >= 0).all() and (
output <= 1
).all(), "Expected output values to be within the range [0, 1]"


def test_mutox_config():
"""Test that MutoxConfig stores the configuration for a model."""
config = MutoxConfig(input_size=512)
Expand Down

0 comments on commit a126f2d

Please sign in to comment.