This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 275
/
Copy pathhooks_output_csv_hook_test.py
67 lines (51 loc) · 1.87 KB
/
hooks_output_csv_hook_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import csv
import os
import shutil
import tempfile
import torch
from classy_vision.hooks import OutputCSVHook
from classy_vision.tasks import build_task
from classy_vision.trainer import LocalTrainer
from test.generic.config_utils import get_fast_test_task_config
from test.generic.hook_test_utils import HookTestBase
def parse_csv(file_path):
"""Parses the csv file and returns number of rows"""
num_rows = 0
with open(file_path, "r", newline="") as csvfile:
reader = csv.DictReader(csvfile, delimiter="\t")
for _ in reader:
num_rows += 1
return num_rows
class TestCSVHook(HookTestBase):
def setUp(self) -> None:
self.base_dir = tempfile.mkdtemp()
def tearDown(self) -> None:
shutil.rmtree(self.base_dir)
def test_constructors(self) -> None:
"""
Test that the hooks are constructed correctly.
"""
folder = f"{self.base_dir}/constructor_test/"
os.makedirs(folder)
self.constructor_test_helper(
config={"folder": folder},
hook_type=OutputCSVHook,
hook_registry_name="output_csv",
invalid_configs=[],
)
def test_train(self) -> None:
for use_gpu in {False, torch.cuda.is_available()}:
folder = f"{self.base_dir}/train_test/{use_gpu}"
os.makedirs(folder)
task = build_task(get_fast_test_task_config(head_num_classes=2))
csv_hook = OutputCSVHook(folder)
task.set_hooks([csv_hook])
task.set_use_gpu(use_gpu)
trainer = LocalTrainer()
trainer.train(task)
self.assertEqual(parse_csv(csv_hook.output_path), 10)