@@ -217,23 +217,36 @@ def test_full(epochs: int, slices: int):
217
217
218
218
test_rng = default_rng (seed = SEED )
219
219
220
- for slice_idx in range (epochs * slices ):
221
- safetensors_file = f"{ folder } /{ slice_idx + 1 } _output.safetensors"
222
- assert os .path .exists (safetensors_file ), "Missing data from disk"
223
- data = stt .load_file (safetensors_file )
224
-
225
- # will see what's going on in a bit
226
- # for val_1, val_idx in zip(data['Y'], test_rng.permutation(len(raw_data))):
227
- # val_2 = 1 if raw_data[val_idx][1] == "pos" else 0
228
- # assert float(val_1) == float(val_2)
220
+ slice_idx = 1
221
+ for _ in range (epochs ):
222
+
223
+ epoch_data = []
224
+ for _ in range (slices ):
225
+ safetensors_file = f"{ folder } /{ slice_idx + 1 } _output.safetensors"
226
+ assert os .path .exists (safetensors_file ), "Missing data from disk"
227
+
228
+ data = stt .load_file (safetensors_file )
229
+
230
+ for val in data ['Y' ]:
231
+ epoch_data .append (val )
232
+
233
+ for token , sentiment in zip (data ['X' ], data ['Y' ]):
234
+ if token == A_TOKEN :
235
+ assert sentiment == 1
236
+ elif token == B_TOKEN :
237
+ assert sentiment == 0
238
+ else :
239
+ assert False , f"Invalid token: { token } "
240
+
241
+ slice_idx += 1
229
242
230
- for token , sentiment in zip ( data [ 'X' ], data [ 'Y' ]):
231
- if token == A_TOKEN :
232
- assert sentiment == 1
233
- elif token == B_TOKEN :
234
- assert sentiment == 0
235
- else :
236
- assert False , f"Invalid token: { token } "
243
+ epoch_permutation = test_rng . permutation ( len ( raw_data ))
244
+
245
+ for val_1 , val_idx in zip ( epoch_data , epoch_permutation )):
246
+ val_2 = 1 if raw_data [ val_idx ][ 1 ] == "pos" else 0
247
+ assert float ( val_1 ) == float ( val_2 )
248
+
249
+
237
250
238
251
239
252
0 commit comments