Skip to content

Commit 7357926

Browse files
committed
Move tests to another file; Add more test cases for openai format
1 parent bf0b180 commit 7357926

File tree

3 files changed

+351
-242
lines changed

3 files changed

+351
-242
lines changed

src/together/utils/files.py

Lines changed: 34 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,9 @@ def check_file(
9696
return report_dict
9797

9898

99-
def _has_weights(messages: List[Dict[str, str | bool]]) -> bool:
100-
"""Check if any message in the conversation has a weight parameter.
101-
102-
Args:
103-
messages (List[Dict[str, str]]): List of messages to check.
104-
105-
Returns:
106-
bool: True if any message has a weight parameter, False otherwise.
107-
"""
108-
return any("weight" in message for message in messages)
109-
110-
11199
def validate_messages(
112-
messages: List[Dict[str, str | bool]], idx: int = 0
113-
) -> tuple[List[Dict[str, str | bool]], bool]:
100+
messages: List[Dict[str, str | bool]], idx: int
101+
) -> None:
114102
"""Validate the messages column."""
115103
if not isinstance(messages, list):
116104
raise InvalidFileFormatError(
@@ -127,10 +115,7 @@ def validate_messages(
127115
error_source="key_value",
128116
)
129117

130-
has_weights = False
131-
# Check for weights in messages
132-
if _has_weights(messages):
133-
has_weights = True
118+
has_weights = any("weight" in message for message in messages)
134119

135120
previous_role = None
136121
for message in messages:
@@ -189,10 +174,8 @@ def validate_messages(
189174
)
190175
previous_role = message["role"]
191176

192-
return messages, has_weights
193177

194-
195-
def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]:
178+
def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None:
196179
"""Validate the OpenAI preference dataset format.
197180
198181
Args:
@@ -201,9 +184,6 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[st
201184
202185
Raises:
203186
InvalidFileFormatError: If the dataset format is invalid.
204-
205-
Returns:
206-
Dict[str, Any]: The validated example.
207187
"""
208188
if not isinstance(example["input"], dict):
209189
raise InvalidFileFormatError(
@@ -219,43 +199,38 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[st
219199
error_source="key_value",
220200
)
221201

222-
example["input"]["messages"], _ = validate_messages(
223-
example["input"]["messages"], idx
224-
)
225-
226-
if not isinstance(example["preferred_output"], list):
227-
raise InvalidFileFormatError(
228-
message="The dataset is malformed, the `preferred_output` field must be a list.",
229-
line_number=idx + 1,
230-
error_source="key_value",
231-
)
232-
233-
if not isinstance(example["non_preferred_output"], list):
234-
raise InvalidFileFormatError(
235-
message="The dataset is malformed, the `non_preferred_output` field must be a list.",
236-
line_number=idx + 1,
237-
error_source="key_value",
238-
)
202+
validate_messages(example["input"]["messages"], idx)
239203

240-
if len(example["preferred_output"]) != 1:
241-
raise InvalidFileFormatError(
242-
message="The dataset is malformed, the `preferred_output` list must contain exactly one message.",
243-
line_number=idx + 1,
244-
error_source="key_value",
245-
)
204+
for output_field in ["preferred_output", "non_preferred_output"]:
205+
if not isinstance(example[output_field], list):
206+
raise InvalidFileFormatError(
207+
message=f"The dataset is malformed, the `{output_field}` field must be a list.",
208+
line_number=idx + 1,
209+
error_source="key_value",
210+
)
246211

247-
if len(example["non_preferred_output"]) != 1:
248-
raise InvalidFileFormatError(
249-
message="The dataset is malformed, the `non_preferred_output` list must contain exactly one message.",
250-
line_number=idx + 1,
251-
error_source="key_value",
252-
)
212+
if len(example[output_field]) != 1:
213+
raise InvalidFileFormatError(
214+
message=f"The dataset is malformed, the `{output_field}` list must contain exactly one message.",
215+
line_number=idx + 1,
216+
error_source="key_value",
217+
)
218+
if "role" not in example[output_field][0]:
219+
raise InvalidFileFormatError(
220+
message=f"The dataset is malformed, the `{output_field}` message is missing the `role` field.",
221+
line_number=idx + 1,
222+
error_source="key_value",
223+
)
224+
elif example[output_field][0]["role"] != "assistant":
225+
raise InvalidFileFormatError(
226+
message=f"The dataset is malformed, the `{output_field}` must contain an assistant message.",
227+
line_number=idx + 1,
228+
error_source="key_value",
229+
)
230+
253231

254-
example["preferred_output"], _ = validate_messages(example["preferred_output"], idx)
255-
example["non_preferred_output"], _ = validate_messages(
256-
example["non_preferred_output"], idx
257-
)
258-
return example
232+
validate_messages(example["preferred_output"], idx)
233+
validate_messages(example["non_preferred_output"], idx)
259234

260235

261236
def _check_jsonl(file: Path) -> Dict[str, Any]:
@@ -332,9 +307,7 @@ def _check_jsonl(file: Path) -> Dict[str, Any]:
332307
message_column = JSONL_REQUIRED_COLUMNS_MAP[
333308
DatasetFormat.CONVERSATION
334309
][0]
335-
messages, has_weights = validate_messages(
336-
json_line[message_column], idx
337-
)
310+
validate_messages(json_line[message_column], idx)
338311
else:
339312
for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]:
340313
if not isinstance(json_line[column], str):

