forked from mahmoodlab/CLAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_features_fp.py
112 lines (88 loc) · 3.99 KB
/
extract_features_fp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import time
import h5py
import numpy as np
import openslide
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset_modules.dataset_h5 import Dataset_All_Bags, Whole_Slide_Bag_FP
from models import get_encoder
from utils.file_utils import save_hdf5
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
def compute_w_loader(output_path, loader, model, silent=False):
"""
args:
output_path: directory to save computed features (.h5 file)
model: pytorch model
silent: quiet outputs for tqdm
"""
if not silent:
print(f'processing a total of {len(loader)} batches'.format(len(loader)))
mode = 'w'
for count, data in enumerate(tqdm(loader, disable=silent)):
with torch.inference_mode():
batch = data['img']
coords = data['coord'].numpy().astype(np.int32)
batch = batch.to(device, non_blocking=True)
features = model(batch)
features = features.cpu().numpy().astype(np.float32)
asset_dict = {'features': features, 'coords': coords}
save_hdf5(output_path, asset_dict, attr_dict= None, mode=mode)
mode = 'a'
return output_path
parser = argparse.ArgumentParser(description='Feature Extraction')
parser.add_argument('--data_h5_dir', type=str, default=None)
parser.add_argument('--data_slide_dir', type=str, default=None)
parser.add_argument('--slide_ext', type=str, default= '.svs')
parser.add_argument('--csv_path', type=str, default=None)
parser.add_argument('--feat_dir', type=str, default=None)
parser.add_argument('--model_name', type=str, default='resnet50_trunc', choices=['resnet50_trunc', 'uni_v1', 'conch_v1', 'virchow_v2'])
parser.add_argument('--batch_size', type=int, default=256)
parser.add_argument('--no_auto_skip', default=False, action='store_true')
parser.add_argument('--target_patch_size', type=int, default=224)
parser.add_argument('--silent', default=False, action='store_true')
args = parser.parse_args()
if __name__ == '__main__':
print('initializing dataset')
csv_path = args.csv_path
if csv_path is None:
raise NotImplementedError
bags_dataset = Dataset_All_Bags(csv_path)
os.makedirs(args.feat_dir, exist_ok=True)
os.makedirs(os.path.join(args.feat_dir, 'pt_files'), exist_ok=True)
os.makedirs(os.path.join(args.feat_dir, 'h5_files'), exist_ok=True)
dest_files = os.listdir(os.path.join(args.feat_dir, 'pt_files'))
model, img_transforms = get_encoder(args.model_name, target_img_size=args.target_patch_size)
_ = model.eval()
model = model.to(device)
total = len(bags_dataset)
loader_kwargs = {'num_workers': 8, 'pin_memory': True} if device.type == "cuda" else {}
for bag_candidate_idx in tqdm(range(total), disable=args.silent):
slide_id = bags_dataset[bag_candidate_idx].split(args.slide_ext)[0]
bag_name = slide_id+'.h5'
h5_file_path = os.path.join(args.data_h5_dir, 'patches', bag_name)
slide_file_path = os.path.join(args.data_slide_dir, slide_id+args.slide_ext)
print('\nprogress: {}/{}'.format(bag_candidate_idx, total))
print(slide_id)
if not args.no_auto_skip and slide_id+'.pt' in dest_files:
print('skipped {}'.format(slide_id))
continue
output_path = os.path.join(args.feat_dir, 'h5_files', bag_name)
time_start = time.time()
wsi = openslide.open_slide(slide_file_path)
dataset = Whole_Slide_Bag_FP(file_path=h5_file_path,
wsi=wsi,
img_transforms=img_transforms)
loader = DataLoader(dataset=dataset, batch_size=args.batch_size, **loader_kwargs)
output_file_path = compute_w_loader(output_path, loader = loader, model = model, silent=args.silent)
time_elapsed = time.time() - time_start
print('\ncomputing features for {} took {} s'.format(output_file_path, time_elapsed))
with h5py.File(output_file_path, "r") as file:
features = file['features'][:]
print('features size: ', features.shape)
print('coordinates size: ', file['coords'].shape)
features = torch.from_numpy(features)
bag_base, _ = os.path.splitext(bag_name)
torch.save(features, os.path.join(args.feat_dir, 'pt_files', bag_base+'.pt'))