Skip to content

Commit 694970a

Browse files
authored
Clean up internal column logic in _run_classifier_helper function (NVIDIA#457)
* first commit Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> * working code and pytest Signed-off-by: Sarah Yurick <sarahyurick@gmail.com> --------- Signed-off-by: Sarah Yurick <sarahyurick@gmail.com>
1 parent 3297c1d commit 694970a

File tree

2 files changed

+35
-41
lines changed

2 files changed

+35
-41
lines changed

nemo_curator/classifiers/base.py

+10-25
Original file line numberDiff line numberDiff line change
@@ -121,44 +121,29 @@ def _run_classifier_helper(
121121
prob_col: str = None,
122122
) -> "dask_cudf.DataFrame":
123123

124-
keep_prob = prob_col is not None
125-
prob_internal_col = "_prob"
126-
# TODO: Make crossfit handle this cleanly
127-
pred_internal_col = "labels"
128-
df["sliced_text"] = df[text_field].str.slice(0, max_chars)
124+
if prob_col:
125+
df[prob_col] = 0
126+
else:
127+
prob_col = "_prob"
128+
129129
columns_to_keep_list = df.columns.to_list()
130-
columns_to_keep_list.remove("sliced_text")
131130

132131
classifier_pipe = op.Sequential(
133-
op.Tokenizer(model, cols=["sliced_text"], tokenizer_type="default"),
132+
op.Tokenizer(
133+
model, cols=[text_field], tokenizer_type="default", max_chars=max_chars
134+
),
134135
op.Predictor(
135136
model,
136137
sorted_data_loader=True,
137138
batch_size=batch_size,
138-
pred_output_col=prob_internal_col,
139+
pred_output_col=prob_col,
139140
),
141+
op.Labeler(labels, cols=[prob_col], suffix=label_col),
140142
repartition=df.npartitions,
141143
keep_cols=columns_to_keep_list,
142144
)
143145
df = classifier_pipe(df)
144146

145-
# TODO: Make crossfit handle this cleanly
146-
# to prevent the labeler from dropping the prob_internal_col
147-
# and combine it into a single step
148-
labeling_pipe = op.Sequential(
149-
op.Labeler(labels, cols=[prob_internal_col]),
150-
keep_cols=columns_to_keep_list + [prob_internal_col],
151-
)
152-
df = labeling_pipe(df)
153-
154-
if keep_prob:
155-
df = df.rename(
156-
columns={prob_internal_col: prob_col, pred_internal_col: label_col},
157-
)
158-
else:
159-
df = df.rename(columns={pred_internal_col: label_col})
160-
df = df.drop(columns=[prob_internal_col])
161-
162147
return df
163148

164149

tests/test_classifiers.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
16-
1715
import pytest
1816
from distributed import Client
1917

@@ -48,24 +46,35 @@ def domain_dataset():
4846

4947

5048
@pytest.mark.gpu
51-
def test_domain_classifier(gpu_client, domain_dataset):
49+
@pytest.mark.parametrize("keep_prob", [True, False])
50+
def test_domain_classifier(gpu_client, domain_dataset, keep_prob):
5251
from nemo_curator.classifiers import DomainClassifier
5352

54-
classifier = DomainClassifier()
55-
result_dataset = classifier(dataset=domain_dataset)
56-
result_pred = result_dataset.df.compute()["domain_pred"]
53+
if keep_prob:
54+
prob_column = "domain_prob"
55+
else:
56+
prob_column = None
5757

58-
expected_pred = cudf.Series(
59-
[
60-
"Computers_and_Electronics",
61-
"Finance",
62-
"Health",
63-
"Jobs_and_Education",
64-
"Travel_and_Transportation",
65-
]
66-
)
58+
classifier = DomainClassifier(prob_column=prob_column)
59+
result_dataset = classifier(dataset=domain_dataset)
6760

68-
assert result_pred.equals(expected_pred)
61+
if keep_prob:
62+
result_df = result_dataset.df.compute()
63+
assert "domain_prob" in result_df.columns
64+
else:
65+
result_pred = result_dataset.df.compute()["domain_pred"]
66+
67+
expected_pred = cudf.Series(
68+
[
69+
"Computers_and_Electronics",
70+
"Finance",
71+
"Health",
72+
"Jobs_and_Education",
73+
"Travel_and_Transportation",
74+
]
75+
)
76+
77+
assert result_pred.equals(expected_pred)
6978

7079

7180
@pytest.mark.gpu

0 commit comments

Comments
 (0)