Skip to content

Commit

Permalink
Add support for batched Iteration (#102)
Browse files Browse the repository at this point in the history
* Add:Support for batched iteration

* Refactor:Data to be list of numpy arrays

* Add:test for tuple of numpy arrays

* Add:batched iteration support
* iter_batches function in Dataset class returns a BatchLoader object
* BatchLoader class added
* Moved utils.py
* Renamed utils.py
* Created test_data.py
* Cleanup
* Fix Typo

* Update src/sparsezoo/utils/data.py

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>

* Fix:Single-Input cases
Address:PR review comments

* Update:Rename tests/utils.py to tests/helpers.py
Fix:Unwrapping Single Input Errors

* Update:Rename tests/utils.py to tests/helpers.py
Fix:Unwrapping Single Input Errors
Update:tests_data.py

* Update:fixes from PR comments

Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
  • Loading branch information
rahul-tuli and bfineran authored Jul 20, 2021
1 parent 7538b0d commit 5ca44f8
Show file tree
Hide file tree
Showing 13 changed files with 346 additions and 11 deletions.
112 changes: 110 additions & 2 deletions src/sparsezoo/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import math
from collections import OrderedDict
from typing import Dict, Iterable, Iterator, List, Tuple, Union
from typing import Dict, Generator, Iterable, Iterator, List, Tuple, Union

import numpy

Expand All @@ -28,10 +28,102 @@

__all__ = ["Dataset", "RandomDataset", "DataLoader"]


_LOGGER = logging.getLogger(__name__)


# A utility class to load data in batches for fixed number of iterations


class _BatchLoader:
__slots__ = [
"_data",
"_batch_size",
"_was_wrapped_originally",
"_iterations",
"_batch_buffer",
"_batch_template",
"_batches_created",
]

def __init__(
self,
data: Iterable[Union[numpy.ndarray, List[numpy.ndarray]]],
batch_size: int,
iterations: int,
):
self._data = data
self._was_wrapped_originally = type(self._data[0]) is list
if not self._was_wrapped_originally:
self._data = [self._data]
self._batch_size = batch_size
self._iterations = iterations
if batch_size <= 0 or iterations <= 0:
raise ValueError(
f"Both batch size and number of iterations should be positive, "
f"supplied values (batch_size, iterations):{(batch_size, iterations)}"
)

self._batch_buffer = []
self._batch_template = self._init_batch_template()
self._batches_created = 0

def __iter__(self) -> Generator[List[numpy.ndarray], None, None]:
yield from self._multi_input_batch_generator()

@property
def _buffer_is_full(self) -> bool:
return len(self._batch_buffer) == self._batch_size

@property
def _all_batches_loaded(self) -> bool:
return self._batches_created >= self._iterations

def _multi_input_batch_generator(
self,
) -> Generator[List[numpy.ndarray], None, None]:
# A generator for with each element of the form
# [[(batch_size, features_a), (batch_size, features_b), ...]]
while not self._all_batches_loaded:
yield from self._batch_generator(source=self._data)

def _batch_generator(self, source) -> Generator[List[numpy.ndarray], None, None]:
# batches from source
for sample in source:
self._batch_buffer.append(sample)
if self._buffer_is_full:
_batch = self._make_batch()
yield _batch
self._batch_buffer = []
self._batches_created += 1
if self._all_batches_loaded:
break

def _init_batch_template(
self,
) -> Iterable[Union[List[numpy.ndarray], numpy.ndarray]]:
# A placeholder for batches
return [
numpy.ascontiguousarray(
numpy.zeros((self._batch_size, *_input.shape), dtype=_input.dtype)
)
for _input in self._data[0]
]

def _make_batch(self) -> Iterable[Union[numpy.ndarray, List[numpy.ndarray]]]:
# Copy contents of buffer to batch placeholder
# and return A list of numpy array(s) representing the batch

batch = [
numpy.stack([sample[idx] for sample in self._batch_buffer], out=template)
for idx, template in enumerate(self._batch_template)
]

if not self._was_wrapped_originally:
# unwrap outer list
batch = batch[0]
return batch


class Dataset(Iterable):
"""
A numpy dataset implementation
Expand Down Expand Up @@ -76,6 +168,22 @@ def data(self) -> List[Union[numpy.ndarray, Dict[str, numpy.ndarray]]]:
"""
return self._data

def iter_batches(
self, batch_size: int, iterations: int
) -> Generator[List[numpy.ndarray], None, None]:
"""
A function to iterate over data in batches
:param batch_size: non-negative integer representing the size of each
:param iterations: non-negative integer representing
the number of batches to return
:returns: A generator for batches, each batch is enclosed in a list
Each batch is of the form [(batch_size, *feature_shape)]
"""
return _BatchLoader(
data=self.data, batch_size=batch_size, iterations=iterations
)


class RandomDataset(Dataset):
"""
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/classification/test_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sparsezoo.models.classification import efficientnet_b0, efficientnet_b4
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/classification/test_inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sparsezoo.models.classification import inception_v3
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/classification/test_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import pytest

from sparsezoo.models.classification import mobilenet_v1, mobilenet_v2
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/classification/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
resnet_101_2x,
resnet_152,
)
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/classification/test_vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
vgg_19,
vgg_19bn,
)
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/detection/test_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sparsezoo.models.detection import ssd_resnet50_300
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/detection/test_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from sparsezoo.models.detection import yolo_v3
from tests.sparsezoo.utils import model_constructor
from tests.sparsezoo.helpers import model_constructor


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/test_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sparsezoo import Zoo
from sparsezoo.utils import CACHE_DIR
from tests.sparsezoo.utils import validate_downloaded_model
from tests.sparsezoo.helpers import validate_downloaded_model


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/sparsezoo/models/test_zoo_extensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

from sparsezoo.models import Zoo
from tests.sparsezoo.utils import download_and_verify
from tests.sparsezoo.helpers import download_and_verify


def _get_models(domain, sub_domain) -> List[str]:
Expand Down
13 changes: 13 additions & 0 deletions tests/sparsezoo/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 5ca44f8

Please sign in to comment.