tests/unit/test_files_checks.py

Lines changed: 29 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,6 @@
55
from together.constants import MIN_SAMPLES
66
from together.utils.files import check_file
77

8-
_TEST_PREFERENCE_OPENAI_CONTENT = [
9-
{
10-
"input": {
11-
"messages": [
12-
{"role": "user", "content": "Hi there, I have a question."},
13-
{"role": "assistant", "content": "Hello, how is your day going?"},
14-
{
15-
"role": "user",
16-
"content": "Hello, can you tell me how cold San Francisco is today?",
17-
},
18-
],
19-
},
20-
"preferred_output": [
21-
{
22-
"role": "assistant",
23-
"content": "Today in San Francisco, it is not quite cold as expected. Morning clouds will give away "
24-
"to sunshine, with a high near 68°F (20°C) and a low around 57°F (14°C).",
25-
}
26-
],
27-
"non_preferred_output": [
28-
{
29-
"role": "assistant",
30-
"content": "It is not particularly cold in San Francisco today.",
31-
}
32-
],
33-
},
34-
{
35-
"input": {
36-
"messages": [
37-
{
38-
"role": "user",
39-
"content": "What's the best way to learn programming?",
40-
},
41-
],
42-
},
43-
"preferred_output": [
44-
{
45-
"role": "assistant",
46-
"content": "The best way to learn programming is through consistent practice, working on real projects, "
47-
"and breaking down complex problems into smaller parts. Start with a beginner-friendly language like Python.",
48-
}
49-
],
50-
"non_preferred_output": [
51-
{"role": "assistant", "content": "Just read some books and you'll be fine."}
52-
],
53-
},
54-
]
55-
568

