Skip to content

Commit

Permalink
Yet another fix for the data iterator. Added a test. (#188)
Browse files Browse the repository at this point in the history
* Yet another fix for the data iterator. Added a test that would catch this kind of problem

* Bump minor version
  • Loading branch information
fhieber authored Nov 8, 2017
1 parent 79f3d50 commit 88940a7
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 63 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa

For each item we will potentially have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.10.5]
### Fixed
- Fixed yet another bug with the data iterator.

## [1.10.4]
### Fixed
- Fixed a bug with the revised data iterator not correctly appending EOS symbols for variable-length batches.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.10.4'
__version__ = '1.10.5'
1 change: 1 addition & 0 deletions sockeye/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,7 @@ def reset(self):

self.nd_source = []
self.nd_target = []
self.nd_label = []
self.indices = []
for i in range(len(self.data_source)):
# shuffle indices within each bucket
Expand Down
25 changes: 25 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import random
import sys
from contextlib import contextmanager
from tempfile import TemporaryDirectory
from typing import Optional, Tuple
from unittest.mock import patch
Expand Down Expand Up @@ -101,6 +102,30 @@ def generate_digits_file(source_path: str,
print(" ".join(digits), file=target_out)


@contextmanager
def tmp_digits_dataset(prefix: str,
train_line_count: int, train_max_length: int,
dev_line_count: int, dev_max_length: int,
sort_target: bool = False,
seed_train: int = 13, seed_dev: int = 13):
with TemporaryDirectory(prefix=prefix) as work_dir:
# Simple digits files for train/dev data
train_source_path = os.path.join(work_dir, "train.src")
train_target_path = os.path.join(work_dir, "train.tgt")
dev_source_path = os.path.join(work_dir, "dev.src")
dev_target_path = os.path.join(work_dir, "dev.tgt")
generate_digits_file(train_source_path, train_target_path, train_line_count, train_max_length,
sort_target=sort_target, seed=seed_train)
generate_digits_file(dev_source_path, dev_target_path, dev_line_count, dev_max_length, sort_target=sort_target,
seed=seed_dev)
data = {'work_dir': work_dir,
'source': train_source_path,
'target': train_target_path,
'validation_source': dev_source_path,
'validation_target': dev_target_path}
yield data


_TRAIN_PARAMS_COMMON = "--use-cpu --max-seq-len {max_len} --source {train_source} --target {train_target}" \
" --validation-source {dev_source} --validation-target {dev_target} --output {model}"

Expand Down
25 changes: 8 additions & 17 deletions test/integration/test_seq_copy_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import os
from tempfile import TemporaryDirectory

import pytest

from test.common import generate_digits_file, run_train_translate
from test.common import run_train_translate, tmp_digits_dataset

_TRAIN_LINE_COUNT = 100
_DEV_LINE_COUNT = 10
Expand Down Expand Up @@ -81,23 +78,17 @@
@pytest.mark.parametrize("train_params, translate_params", ENCODER_DECODER_SETTINGS)
def test_seq_copy(train_params, translate_params):
"""Task: copy short sequences of digits"""
with TemporaryDirectory(prefix="test_seq_copy") as work_dir:
# Simple digits files for train/dev data
train_source_path = os.path.join(work_dir, "train.src")
train_target_path = os.path.join(work_dir, "train.tgt")
dev_source_path = os.path.join(work_dir, "dev.src")
dev_target_path = os.path.join(work_dir, "dev.tgt")
generate_digits_file(train_source_path, train_target_path, _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH)
generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT, _LINE_MAX_LENGTH)
with tmp_digits_dataset("test_seq_copy", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT,
_LINE_MAX_LENGTH) as data:
# Test model configuration, including the output equivalence of batch and no-batch decoding
translate_params_batch = translate_params + " --batch-size 2"
# Ignore return values (perplexity and BLEU) for integration test
run_train_translate(train_params,
translate_params,
translate_params_batch,
train_source_path,
train_target_path,
dev_source_path,
dev_target_path,
data['source'],
data['target'],
data['validation_source'],
data['validation_target'],
max_seq_len=_LINE_MAX_LENGTH + 1,
work_dir=work_dir)
work_dir=data['work_dir'])
47 changes: 15 additions & 32 deletions test/system/test_seq_copy_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import os
from tempfile import TemporaryDirectory

