Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Barcodes loaded in inconsistent format (different data types based on batch size) #62

Open
hugo-ekinge opened this issue Oct 29, 2024 · 4 comments

Comments

@hugo-ekinge
Copy link

hugo-ekinge commented Oct 29, 2024

Hi and thank you for all the work you have done so far.

I am trying to adapt your embed_tiles function to create embeddings for arbitrary samples, not just the ones included in HEST-Benchmark. One issue that I came across when looping over all samples is that assert dset.dtype == val.dtype in save_hdf5 triggered on some samples. Upon further inspection the issue seems to be that barcodes is not always of fixed length. On some samples, the first batch of say 128 tiles might have only numerical barcodes like [b'0'] [b'1'] [b'2'] ... [b'994'] etc instead of for example [b'AAACACCAATAACTGC-1'] which is always of fixed length. This is a problem when the number of characters increases to say [b'1014'] because suddenly the dtype is |S4 instead of |S3.

One thing I tried was upping the batch size to 1024 instead which eliminated the problem of going from |S3 to |S4, but instead I get the same issue (although less often) when going from |S4 to |S5 in samples with large number of barcodes...

This is what I have been able to figure out myself so far, but what would you recommend in order to fix it in a more robust way and properly handle all samples in the dataset?

(An interesting note I made is that in the hest_data/patches folder the dtype seems to be object rather than any S type at all for barcodes. So the conversion seems to happen somewhere in the dataset/dataloader pipeline and I am wondering if this is intentional or not and what the purpose behind it is in that case?)

@guillaumejaume
Copy link
Collaborator

Can you post the code to replicate? Thanks!

@hugo-ekinge
Copy link
Author

hugo-ekinge commented Oct 30, 2024

I have made this function that is largely a trimmed down version of predict_single_split with some modifications. It is placed inside benchmark.py

def create_embeddings(args, model_name, device, custom_encoder):
    """ Create embeddings for a single model """
    
    embedding_dir = os.path.join(get_path(args.embed_dataroot), "ALL", model_name)
    os.makedirs(embedding_dir, exist_ok=True)
    
    # Embed patches
    logger.info(f"Embedding ALL tiles using {model_name} encoder and custom code")
    weights_path = get_bench_weights(args.weights_root, model_name)
    if model_name == 'custom_encoder':
        encoder = custom_encoder
    else:
        encoder: InferenceEncoder = inf_encoder_factory(model_name)(weights_path)
    precision = encoder.precision
    
    patches_dir = os.path.join(get_path('hest_data'), 'patches')
    for root, _, files in os.walk(patches_dir):
        for file in tqdm(files):
            if file.endswith('.h5'):
                tile_h5_path = os.path.join(root, file)
                sample_id = os.path.splitext(file)[0]
                embed_path = os.path.join(embedding_dir, f'{sample_id}.h5')
                if not os.path.isfile(embed_path) or args.overwrite:
                    _ = encoder.eval()
                    encoder.to(device)

                    tile_dataset = H5HESTDataset(tile_h5_path, chunk_size=args.batch_size, img_transform=encoder.eval_transforms)
                    tile_dataloader = torch.utils.data.DataLoader(tile_dataset, 
                                                                    batch_size=1, 
                                                                    shuffle=False,
                                                                    num_workers=args.num_workers)

                    _ = embed_tiles(tile_dataloader, encoder, embed_path, device, precision)
                else:
                    logger.info(f"Skipping {sample_id} as it already exists")

To connect it to the existing benchmark workflow I added an argument parser.add_argument('--create_embeddings', type=bool, default='false', help='only create embeddings').

I then added to benchmark_grid at the top of the for model_name in model_names: loop:

if args.create_embeddings:
    create_embeddings(args, model_name, device, custom_encoder)
    continue

and right after the loop

if args.create_embeddings:
    break

Not the prettiest but I just wanted to hook into the existing code to run a few quick tests. Anyway, simply add create_embeddings: True to bench_config.yaml and run the benchmark like normal and it will try to create embeddings of all the data in hest_data/patches and place it in an ALL folder with the other embeddings, sorted for each model included.

With a batch size of 128, I noticed the issue with for example id TENX137 but there are also others.

To debug, I also added

if dset.dtype != val.dtype:
    print(f'Path: {output_fpath}')
    print(f'Key: {key}')
    print(f'dset.dtype: {dset.dtype}')
    print(f'val.dtype: {val.dtype}')
    print(f'dset: {dset}\n{dset[:]}')
    print(f'val: {val}')

right before the assert dset.dtype == val.dtype in file_utils.py.

Note that if overwrite is not set to True just rerunning the same thing will skip the samples where assertion failed previously, since the files have already been created (although incomplete).

@hugo-ekinge
Copy link
Author

Have you been able to replicate?

@guillaumejaume
Copy link
Collaborator

It's hard to reproduce without all seeing all your modifications. You might be able to solve the issue by modifying the save_hdh5, making sure the string types are not changing depending on the size of the string. Can you confirm if saving to h5 using this function solves the issue. thanks.

def save_hdf5_revised(output_fpath, 
                      asset_dict, 
                      attr_dict= None, 
                      mode='a', 
                      auto_chunk = True,
                      chunk_size = None):
    with h5py.File(output_fpath, mode) as f:
        for key, val in asset_dict.items():
            data_shape = val.shape
            if len(data_shape) == 1:
                val = np.expand_dims(val, axis=1)
                data_shape = val.shape

            # Determine if the data is of string type
            if np.issubdtype(val.dtype, np.string_) or np.issubdtype(val.dtype, np.unicode_):
                data_type = h5py.string_dtype(encoding='utf-8')
            else:
                data_type = val.dtype

            if key not in f:  # if key does not exist, create dataset
                if auto_chunk:
                    chunks = True  # let h5py decide chunk size
                else:
                    chunks = (chunk_size,) + data_shape[1:]
                dset = f.create_dataset(
                    key,
                    shape=data_shape,
                    chunks=chunks,
                    maxshape=(None,) + data_shape[1:],
                    dtype=data_type
                )
                # Save attribute dictionary
                if attr_dict is not None:
                    if key in attr_dict.keys():
                        for attr_key, attr_val in attr_dict[key].items():
                            dset.attrs[attr_key] = attr_val
                dset[:] = val
            else:
                dset = f[key]
                dset.resize(len(dset) + data_shape[0], axis=0)
                if dset.dtype != data_type:
                    raise TypeError(f"Data type mismatch for key '{key}'. Dataset dtype: {dset.dtype}, value dtype: {data_type}")
                dset[-data_shape[0]:] = val

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants