diff --git a/.github/workflows/ci-pipeline.yml b/.github/workflows/ci-pipeline.yml index ac647e23..0413a251 100644 --- a/.github/workflows/ci-pipeline.yml +++ b/.github/workflows/ci-pipeline.yml @@ -9,7 +9,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-2022] - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.10", "3.11"] steps: - uses: actions/checkout@v3 - if: matrix.os == 'ubuntu-latest' @@ -21,7 +21,6 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install requirements run: | - pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -r test/requirements.txt pip install -r test/torch_requirements.txt pip install . @@ -43,7 +42,6 @@ jobs: python-version: 3.9 - name: Generate coverage report run: | - pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -r test/requirements.txt pip install -r test/torch_requirements.txt pip install . @@ -65,8 +63,8 @@ jobs: python-version: 3.9 - name: Install dependencies run: | - pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install -r docs/requirements.txt + pip install -r test/torch_requirements.txt pip install . - name: Build documentation run: cd docs && make clean && make html # Use SPHINXOPTS="-W" to fail on warning. diff --git a/.gitignore b/.gitignore index 31e3f5c0..75833b93 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ data generated auto_examples .vscode +AUTHORS +ChangeLog # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 4998e308..1570cecb 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![contributors](https://img.shields.io/github/contributors-anon/neuromorphs/tonic)](https://github.com/neuromorphs/tonic/pulse) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/neuromorphs/tonic/main?labpath=docs%2Ftutorials) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5079802.svg)](https://doi.org/10.5281/zenodo.5079802) -[![Discord](https://img.shields.io/discord/852094154188259338)](https://discord.gg/V6FHBZURkg) +[![Discord](https://img.shields.io/discord/1044548629622439977)](https://discord.gg/qubbM4uPuA) **Tonic** is a tool to facilitate the download, manipulation and loading of event-based/spike-based data. It's like PyTorch Vision but for neuromorphic data! diff --git a/docs/datasets.rst b/docs/datasets.rst index fd1f46e5..883edf91 100644 --- a/docs/datasets.rst +++ b/docs/datasets.rst @@ -53,6 +53,14 @@ Star tracking EBSSA +Eye tracking +------------------- +.. autosummary:: + :toctree: generated/ + :template: class_dataset.rst + + ThreeET_Eyetracking + .. currentmodule:: tonic.prototype.datasets Prototype iterable datasets @@ -65,4 +73,4 @@ Prototype iterable datasets NCARS STMNIST Gen1AutomotiveDetection - Gen4AutomotiveDetectionMini + Gen4AutomotiveDetectionMini \ No newline at end of file diff --git a/docs/getting_involved/communication_channels.rst b/docs/getting_involved/communication_channels.rst index eea1f02a..31d2c4b8 100644 --- a/docs/getting_involved/communication_channels.rst +++ b/docs/getting_involved/communication_channels.rst @@ -3,11 +3,10 @@ Communication channels Discord ------- -We would be very happy if you got in touch with us, so please don't hesitate! -The easiest way is to join our Discord channel. There we can reply more or less -instantly. The #tonic channel is part of SynSense's public space along other channels +The easiest way to get in touch with us is to via Discord. There we can reply more or less +instantly. The #tonic channel is part of Open Neuromorphic's public space along other channels for all things revolving around SNN training. -The link to join is https://discord.gg/V6FHBZURkg. +The link to join is https://discord.gg/qubbM4uPuA Github ------ diff --git a/docs/index.md b/docs/index.md index 45d2744a..6ecd444b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,6 +4,7 @@ [![Documentation Status](https://readthedocs.org/projects/tonic/badge/?version=latest)](https://tonic.readthedocs.io/en/latest/?badge=latest) [![contributors](https://img.shields.io/github/contributors-anon/neuromorphs/tonic)](https://github.com/neuromorphs/tonic/pulse) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5079802.svg)](https://doi.org/10.5281/zenodo.5079802) +[![Discord](https://img.shields.io/discord/1044548629622439977)](https://discord.gg/qubbM4uPuA) **Download and manipulate neuromorphic datasets fast and easily!** @@ -53,4 +54,4 @@ how-tos/how-tos reading_material/reading_material getting_involved/getting_involved about/about -``` \ No newline at end of file +``` diff --git a/docs/reading_material/intro-snns.rst b/docs/reading_material/intro-snns.rst index 4c72b2a4..6363119e 100644 --- a/docs/reading_material/intro-snns.rst +++ b/docs/reading_material/intro-snns.rst @@ -22,7 +22,7 @@ ANN is typically a tensor with high data precision, but low temporal resolution, the input for an SNN are binary flags of spikes with comparatively high temporal precision in the -order of s. The unit in the SNN integrates all of the incoming spikes, +order of µs. The unit in the SNN integrates all of the incoming spikes, which affect the internal parameters such as membrane potential. The unit in the ANN merely computes the linear combination for inputs on all synapses and diff --git a/docs/requirements.txt b/docs/requirements.txt index 4cc375ea..67a09d88 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,8 +4,6 @@ sphinx-book-theme sphinx-gallery myst_nb pbr -torchvision ipywidgets matplotlib -torchdata sphinx-autoapi diff --git a/test/test_datasets.py b/test/test_datasets.py index 760dca02..7bc4e964 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -98,13 +98,44 @@ class EBSSATestCase(dataset_utils.DatasetTestCase): def inject_fake_data(self, tmpdir): testfolder = os.path.join(tmpdir, "EBSSA") os.makedirs(testfolder, exist_ok=True) - filename = "A5ooN9edo7TnNPx/download/labelled_ebssa.h5" + filename = "Jpw3Adae5kReMrN/download/labelled_ebssa.h5" download_url( url=base_url + filename, root=testfolder, filename="labelled_ebssa.h5" ) return {"n_samples": 1} +class ThreeET_EyetrackingTestCase(dataset_utils.DatasetTestCase): + DATASET_CLASS = datasets.ThreeET_Eyetracking + FEATURE_TYPES = (datasets.ThreeET_Eyetracking.dtype,) + TARGET_TYPES = (np.ndarray,) + KWARGS = {"split": "train"} + + def inject_fake_data(self, tmpdir): + testfolder = os.path.join(tmpdir, "ThreeET_Eyetracking") + os.makedirs(testfolder, exist_ok=True) + os.makedirs(os.path.join(testfolder, "data"), exist_ok=True) + os.makedirs(os.path.join(testfolder, "labels"), exist_ok=True) + # write one line of file name into train_files.txt under testfolder + os.system("echo testcase > " + os.path.join(testfolder, "train_files.txt")) + filename = "testcase" + + # download test h5 file + download_url( + url=base_url + "4aiA4BAqz5km4Gc/download/" + filename + ".h5", + root=os.path.join(testfolder, "data"), + filename=filename + ".h5", + ) + # # download test labels + download_url( + url=base_url + "G6ejNmXNnB2sKyc/download/" + filename + ".txt", + root=os.path.join(testfolder, "labels"), + filename=filename + ".txt", + ) + + return {"n_samples": 1} + + class NCaltech101TestCase(dataset_utils.DatasetTestCase): DATASET_CLASS = datasets.NCALTECH101 FEATURE_TYPES = (datasets.NCALTECH101.dtype,) diff --git a/test/test_transforms.py b/test/test_transforms.py index 24ba4781..dc3259de 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -151,7 +151,7 @@ def test_transform_drop_events_by_area(area_ratio): break assert ( - dropped_area_found is True + dropped_area_found is True ), f"There should be an area with {dropped_events} events dropped in the obtained sequence." @@ -223,6 +223,51 @@ def test_transform_drop_pixel(coordinates, hot_pixel_frequency): assert events is not orig_events +@pytest.mark.parametrize( + "hot_pixel_frequency, event_max_freq", + [(59, 60), (10, 60)], +) +def test_transform_drop_pixel_unequal_sensor(hot_pixel_frequency, event_max_freq): + orig_events, sensor_size = create_random_input( + n_events=40000, sensor_size=(15, 20, 2) + ) + orig_events = orig_events.tolist() + orig_events += [(0, 0, int(t * 1e3), 1) for t in np.arange(1, 1e6, 1e3 / event_max_freq)] + orig_events += [(0, 19, int(t * 1e3), 1) for t in np.arange(1, 1e6, 1e3 / event_max_freq)] + orig_events += [(14, 0, int(t * 1e3), 1) for t in np.arange(1, 1e6, 1e3 / event_max_freq)] + orig_events += [(14, 19, int(t * 1e3), 1) for t in np.arange(1, 1e6, 1e3 / event_max_freq)] + # cast back to numpy events + orig_events = np.asarray(orig_events, np.dtype([("x", int), ("y", int), ("t", int), ("p", int)])) + + transform = transforms.DropPixel( + coordinates=None, hot_pixel_frequency=hot_pixel_frequency + ) + + events = transform(orig_events) + assert len(np.where((events["x"] == 0) & (events["y"] == 0))[0]) == 0 + assert len(np.where((events["x"] == 14) & (events["y"] == 0))[0]) == 0 + assert len(np.where((events["x"] == 0) & (events["y"] == 19))[0]) == 0 + assert len(np.where((events["x"] == 14) & (events["y"] == 19))[0]) == 0 + + +@pytest.mark.parametrize( + "coordinates, hot_pixel_frequency", + [(((9, 11), (10, 12), (11, 13)), None), (None, 10000)], +) +def test_transform_drop_pixel_empty(coordinates, hot_pixel_frequency): + orig_events, sensor_size = create_random_input( + n_events=0, sensor_size=(15, 20, 2) + ) + + transform = transforms.DropPixel(coordinates=None, hot_pixel_frequency=hot_pixel_frequency) + events = transform(orig_events) + assert len(events) == len(orig_events) + + transform = transforms.DropPixel(coordinates=coordinates, hot_pixel_frequency=None) + events = transform(orig_events) + assert len(events) == len(orig_events) + + @pytest.mark.parametrize( "coordinates, hot_pixel_frequency", [(((199, 11), (199, 12), (11, 13)), None), (None, 5000)], @@ -247,7 +292,8 @@ def test_transform_drop_pixel_raster(coordinates, hot_pixel_frequency): assert not merged_polarity_raster[merged_polarity_raster > 5000].sum().sum() -@pytest.mark.parametrize("time_factor, spatial_factor, target_size", [(1, 0.25, None), (1e-3, (1, 2), None), (1, 1, (5, 5))]) +@pytest.mark.parametrize("time_factor, spatial_factor, target_size", + [(1, 0.25, None), (1e-3, (1, 2), None), (1, 1, (5, 5))]) def test_transform_downsample(time_factor, spatial_factor, target_size): orig_events, sensor_size = create_random_input() @@ -256,43 +302,42 @@ def test_transform_downsample(time_factor, spatial_factor, target_size): ) events = transform(orig_events) - + if not isinstance(spatial_factor, tuple): spatial_factor = (spatial_factor, spatial_factor) - + if target_size is None: assert np.array_equal( (orig_events["t"] * time_factor).astype(orig_events["t"].dtype), events["t"] ) assert np.array_equal(np.floor(orig_events["x"] * spatial_factor[0]), events["x"]) assert np.array_equal(np.floor(orig_events["y"] * spatial_factor[1]), events["y"]) - + else: spatial_factor_test = np.asarray(target_size) / sensor_size[:-1] assert np.array_equal(np.floor(orig_events["x"] * spatial_factor_test[0]), events["x"]) assert np.array_equal(np.floor(orig_events["y"] * spatial_factor_test[1]), events["y"]) - + assert events is not orig_events - - -@pytest.mark.parametrize("target_size, dt, downsampling_method, noise_threshold, differentiator_time_bins", + + +@pytest.mark.parametrize("target_size, dt, downsampling_method, noise_threshold, differentiator_time_bins", [((50, 50), 0.05, 'integrator', 1, None), ((20, 15), 5, 'differentiator', 3, 1)]) -def test_transform_event_downsampling(target_size, dt, downsampling_method, noise_threshold, +def test_transform_event_downsampling(target_size, dt, downsampling_method, noise_threshold, differentiator_time_bins): - orig_events, sensor_size = create_random_input() - - transform = transforms.EventDownsampling(sensor_size=sensor_size, target_size=target_size, dt=dt, + + transform = transforms.EventDownsampling(sensor_size=sensor_size, target_size=target_size, dt=dt, downsampling_method=downsampling_method, noise_threshold=noise_threshold, differentiator_time_bins=differentiator_time_bins) - + events = transform(orig_events) - + assert len(events) <= len(orig_events) assert np.logical_and(np.all(events["x"] <= target_size[0]), np.all(events["y"] <= target_size[1])) assert events is not orig_events - + @pytest.mark.parametrize("target_size", [(50, 50), (10, 5)]) def test_transform_random_crop(target_size): @@ -465,13 +510,13 @@ def test_transform_spatial_jitter(variance, clip_outliers): assert np.isclose(events["y"].all(), orig_events["y"].all(), atol=2 * variance) assert ( - events["x"] - orig_events["x"] - == (events["x"] - orig_events["x"]).astype(int) + events["x"] - orig_events["x"] + == (events["x"] - orig_events["x"]).astype(int) ).all() assert ( - events["y"] - orig_events["y"] - == (events["y"] - orig_events["y"]).astype(int) + events["y"] - orig_events["y"] + == (events["y"] - orig_events["y"]).astype(int) ).all() else: @@ -503,8 +548,8 @@ def test_transform_time_jitter(std, clip_negative, sort_timestamps): np.testing.assert_array_equal(events["y"], orig_events["y"]) np.testing.assert_array_equal(events["p"], orig_events["p"]) assert ( - events["t"] - orig_events["t"] - == (events["t"] - orig_events["t"]).astype(int) + events["t"] - orig_events["t"] + == (events["t"] - orig_events["t"]).astype(int) ).all() assert events is not orig_events @@ -562,17 +607,7 @@ def test_transform_time_skew(coefficient, offset): assert events is not orig_events -@pytest.mark.parametrize( - "n", - [ - 100, - 0, - ( - 10, - 100, - ), - ], -) +@pytest.mark.parametrize("n", [100, 0, (10, 100)]) def test_transform_uniform_noise(n): orig_events, sensor_size = create_random_input() @@ -597,6 +632,16 @@ def test_transform_uniform_noise(n): assert events is not orig_events +@pytest.mark.parametrize("n", [100, 0, (10, 100)]) +def test_transform_uniform_noise_empty(n): + orig_events, sensor_size = create_random_input(n_events=0) + assert len(orig_events) == 0 + + transform = transforms.UniformNoise(sensor_size=sensor_size, n=n) + events = transform(orig_events) + assert len(events) == 0 # check returns an empty array, independent of n. + + def test_transform_time_alignment(): orig_events, sensor_size = create_random_input() @@ -606,3 +651,36 @@ def test_transform_time_alignment(): assert np.min(events["t"]) == 0 assert events is not orig_events + + +def test_toframe_empty(): + orig_events, sensor_size = create_random_input(n_events=0) + assert len(orig_events) == 0 + + with pytest.raises(ValueError): # check that empty array raises error if no slicing method is specified + transform = transforms.ToFrame(sensor_size=sensor_size) + frame = transform(orig_events) + + n_event_bins = 100 + transform = transforms.ToFrame(sensor_size=sensor_size, n_event_bins=n_event_bins) + frame = transform(orig_events) + assert frame.shape == (n_event_bins, sensor_size[2], sensor_size[0], sensor_size[1]) + assert frame.sum() == 0 + + n_time_bins = 100 + transform = transforms.ToFrame(sensor_size=sensor_size, n_time_bins=n_time_bins) + frame = transform(orig_events) + assert frame.shape == (n_time_bins, sensor_size[2], sensor_size[0], sensor_size[1]) + assert frame.sum() == 0 + + event_count = 1e3 + transform = transforms.ToFrame(sensor_size=sensor_size, event_count=event_count) + frame = transform(orig_events) + assert frame.shape == (1, sensor_size[2], sensor_size[0], sensor_size[1]) + assert frame.sum() == 0 + + time_window = 1e3 + transform = transforms.ToFrame(sensor_size=sensor_size, time_window=time_window) + frame = transform(orig_events) + assert frame.shape == (1, sensor_size[2], sensor_size[0], sensor_size[1]) + assert frame.sum() == 0 \ No newline at end of file diff --git a/test/torch_requirements.txt b/test/torch_requirements.txt index 729804fc..fd14302b 100644 --- a/test/torch_requirements.txt +++ b/test/torch_requirements.txt @@ -1,4 +1,5 @@ --index-url https://download.pytorch.org/whl/cpu torch==2.1.0 torchaudio==2.1.0 +torchvision==0.16.0 torchdata diff --git a/tonic/datasets/__init__.py b/tonic/datasets/__init__.py index f131ec42..b16988f9 100644 --- a/tonic/datasets/__init__.py +++ b/tonic/datasets/__init__.py @@ -11,6 +11,7 @@ from .nmnist import NMNIST from .pokerdvs import POKERDVS from .s_mnist import SMNIST +from .threeET_eyetracking import ThreeET_Eyetracking from .tum_vie import TUMVIE from .visual_place_recognition import VPR @@ -28,6 +29,7 @@ "SHD", "SMNIST", "SSC", + "ThreeET_Eyetracking", "TUMVIE", "VPR", "DVSLip", diff --git a/tonic/datasets/dvsgesture.py b/tonic/datasets/dvsgesture.py index a6b48a5a..21e027c5 100644 --- a/tonic/datasets/dvsgesture.py +++ b/tonic/datasets/dvsgesture.py @@ -86,13 +86,21 @@ def __init__( if not self._check_exists(): self.download() + self.users = [] + self.lighting = [] file_path = os.path.join(self.location_on_system, self.folder_name) for path, dirs, files in os.walk(file_path): - dirs.sort() - for file in files: - if file.endswith("npy"): - self.data.append(path + "/" + file) - self.targets.append(int(file[:-4])) + rel_path = os.path.relpath(path, file_path) + if rel_path != ".": + user, lighting = rel_path.split("_", 1) + user = int(user[4:]) + dirs.sort() + for file in files: + if file.endswith("npy"): + self.data.append(os.path.join(path, file)) + self.targets.append(int(file[:-4])) + self.users.append(user) + self.lighting.append(lighting) def __getitem__(self, index): """ diff --git a/tonic/datasets/ebssa.py b/tonic/datasets/ebssa.py index 6b194029..30b5ed66 100644 --- a/tonic/datasets/ebssa.py +++ b/tonic/datasets/ebssa.py @@ -11,7 +11,11 @@ class EBSSA(Dataset): """`EBSSA `_ - There are six different splits provided in this dataset. The labelled section of the dataset contains 84 recordings and 84 label files. The unlabelled section of the dataset contains 153 recordings in folders marked "Unlabelled". + There are six different splits provided in this dataset. The labelled section of the dataset contains 84 recordings and 84 label files. + The unlabelled section of the dataset contains 153 recordings in folders marked "Unlabelled". + If the automatic download from Google Drive fails, please download the file manually from https://drive.google.com/uc?id=1lCh2HWvxEzzaBHT5TlPuyUn6XPM5OVWN + and put it in a folder called 'EBSSA'. Then point the path to its parent folder. For example, if you put the data file in 'datasets/EBSSA/labelled_ebssa.h5', + create the dataset object with 'dataset = tonic.datasets.EBSSA('datasets', split='labelled')' :: diff --git a/tonic/datasets/threeET_eyetracking.py b/tonic/datasets/threeET_eyetracking.py new file mode 100644 index 00000000..8c50cac4 --- /dev/null +++ b/tonic/datasets/threeET_eyetracking.py @@ -0,0 +1,113 @@ +import os +from typing import Any, Callable, Optional, Tuple + +import h5py +import numpy as np + +from tonic.dataset import Dataset +from tonic.io import make_structured_array + + +class ThreeET_Eyetracking(Dataset): + """3ET DVS eye tracking `3ET `_ + :: + + @article{chen20233et, + title={3ET: Efficient Event-based Eye Tracking using a Change-Based ConvLSTM Network}, + author={Chen, Qinyu and Wang, Zuowen and Liu, Shih-Chii and Gao, Chang}, + journal={arXiv preprint arXiv:2308.11771}, + year={2023} + } + + Parameters: + save_to (string): Location to save files to on disk. + transform (callable, optional): A callable of transforms to apply to the data. + split (string, optional): The dataset split to use, ``train`` or ``val``. + target_transform (callable, optional): A callable of transforms to apply to the targets/labels. + transforms (callable, optional): A callable of transforms that is applied to both data and + labels at the same time. + + Returns: + A dataset object that can be indexed or iterated over. + One sample returns a tuple of (events, targets). + """ + + url = "https://dl.dropboxusercontent.com/s/1hyer8egd8843t9/ThreeET_Eyetracking.zip?dl=0" + filename = "ThreeET_Eyetracking.zip" + file_md5 = "b6c652b06fdfd85721f39e2dbe12f4e8" + + sensor_size = (240, 180, 2) + dtype = np.dtype([("t", int), ("x", int), ("y", int), ("p", int)]) + ordering = dtype.names + + def __init__( + self, + save_to: str, + split: str = "train", + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + transforms: Optional[Callable] = None, + ): + super().__init__( + save_to, + transform=transform, + target_transform=target_transform, + transforms=transforms, + ) + + # if not exist, download from url + if not self._check_exists(): + self.download() + + data_dir = os.path.join(save_to, "ThreeET_Eyetracking") + # Load filenames from the provided lists + if split == "train": + filenames = self.load_filenames(os.path.join(data_dir, "train_files.txt")) + elif split == "val": + filenames = self.load_filenames(os.path.join(data_dir, "val_files.txt")) + else: + raise ValueError("Invalid split name") + + # Get the data file paths and target file paths + self.data = [os.path.join(data_dir, "data", f + ".h5") for f in filenames] + self.targets = [os.path.join(data_dir, "labels", f + ".txt") for f in filenames] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Returns: + (events, target) where target is index of the target class. + """ + # get events from .h5 file + with h5py.File(self.data[index], "r") as f: + events = f["events"][:] + # load the sparse labels + with open(self.targets[index], "r") as f: + target = np.array( + [line.strip().split() for line in f.readlines()], np.float64 + ) + + events = make_structured_array( + events[:, 0], # time in us + events[:, 1], # x + events[:, 2], # y + events[:, 3], # polarity in 1 or 0 + dtype=self.dtype, + ) + + if self.transform is not None: + events = self.transform(events) + if self.target_transform is not None: + target = self.target_transform(target) + if self.transforms is not None: + events, target = self.transforms(events, target) + return events, target + + def __len__(self): + return len(self.data) + + def _check_exists(self): + return self._is_file_present() + + def load_filenames(self, path): + with open(path, "r") as f: + return [line.strip() for line in f.readlines()] diff --git a/tonic/functional/drop_pixel.py b/tonic/functional/drop_pixel.py index 5065072a..2a8098cb 100644 --- a/tonic/functional/drop_pixel.py +++ b/tonic/functional/drop_pixel.py @@ -22,7 +22,7 @@ def identify_hot_pixel(events: np.ndarray, hot_pixel_frequency: float): hist = np.histogram2d( events["x"], events["y"], - bins=(np.arange(events["y"].max() + 1), np.arange(events["x"].max() + 1)), + bins=(np.arange(events["x"].max() + 2), np.arange(events["y"].max() + 2)), )[0] max_occur = hot_pixel_frequency * total_time * 1e-6 hot_pixels = np.asarray((hist > max_occur).nonzero()).T diff --git a/tonic/functional/to_frame.py b/tonic/functional/to_frame.py index cd1bacdf..af8dc23a 100644 --- a/tonic/functional/to_frame.py +++ b/tonic/functional/to_frame.py @@ -9,14 +9,14 @@ def to_frame_numpy( - events, - sensor_size, - time_window=None, - event_count=None, - n_time_bins=None, - n_event_bins=None, - overlap=0.0, - include_incomplete=False, + events, + sensor_size, + time_window=None, + event_count=None, + n_time_bins=None, + n_event_bins=None, + overlap=0.0, + include_incomplete=False, ): """Accumulate events to frames by slicing along constant time (time_window), constant number of events (event_count) or constant number of frames (n_time_bins / n_event_bins). @@ -37,11 +37,11 @@ def to_frame_numpy( assert "x" and "t" and "p" in events.dtype.names if ( - not sum( - param is not None - for param in [time_window, event_count, n_time_bins, n_event_bins] - ) - == 1 + not sum( + param is not None + for param in [time_window, event_count, n_time_bins, n_event_bins] + ) + == 1 ): raise ValueError( "Please assign a value to exactly one of the parameters time_window," @@ -93,3 +93,4 @@ def to_frame_numpy( for i, event_slice in enumerate(event_slices): np.add.at(frames, (i, event_slice["p"].astype(int), event_slice["x"]), 1) return frames + diff --git a/tonic/transforms.py b/tonic/transforms.py index 9a1c214c..41302a0f 100644 --- a/tonic/transforms.py +++ b/tonic/transforms.py @@ -61,7 +61,7 @@ def __call__(self, events: np.ndarray) -> np.ndarray: if type(self.size) == int: self.size = [self.size, self.size] offsets = (self.sensor_size[0] - self.size[0]) // 2, ( - self.sensor_size[1] - self.size[1] + self.sensor_size[1] - self.size[1] ) // 2 offset_idx = [max(offset, 0) for offset in offsets] cropped_events = events[ @@ -69,7 +69,7 @@ def __call__(self, events: np.ndarray) -> np.ndarray: & (events["x"] < (offset_idx[0] + self.size[0])) & (offset_idx[1] <= events["y"]) & (events["y"] < (offset_idx[1] + self.size[1])) - ] + ] cropped_events["x"] -= offsets[0] cropped_events["y"] -= offsets[1] return cropped_events @@ -228,6 +228,9 @@ class DropPixel: hot_pixel_frequency: Optional[int] = None def __call__(self, events): + if len(events) == 0: + return events # return empty array + if events.dtype.names is not None: # assert "x", "y", "p" in events.dtype.names if self.hot_pixel_frequency: @@ -723,7 +726,7 @@ def __call__(self, events): @dataclass(frozen=True) class UniformNoise: """Adds a fixed number of n noise events that are uniformly distributed across sensor size - dimensions such as x, y, t and p. + dimensions such as x, y, t and p. Not applied if the input is empty. Parameters: sensor_size: a 3-tuple of x,y,p for sensor_size @@ -744,6 +747,9 @@ def get_params(n: Union[int, Tuple[int, int]]): return n def __call__(self, events): + if len(events) == 0: + return events + n = self.get_params(n=self.n) return functional.uniform_noise_numpy( events=events, sensor_size=self.sensor_size, n=n @@ -782,10 +788,10 @@ class NumpyAsType: def __call__(self, events): source_is_structured_array = ( - hasattr(events.dtype, "names") and events.dtype.names != None + hasattr(events.dtype, "names") and events.dtype.names != None ) target_is_structured_array = ( - hasattr(self.dtype, "names") and self.dtype.names != None + hasattr(self.dtype, "names") and self.dtype.names != None ) if source_is_structured_array and not target_is_structured_array: return np.lib.recfunctions.structured_to_unstructured(events, self.dtype) @@ -891,16 +897,29 @@ class ToFrame: include_incomplete: bool = False def __call__(self, events): - return functional.to_frame_numpy( - events=events, - sensor_size=self.sensor_size, - time_window=self.time_window, - event_count=self.event_count, - n_time_bins=self.n_time_bins, - n_event_bins=self.n_event_bins, - overlap=self.overlap, - include_incomplete=self.include_incomplete, - ) + + # if events are empty, return a frame in the expected format + if len(events) == 0: + if self.time_window is not None or self.event_count is not None: + return np.zeros((1, self.sensor_size[2], self.sensor_size[0], self.sensor_size[1])) + elif self.n_event_bins is not None: + return np.zeros((self.n_event_bins, self.sensor_size[2], self.sensor_size[0], self.sensor_size[1])) + elif self.n_time_bins is not None: + return np.zeros((self.n_time_bins, self.sensor_size[2], self.sensor_size[0], self.sensor_size[1])) + else: + raise ValueError("No slicing method specified.") + + else: + return functional.to_frame_numpy( + events=events, + sensor_size=self.sensor_size, + time_window=self.time_window, + event_count=self.event_count, + n_time_bins=self.n_time_bins, + n_event_bins=self.n_event_bins, + overlap=self.overlap, + include_incomplete=self.include_incomplete, + ) @dataclass(frozen=True)