import pytest

from test.common import generate_digits_file, run_train_translate
from test.common import tmp_digits_dataset, run_train_translate

_TRAIN_LINE_COUNT = 10000
_DEV_LINE_COUNT = 100
Expand Down Expand Up @@ -91,24 +88,18 @@
])
def test_seq_copy(train_params, translate_params, perplexity_thresh, bleu_thresh):
"""Task: copy short sequences of digits"""
with TemporaryDirectory(prefix="test_seq_copy.") as work_dir:
# Simple digits files for train/dev data
train_source_path = os.path.join(work_dir, "train.src")
train_target_path = os.path.join(work_dir, "train.tgt")
dev_source_path = os.path.join(work_dir, "dev.src")
dev_target_path = os.path.join(work_dir, "dev.tgt")
generate_digits_file(train_source_path, train_target_path, _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, seed=_SEED_TRAIN)
generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT, _LINE_MAX_LENGTH, seed=_SEED_DEV)
with tmp_digits_dataset("test_seq_copy.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT,
_LINE_MAX_LENGTH, seed_train=_SEED_TRAIN, seed_dev=_SEED_DEV) as data:
# Test model configuration
perplexity, bleu = run_train_translate(train_params,
translate_params,
None, # no second set of parameters
train_source_path,
train_target_path,
dev_source_path,
dev_target_path,
data['source'],
data['target'],
data['validation_source'],
data['validation_target'],
max_seq_len=_LINE_MAX_LENGTH + 1,
work_dir=work_dir)
work_dir=data['work_dir'])

assert perplexity <= perplexity_thresh
assert bleu >= bleu_thresh
Expand Down Expand Up @@ -171,25 +162,17 @@ def test_seq_copy(train_params, translate_params, perplexity_thresh, bleu_thresh
])
def test_seq_sort(train_params, translate_params, perplexity_thresh, bleu_thresh):
"""Task: sort short sequences of digits"""
with TemporaryDirectory(prefix="test_seq_sort.") as work_dir:
# Simple digits files for train/dev data
train_source_path = os.path.join(work_dir, "train.src")
train_target_path = os.path.join(work_dir, "train.tgt")
dev_source_path = os.path.join(work_dir, "dev.src")
dev_target_path = os.path.join(work_dir, "dev.tgt")
generate_digits_file(train_source_path, train_target_path, _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH,
sort_target=True, seed=_SEED_TRAIN)
generate_digits_file(dev_source_path, dev_target_path, _DEV_LINE_COUNT, _LINE_MAX_LENGTH,
sort_target=True, seed=_SEED_DEV)
with tmp_digits_dataset("test_seq_sort.", _TRAIN_LINE_COUNT, _LINE_MAX_LENGTH, _DEV_LINE_COUNT,
_LINE_MAX_LENGTH, sort_target=True, seed_train=_SEED_TRAIN, seed_dev=_SEED_DEV) as data:
# Test model configuration
perplexity, bleu = run_train_translate(train_params,
translate_params,
None,
train_source_path,
train_target_path,
dev_source_path,
dev_target_path,
data['source'],
data['target'],
data['validation_source'],
data['validation_target'],
max_seq_len=_LINE_MAX_LENGTH + 1,
work_dir=work_dir)
work_dir=data['work_dir'])
assert perplexity <= perplexity_thresh
assert bleu >= bleu_thresh
123 changes: 110 additions & 13 deletions test/unit/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import numpy as np
import pytest

import sockeye.constants as C
import sockeye.data_io
from sockeye import constants as C
from sockeye import data_io
from sockeye import vocab
from test.common import tmp_digits_dataset

