Skip to content

Commit

Permalink
Update data_factory_test.py: remove data path replacement verfication
Browse files Browse the repository at this point in the history
  • Loading branch information
LSC2204 authored Jan 25, 2025
1 parent 044a881 commit ac5f98e
Showing 1 changed file with 1 addition and 21 deletions.
22 changes: 1 addition & 21 deletions tests/data_provider/data_factory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,6 @@ def test_data_factory__get_prompt(setup, save_function, expected_shape):
concatenated_result = np.concatenate((prompt_data, arr))
assert len(concatenated_result) == expected_shape[0] + len(arr)

def test_data_factory_load_prompts_invalid_chars(setup):
data_path, prompt_data_path, prompt_data_folder, datasetFactory = setup
buff = ['rho (g/m**3)', 'rh (%)', 'speed (m/s)', 0, 1]
correct = ['rho (g-m_3)', 'rh (_)', 'speed (m-s)', '0', '1']

# Mock data for mock pth file
train_buf = pd.DataFrame(pd.DataFrame(np.zeros((5, 133)), columns=[i for i in range(133)]))

for i, filename in enumerate(correct):
p = prompt_data_folder / f"mock_{filename}_prompt.pth.tar"
train_prompt = pd.DataFrame(train_buf.iloc[i].values)
train_prompt = train_prompt.T
assert train_prompt.shape == (1, 133)
torch.save(train_prompt, str(p))

prompt_data, _ = datasetFactory.loadPrompts(str(data_path), str(prompt_data_path), buff)
assert len(prompt_data) == 5
for i in range(5):
assert len(prompt_data[i]) == 133

def test_data_factory_load_prompts_invalid_prompt_file(setup):
data_path, prompt_data_path, prompt_data_folder, datasetFactory = setup

Expand Down Expand Up @@ -362,4 +342,4 @@ def test_data_factory_getDatasets_missing_prompts(mocker, setup):
datasetFactory.fetch.assert_has_calls(expected_calls)

assert datasetFactory.splitter.get_csv_splits.call_count == 1
assert datasetFactory.processor.process.call_count == 1
assert datasetFactory.processor.process.call_count == 1

0 comments on commit ac5f98e

Please sign in to comment.