Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #37 from juharris/pad_str
Browse files Browse the repository at this point in the history
pad_sequences: Add support for string value.
  • Loading branch information
Frédéric Branchaud-Charron authored Aug 24, 2018
2 parents dad7fcc + bfb7297 commit b786e96
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
12 changes: 10 additions & 2 deletions keras_preprocessing/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random
import json
from six.moves import range
import six


def pad_sequences(sequences, maxlen=None, dtype='int32',
Expand All @@ -35,12 +36,13 @@ def pad_sequences(sequences, maxlen=None, dtype='int32',
sequences: List of lists, where each element is a sequence.
maxlen: Int, maximum length of all sequences.
dtype: Type of the output sequences.
To pad sequences with variable length strings, you can use `object`.
padding: String, 'pre' or 'post':
pad either before or after each sequence.
truncating: String, 'pre' or 'post':
remove values from sequences larger than
`maxlen`, either at the beginning or at the end of the sequences.
value: Float, padding value.
value: Float or String, padding value.
# Returns
x: Numpy array with shape `(len(sequences), maxlen)`
Expand Down Expand Up @@ -70,7 +72,13 @@ def pad_sequences(sequences, maxlen=None, dtype='int32',
sample_shape = np.asarray(s).shape[1:]
break

x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
"You should set `dtype=object` for variable length strings."
.format(dtype, type(value)))

x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
for idx, s in enumerate(sequences):
if not len(s):
continue # empty list/array was found
Expand Down
22 changes: 22 additions & 0 deletions tests/sequence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import numpy as np
from numpy.testing import assert_allclose
from numpy.testing import assert_equal
from numpy.testing import assert_raises

import keras
Expand Down Expand Up @@ -35,6 +36,27 @@ def test_pad_sequences():
assert_allclose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])


def test_pad_sequences_str():
a = [['1'], ['1', '2'], ['1', '2', '3']]

# test padding
b = sequence.pad_sequences(a, maxlen=3, padding='pre', value='pad', dtype=object)
assert_equal(b, [['pad', 'pad', '1'], ['pad', '1', '2'], ['1', '2', '3']])
b = sequence.pad_sequences(a, maxlen=3, padding='post', value='pad', dtype='<U3')
assert_equal(b, [['1', 'pad', 'pad'], ['1', '2', 'pad'], ['1', '2', '3']])

# test truncating
b = sequence.pad_sequences(a, maxlen=2, truncating='pre', value='pad',
dtype=object)
assert_equal(b, [['pad', '1'], ['1', '2'], ['2', '3']])
b = sequence.pad_sequences(a, maxlen=2, truncating='post', value='pad',
dtype='<U3')
assert_equal(b, [['pad', '1'], ['1', '2'], ['1', '2']])

with pytest.raises(ValueError, match="`dtype` int32 is not compatible with "):
sequence.pad_sequences(a, maxlen=2, truncating='post', value='pad')


def test_pad_sequences_vector():
a = [[[1, 1]],
[[2, 1], [2, 2]],
Expand Down

0 comments on commit b786e96

Please sign in to comment.