define_bucket_tests = [(50, 10, [10, 20, 30, 40, 50]),
(50, 20, [20, 40, 50]),
Expand All @@ -26,25 +29,32 @@

@pytest.mark.parametrize("max_seq_len, step, expected_buckets", define_bucket_tests)
def test_define_buckets(max_seq_len, step, expected_buckets):
buckets = sockeye.data_io.define_buckets(max_seq_len, step=step)
buckets = data_io.define_buckets(max_seq_len, step=step)
assert buckets == expected_buckets


define_parallel_bucket_tests = [(50, 50, 10, 1.0, [(10, 10), (20, 20), (30, 30), (40, 40), (50, 50)]),
(50, 50, 10, 0.5, [(10, 5), (20, 10), (30, 15), (40, 20), (50, 25), (50, 30), (50, 35), (50, 40), (50, 45), (50, 50)]),
(10, 10, 10, 0.1, [(10, 2), (10, 3), (10, 4), (10, 5), (10, 6), (10, 7), (10, 8), (10, 9), (10, 10)]),
(50, 50, 10, 0.5,
[(10, 5), (20, 10), (30, 15), (40, 20), (50, 25), (50, 30), (50, 35), (50, 40),
(50, 45), (50, 50)]),
(10, 10, 10, 0.1,
[(10, 2), (10, 3), (10, 4), (10, 5), (10, 6), (10, 7), (10, 8), (10, 9), (10, 10)]),
(10, 5, 10, 0.01, [(10, 2), (10, 3), (10, 4), (10, 5)]),
(50, 50, 10, 2.0, [(5, 10), (10, 20), (15, 30), (20, 40), (25, 50), (30, 50), (35, 50), (40, 50), (45, 50), (50, 50)]),
(50, 50, 10, 2.0,
[(5, 10), (10, 20), (15, 30), (20, 40), (25, 50), (30, 50), (35, 50), (40, 50),
(45, 50), (50, 50)]),
(5, 10, 10, 10.0, [(2, 10), (3, 10), (4, 10), (5, 10)]),
(5, 10, 10, 11.0, [(2, 10), (3, 10), (4, 10), (5, 10)]),
(50, 50, 50, 0.5, [(50, 25), (50, 50)]),
(50, 50, 50, 1.5, [(33, 50), (50, 50)]),
(75, 75, 50, 1.5, [(33, 50), (66, 75), (75, 75)])]


@pytest.mark.parametrize("max_seq_len_source, max_seq_len_target, bucket_width, length_ratio, expected_buckets", define_parallel_bucket_tests)
@pytest.mark.parametrize("max_seq_len_source, max_seq_len_target, bucket_width, length_ratio, expected_buckets",
define_parallel_bucket_tests)
def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, length_ratio, expected_buckets):
buckets = sockeye.data_io.define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width=bucket_width, length_ratio=length_ratio)
buckets = data_io.define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width=bucket_width,
length_ratio=length_ratio)
assert buckets == expected_buckets


Expand All @@ -60,7 +70,7 @@ def test_define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_
@pytest.mark.parametrize("buckets, length, expected_bucket",
get_bucket_tests)
def test_get_bucket(buckets, length, expected_bucket):
bucket = sockeye.data_io.get_bucket(length, buckets)
bucket = data_io.get_bucket(length, buckets)
assert bucket == expected_bucket


Expand All @@ -70,7 +80,7 @@ def test_get_bucket(buckets, length, expected_bucket):

@pytest.mark.parametrize("line, expected_tokens", get_tokens_tests)
def test_get_tokens(line, expected_tokens):
tokens = list(sockeye.data_io.get_tokens(line))
tokens = list(data_io.get_tokens(line))
assert tokens == expected_tokens


Expand All @@ -80,15 +90,15 @@ def test_get_tokens(line, expected_tokens):

@pytest.mark.parametrize("tokens, vocab, expected_ids", tokens2ids_tests)
def test_tokens2ids(tokens, vocab, expected_ids):
ids = sockeye.data_io.tokens2ids(tokens, vocab)
ids = data_io.tokens2ids(tokens, vocab)
assert ids == expected_ids


@pytest.mark.parametrize("buckets, expected_default_bucket_key",
[([(10, 10), (20, 20), (30, 30), (40, 40), (50, 50)], (50, 50)),
([(5, 10), (10, 20), (15, 30), (25, 50), (20, 40)], (25, 50))])
def test_get_default_bucket_key(buckets, expected_default_bucket_key):
default_bucket_key = sockeye.data_io.get_default_bucket_key(buckets)
default_bucket_key = data_io.get_default_bucket_key(buckets)
assert default_bucket_key == expected_default_bucket_key