579
def test_check_jsonl_valid_general(tmp_path: Path):
5810
# Create a valid JSONL file
@@ -128,149 +80,45 @@ def test_check_jsonl_valid_conversational_single_turn(tmp_path: Path):
12880
def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path):
12981
# Create a valid JSONL file with conversational format and multiple user-assistant turn pairs
13082
file = tmp_path / "valid_conversational_multiple_turns.jsonl"
131-
content = _TEST_PREFERENCE_OPENAI_CONTENT
132-
with file.open("w") as f:
133-
f.write("\n".join(json.dumps(item) for item in content))
134-
135-
report = check_file(file)
136-
137-
assert report["is_check_passed"]
138-
assert report["utf8"]
139-
assert report["num_samples"] == len(content)
140-
assert report["has_min_samples"]
141-
142-
143-
def test_check_jsonl_valid_preference_openai(tmp_path: Path):
144-
file = tmp_path / "valid_preference_openai.jsonl"
145-
content = _TEST_PREFERENCE_OPENAI_CONTENT
146-
with file.open("w") as f:
147-
f.write("\n".join(json.dumps(item) for item in content))
148-
149-
report = check_file(file)
150-
151-
assert report["is_check_passed"]
152-
assert report["utf8"]
153-
assert report["num_samples"] == len(content)
154-
assert report["has_min_samples"]
155-
156-
157-
def test_check_jsonl_invalid_preference_openai_missing_fields(tmp_path: Path):
158-
# Test all required fields in OpenAI preference format
159-
required_fields = [
160-
("input", "Missing input field"),
161-
("preferred_output", "Missing preferred_output field"),
162-
("non_preferred_output", "Missing non_preferred_output field"),
163-
]
164-
165-
for field_to_remove, description in required_fields:
166-
file = tmp_path / f"invalid_preference_openai_missing_{field_to_remove}.jsonl"
167-
content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT]
168-
169-
# Remove the specified field from the first item
170-
del content[0][field_to_remove]
171-
172-
with file.open("w") as f:
173-
f.write("\n".join(json.dumps(item) for item in content))
174-
175-
report = check_file(file)
176-
177-
assert not report["is_check_passed"], f"Test should fail when {description}"
178-
179-
180-
def test_check_jsonl_invalid_preference_openai_structural_issues(tmp_path: Path):
181-
# Test various structural issues in OpenAI preference format
182-
test_cases = [
183-
{
184-
"name": "empty_messages",
185-
"modifier": lambda item: item.update({"input": {"messages": []}}),
186-
"description": "Empty messages array",
187-
},
188-
{
189-
"name": "missing_role_preferred",
190-
"modifier": lambda item: item.update(
191-
{"preferred_output": [{"content": "Missing role field"}]}
192-
),
193-
"description": "Missing role in preferred_output",
194-
},
195-
{
196-
"name": "missing_role_non_preferred",
197-
"modifier": lambda item: item.update(
198-
{"non_preferred_output": [{"content": "Missing role field"}]}
199-
),
200-
"description": "Missing role in non_preferred_output",
201-
},
202-
{
203-
"name": "wrong_output_format_preferred",
204-
"modifier": lambda item: item.update(
205-
{"preferred_output": "Not an array but a string"}
206-
),
207-
"description": "Wrong format for preferred_output",
208-
},
209-
{
210-
"name": "wrong_output_format_non_preferred",
211-
"modifier": lambda item: item.update(
212-
{"non_preferred_output": "Not an array but a string"}
213-
),
214-
"description": "Wrong format for non_preferred_output",
215-
},
216-
{
217-
"name": "missing_content",
218-
"modifier": lambda item: item.update(
219-
{"input": {"messages": [{"role": "user"}]}}
220-
),
221-
"description": "Missing content in messages",
222-
},
223-
{
224-
"name": "multiple_preferred_outputs",
225-
"modifier": lambda item: item.update(
226-
{
227-
"preferred_output": [
228-
{"role": "assistant", "content": "First response"},
229-
{"role": "assistant", "content": "Second response"},
230-
]
231-
}
232-
),
233-
"description": "Multiple messages in preferred_output",
234-
},
83+
content = [
23584
{
236-
"name": "multiple_non_preferred_outputs",
237-
"modifier": lambda item: item.update(
85+
"messages": [
86+
{"role": "user", "content": "Is it going to rain today?"},
23887
{
239-
"non_preferred_output": [
240-
{"role": "assistant", "content": "First response"},
241-
{"role": "assistant", "content": "Second response"},
242-
]
243-
}
244-
),
245-
"description": "Multiple messages in non_preferred_output",
88+
"role": "assistant",
89+
"content": "Yes, expect showers in the afternoon.",
90+
},
91+
{"role": "user", "content": "What is the weather like in Tokyo?"},
92+
{"role": "assistant", "content": "It is sunny with a chance of rain."},
93+
]
24694
},
24795
{
248-
"name": "empty_preferred_output",
249-
"modifier": lambda item: item.update({"preferred_output": []}),
250-
"description": "Empty preferred_output array",
96+
"messages": [
97+
{"role": "user", "content": "Who won the game last night?"},
98+
{"role": "assistant", "content": "The home team won by two points."},
99+
{"role": "user", "content": "What is the weather like in Amsterdam?"},
100+
{"role": "assistant", "content": "It is cloudy with a chance of snow."},
101+
]
251102
},
252103
{
253-
"name": "empty_non_preferred_output",
254-
"modifier": lambda item: item.update({"non_preferred_output": []}),
255-
"description": "Empty non_preferred_output array",
104+
"messages": [
105+
{"role": "system", "content": "You are a kind AI"},
106+
{"role": "user", "content": "Who won the game last night?"},
107+
{"role": "assistant", "content": "The home team won by two points."},
108+
{"role": "user", "content": "What is the weather like in Amsterdam?"},
109+
{"role": "assistant", "content": "It is cloudy with a chance of snow."},
110+
]
256111
},
257112
]
113+
with file.open("w") as f:
114+
f.write("\n".join(json.dumps(item) for item in content))
258115

259-
for test_case in test_cases:
260-
file = tmp_path / f"invalid_preference_openai_{test_case['name']}.jsonl"
261-
content = [item.copy() for item in _TEST_PREFERENCE_OPENAI_CONTENT]
262-
263-
# Apply the modification to the first item
264-
test_case["modifier"](content[0])
265-
266-
with file.open("w") as f:
267-
f.write("\n".join(json.dumps(item) for item in content))
268-
269-
report = check_file(file)
116+
report = check_file(file)
270117

271-
assert not report[
272-
"is_check_passed"
273-
], f"Test should fail with {test_case['description']}"
118+
assert report["is_check_passed"]
119+
assert report["utf8"]
120+
assert report["num_samples"] == len(content)
121+
assert report["has_min_samples"]
274122

275123

276124
def test_check_jsonl_empty_file(tmp_path: Path):

0 commit comments

Comments
 (0)