Expand All @@ -104,6 +114,93 @@ def test_get_default_bucket_key(buckets, expected_default_bucket_key):
@pytest.mark.parametrize("buckets, source_length, target_length, expected_bucket_index, expected_bucket",
get_parallel_bucket_tests)
def test_get_parallel_bucket(buckets, source_length, target_length, expected_bucket_index, expected_bucket):
bucket_index, bucket = sockeye.data_io.get_parallel_bucket(buckets, source_length, target_length)
bucket_index, bucket = data_io.get_parallel_bucket(buckets, source_length, target_length)
assert bucket_index == expected_bucket_index
assert bucket == expected_bucket


@pytest.mark.parametrize("source, target, expected_mean, expected_std",
[([[1, 1, 1], [2, 2, 2], [3, 3, 3]],
[[1, 1, 1], [2, 2, 2], [3, 3, 3]], 1.0, 0.0),
([[1, 1], [2, 2], [3, 3]],
[[1, 1, 1], [2, 2, 2], [3, 3, 3]], 1.5, 0.0),
([[1, 1, 1], [2, 2]],
[[1, 1, 1], [2], [3, 3, 3]], 0.75, 0.25)])
def test_length_statistics(source, target, expected_mean, expected_std):
mean, std = data_io.length_statistics(source, target)
assert np.isclose(mean, expected_mean)
assert np.isclose(std, expected_std)


def test_get_training_data_iters():
train_line_count = 100
train_max_length = 30
dev_line_count = 20
dev_max_length = 30
expected_mean = 1.1476392401276574
expected_std = 0.2318455878853099
batch_size = 5
with tmp_digits_dataset("tmp_corpus",
train_line_count, train_max_length, dev_line_count, dev_max_length) as data:
# tmp common vocab
vcb = vocab.build_from_paths([data['source'], data['target']])

train_iter, val_iter, config_data = data_io.get_training_data_iters(data['source'], data['target'],
data['validation_source'],
data['validation_target'],
vocab_source=vcb,
vocab_target=vcb,
vocab_source_path=None,
vocab_target_path=None,
batch_size=batch_size,
batch_by_words=False,
batch_num_devices=1,
fill_up="replicate",
max_seq_len_source=train_max_length,
max_seq_len_target=train_max_length,
bucketing=True,
bucket_width=10)
assert config_data.source == data['source']
assert config_data.target == data['target']
assert config_data.validation_source == data['validation_source']
assert config_data.validation_target == data['validation_target']
assert config_data.vocab_source is None
assert config_data.vocab_target is None
assert config_data.max_observed_source_seq_len == train_max_length - 1
assert config_data.max_observed_target_seq_len == train_max_length
assert np.isclose(config_data.length_ratio_mean, expected_mean)
assert np.isclose(config_data.length_ratio_std, expected_std)

assert train_iter.batch_size == batch_size
assert val_iter.batch_size == batch_size
assert train_iter.default_bucket_key == (train_max_length, train_max_length)
assert val_iter.default_bucket_key == (dev_max_length, dev_max_length)
assert train_iter.max_observed_source_len == config_data.max_observed_source_seq_len
assert train_iter.max_observed_target_len == config_data.max_observed_target_seq_len
assert train_iter.pad_id == vcb[C.PAD_SYMBOL]
assert train_iter.dtype == 'float32'
assert not train_iter.batch_by_words
assert train_iter.fill_up == 'replicate'

# test some batches
bos_id = vcb[C.BOS_SYMBOL]
expected_first_target_symbols = np.full((batch_size,), bos_id, dtype='float32')
for epoch in range(2):
while train_iter.iter_next():
batch = train_iter.next()
assert len(batch.data) == 2
assert len(batch.label) == 1
assert batch.bucket_key in train_iter.buckets
source = batch.data[0].asnumpy()
target = batch.data[1].asnumpy()
label = batch.label[0].asnumpy()
assert source.shape[0] == batch_size
assert target.shape[0] == batch_size
assert label.shape[0] == batch_size
# target first symbol should be BOS
assert np.array_equal(target[:, 0], expected_first_target_symbols)
# label first symbol should be 2nd target symbol
assert np.array_equal(label[:, 0], target[:, 1])
# each label sequence contains one EOS symbol
assert np.sum(label == vcb[C.EOS_SYMBOL]) == batch_size
train_iter.reset()

0 comments on commit 88940a7

Please sign in to comment.