forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMetaDataset.py
More file actions
1418 lines (1257 loc) · 54.8 KB
/
MetaDataset.py
File metadata and controls
1418 lines (1257 loc) · 54.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
There are use cases in which we want to combine several datasets:
* **Multimodality:** features from several datasets should be provided at the same time
* Examples: multi-source translation, speech translation with source CTC loss
for stability (needs both source audio and transcription)
* **Multi-Task Learning:** several datasets should be used alternatingly,
such that at each time the dataset of the corresponding task is selected
* Examples: multi-task speech translation (either from audio or from text)
* **Combination of Corpora:** the training data should be split into different datatsets.
This allows creating a combined corpus dynamically
and avoids manual concatenation/shuffling.
* Examples: multi-lingual translation systems
(datasets can be reused from corresponding bilingual systems)
The dataset classes MetaDataset and CombinedDataset which perform these tasks are implemented in MetaDataset.py.
"""
from __future__ import print_function
from Dataset import Dataset, DatasetSeq, init_dataset, convert_data_dims
from CachedDataset2 import CachedDataset2
from Util import NumbersDict, load_json
from Log import log
from random import Random
import numpy
import sys
import typing
class EpochWiseFilter:
"""
Applies some filter to the sequences (e.g. by seq length) for some epoch.
"""
def __init__(self, epochs_opts, debug_msg_prefix="EpochWiseFilter"):
"""
:param dict[(int,int|None),dict[str]] epochs_opts: (ep_start, ep_end) -> epoch opts
:param str debug_msg_prefix:
"""
self.epochs_opts = epochs_opts
self.debug_msg_prefix = debug_msg_prefix
@classmethod
def filter_epoch(cls, opts, seq_order, get_seq_len, debug_msg_prefix):
"""
:param dict[str]|Util.CollectionReadCheckCovered opts:
:param list[int] seq_order: list of seq idxs
:param ((int)->int) get_seq_len: seq idx -> len
:param str debug_msg_prefix:
:return: new seq_order
:rtype: list[int]
"""
import Util
if not isinstance(opts, Util.CollectionReadCheckCovered):
opts = Util.CollectionReadCheckCovered(opts)
if opts.get("max_mean_len"):
max_mean_len = opts.get("max_mean_len")
lens_and_seqs = numpy.array(sorted([(get_seq_len(idx), idx) for idx in seq_order]))
best_num = Util.binary_search_any(
cmp=lambda num: numpy.mean(lens_and_seqs[:num, 0]) - max_mean_len, low=1, high=len(lens_and_seqs) + 1)
assert best_num is not None
selected_seq_idxs = set(lens_and_seqs[:best_num, 1])
# Select subset of seq_order. Keep order as-is.
seq_order = [seq_idx for seq_idx in seq_order if seq_idx in selected_seq_idxs]
print(
("%sOld mean seq len is %f, new is %f, requested max is %f."
" Old num seqs is %i, new num seqs is %i.") %
(debug_msg_prefix,
float(numpy.mean(lens_and_seqs[:, 0])), float(numpy.mean(lens_and_seqs[:best_num, 0])),
max_mean_len, len(lens_and_seqs), best_num),
file=log.v4)
opts.assert_all_read()
return seq_order
def filter(self, epoch, seq_order, get_seq_len):
"""
:param int|None epoch:
:param list[int] seq_order: list of seq idxs
:param ((int)->int) get_seq_len: seq idx -> len
:return: new seq_order
"""
epoch = epoch or 1
old_num_seqs = len(seq_order)
any_filter = False
for (ep_start, ep_end), value in sorted(self.epochs_opts.items()):
if ep_start is None:
ep_start = 1
if ep_end is None or ep_end == -1:
ep_end = sys.maxsize
assert isinstance(ep_start, int) and isinstance(ep_end, int) and 1 <= ep_start <= ep_end
assert isinstance(value, dict)
if ep_start <= epoch <= ep_end:
any_filter = True
seq_order = self.filter_epoch(
opts=value, debug_msg_prefix="%s, epoch %i. " % (self.debug_msg_prefix, epoch),
seq_order=seq_order, get_seq_len=get_seq_len)
if any_filter:
print("%s, epoch %i. Old num seqs %i, new num seqs %i." % (
self.debug_msg_prefix, epoch, old_num_seqs, len(seq_order)), file=log.v4)
else:
print("%s, epoch %i. No filter for this epoch." % (self.debug_msg_prefix, epoch), file=log.v4)
return seq_order
class MetaDataset(CachedDataset2):
"""
The MetaDataset is to be used in the case of **Multimodality**.
Here, the datasets are expected to describe different features of the **same training sequences**.
These features will all be available to the network at the same time.
The datasets to be combined are given via the input parameter ``"datasets"``.
To define which training examples from the different datasets belong together,
a ``"seq_list_file"`` in pickle format has to be created.
It contains a list of sequence tags for each dataset (see example below).
Note, that in general each dataset type has its own tag format, e.g. for the TranslationDataset it is ``line-<n>``,
for the SprintDataset it is ``<corpusname>/<recording>/<segment id>``.
**Providing a sequence list can be omitted**, if the set of sequence tags is the same for all datasets.
When using multiple ExternSprintDataset instances, the sprint segment file can be provided as sequence list.
In this case the MetaDataset assumes that the sequences with equal tag correspond to each other.
This e.g. works when combining TranslationDatasets if all the text files are sentence aligned.
**Example of Sequence List:**
.. code::
{ 'sprint': [
'corpus/ted_1/1',
'corpus/ted_1/2',
'corpus/ted_1/3',
'corpus/ted_1/4',
'translation': [
'line-0',
'line-1',
'line-2',
'line-3']
}
Python dict stored in pickle file. E.g. the sequence tagged with 'corpus/ted_1/3' in the 'sprint' dataset
corresponds to the sequence tagged 'line-2'
in the 'translation' dataset.
**Example of MetaDataset config:**
.. code::
train = {"class": "MetaDataset", "seq_list_file": "seq_list.pkl",
"datasets": {"sprint": train_sprint, "translation": train_translation},
"data_map": {"data": ("sprint", "data"),
"target_text_sprint": ("sprint", "orth_classes"),
"source_text": ("translation", "data"),
"target_text": ("translation", "classes")},
"seq_ordering": "random",
"partition_epoch": 2,
}
This combines a SprintDataset and a TranslationDataset.
These are defined as ``"train_sprint"`` and ``"train_translation"`` separately.
*Note that the current implementation expects one input feature to be called "data".*
**Sequence Sorting:**
If the selected sequence order uses the length of the data (e.g. when using "sorted" or any kind of "laplace"),
a sub-dataset has to be specified via ``seq_order_control_dataset``.
The desired sorting needs to be set as parameter in this sub-daset, setting ``seq_ordering`` for the MetaDataset
will be ignored.
"""
def __init__(self,
datasets,
data_map,
seq_list_file=None,
seq_order_control_dataset=None,
seq_lens_file=None,
data_dims=None,
data_dtypes=None,
window=1, **kwargs):
"""
:param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
:param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
:param str|None seq_list_file: filename. pickle. dict[str,list[str]], dataset-key -> list of sequence tags.
Can be None if tag format is the same for all datasets.
Then the sequence list will be default sequence order of default dataset (``data_map["data"][0]``),
or seq_order_control_dataset.
You only need it if the tag name is not the same for all datasets.
It will currently not act as filter,
as the subdataset controls the sequence order (and thus what seqs to use).
:param str|None seq_order_control_dataset: if set, this dataset will define the order for each epoch.
:param str|None seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len.
Use if getting sequence length from loading data is too costly.
:param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
Deprecated/Only to double check. Read from data if not specified.
:param dict[str,str] data_dtypes: self-data-key -> dtype. Read from data if not specified.
"""
assert window == 1 # not implemented
super(MetaDataset, self).__init__(**kwargs)
assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets
self.data_map = data_map
self.dataset_keys = set([m[0] for m in self.data_map.values()]) # type: typing.Set[str]
self.data_keys = set(self.data_map.keys()) # type: typing.Set[str]
assert "data" in self.data_keys
self.target_list = sorted(self.data_keys - {"data"})
self.default_dataset_key = seq_order_control_dataset or self.data_map["data"][0]
self.seq_order_control_dataset = seq_order_control_dataset
# This will only initialize datasets needed for features occuring in data_map
self.datasets = {
key: init_dataset(datasets[key], extra_kwargs={"name": "%s_%s" % (self.name, key)})
for key in self.dataset_keys} # type: typing.Dict[str,Dataset]
self.seq_list_original = self._load_seq_list(seq_list_file)
self.num_total_seqs = len(self.seq_list_original[self.default_dataset_key])
for key in self.dataset_keys:
assert len(self.seq_list_original[key]) == self.num_total_seqs
self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original[self.default_dataset_key])}
self._seq_lens = None # type: typing.Optional[typing.Dict[str,NumbersDict]]
self._num_timesteps = None # type: typing.Optional[NumbersDict]
if seq_lens_file:
seq_lens = load_json(filename=seq_lens_file)
assert isinstance(seq_lens, dict)
# dict[str,NumbersDict], seq-tag -> data-key -> len
self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original[self.default_dataset_key]])
if data_dims:
data_dims = convert_data_dims(data_dims)
self.data_dims = data_dims
assert "data" in data_dims
for key in self.target_list:
assert key in data_dims
else:
self.data_dims = {}
for data_key in self.data_keys:
dataset_key, dataset_data_key = self.data_map[data_key]
dataset = self.datasets[dataset_key]
if not data_dims:
self.data_dims[data_key] = dataset.num_outputs[dataset_data_key]
if dataset_data_key in dataset.labels:
self.labels[data_key] = dataset.labels[dataset_data_key]
self.num_inputs = self.data_dims["data"][0]
self.num_outputs = self.data_dims
self.orig_seq_order_is_initialized = False
self.seq_list_ordered = None # type: typing.Optional[typing.Dict[str,typing.List[str]]]
def _is_same_seq_name_for_each_dataset(self):
"""
This should be fast.
:rtype: bool
"""
main_list = self.seq_list_original[self.default_dataset_key]
for key, other_list in self.seq_list_original.items():
if main_list is not other_list:
return False
return True
def _load_seq_list(self, seq_list_file=None):
"""
:param str seq_list_file:
:return: dict: dataset key -> seq list
:rtype: dict[str,list[str]]
"""
if seq_list_file:
seq_list = Dataset._load_seq_list_file(seq_list_file, expect_list=False)
else:
# We create a sequence list from all the sequences of the default dataset and hope that it also applies to the
# other datasets. This can only work if all datasets have the same tag format and the sequences in the other
# datasets are a subset of those in the default dataset.
default_dataset = self.datasets[self.default_dataset_key]
assert isinstance(default_dataset, Dataset)
print("Reading sequence list for MetaDataset %r from sub-dataset %r" % (self.name, default_dataset.name),
file=log.v3)
seq_list = default_dataset.get_all_tags()
# Catch index out of bounds errors. Whether the tags are actually valid will be checked in _check_dataset_seq().
for key in self.dataset_keys:
if key == self.default_dataset_key:
continue
try:
if self.datasets[key].get_total_num_seqs() >= len(seq_list):
continue # ok
except NotImplementedError:
continue # we don't know. but continue for now...
print("Dataset %r has less sequences (%i) than in sequence list (%i) read from %r, this cannot work out!" % (
key, self.datasets[key].get_total_num_seqs(), len(seq_list), self.default_dataset_key), file=log.v1)
other_tags = self.datasets[key].get_all_tags()
for tag in seq_list:
if tag not in other_tags:
print(
"Seq tag %r in dataset %r but not in dataset %r." % (tag, self.default_dataset_key, key), file=log.v1)
break # only print one
for tag in other_tags:
if tag not in seq_list:
print(
"Seq tag %r in dataset %r but not in dataset %r." % (tag, key, self.default_dataset_key), file=log.v1)
break # only print one
raise Exception("Dataset %r is missing seqs." % key)
assert isinstance(seq_list, (list, dict))
if isinstance(seq_list, list):
seq_list = {key: seq_list for key in self.dataset_keys}
return seq_list
def _get_dataset_seq_length(self, seq_idx):
if not self.orig_seq_order_is_initialized:
# To use get_seq_length() we first have to init the sequence order once in original order.
# If sequence lengths are not needed by get_seq_order_for_epoch this is never executed.
self.datasets[self.default_dataset_key].init_seq_order(
epoch=self.epoch, seq_list=self.seq_list_original[self.default_dataset_key])
self.orig_seq_order_is_initialized = True
return self.datasets[self.default_dataset_key].get_seq_length(seq_idx)["data"]
def init_seq_order(self, epoch=None, seq_list=None):
"""
:param int|None epoch:
:param list[str]|None seq_list:
:rtype: bool
"""
need_reinit = self.epoch is None or self.epoch != epoch or seq_list
super(MetaDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list)
if not need_reinit:
self._num_seqs = len(self.seq_list_ordered[self.default_dataset_key])
return False
seq_order_dataset = None
if seq_list:
seq_index = [self.tag_idx[tag] for tag in seq_list]
elif self.seq_order_control_dataset:
seq_order_dataset = self.datasets[self.seq_order_control_dataset]
assert isinstance(seq_order_dataset, Dataset)
seq_order_dataset.init_seq_order(epoch=epoch)
seq_index = seq_order_dataset.get_current_seq_order()
else:
if self._seq_lens:
def get_seq_len(s):
"""
:param int s:
:rtype: int
"""
return self._seq_lens[self.seq_list_original[self.default_dataset_key][s]]["data"]
elif self._seq_order_seq_lens_file:
get_seq_len = self._get_seq_order_seq_lens_by_idx
else:
self.orig_seq_order_is_initialized = False
get_seq_len = self._get_dataset_seq_length
seq_index = self.get_seq_order_for_epoch(epoch, self.num_total_seqs, get_seq_len)
self._num_seqs = len(seq_index)
self.seq_list_ordered = {key: [ls[s] for s in seq_index] for (key, ls) in self.seq_list_original.items()}
for dataset_key, dataset in self.datasets.items():
assert isinstance(dataset, Dataset)
if dataset is seq_order_dataset:
continue
dataset.init_seq_order(epoch=epoch, seq_list=self.seq_list_ordered[dataset_key])
return True
def get_all_tags(self):
"""
:return: list of all seq tags, of the whole dataset, without partition epoch
:rtype: list[str]
"""
return self.seq_list_original[self.default_dataset_key]
def finish_epoch(self):
"""
This would get called at the end of the epoch.
"""
super(MetaDataset, self).finish_epoch()
for _, dataset in self.datasets.items():
assert isinstance(dataset, Dataset)
dataset.finish_epoch()
def _load_seqs(self, start, end):
"""
:param int start: inclusive seq idx start
:param int end: exclusive seq idx end. can be more than num_seqs
"""
# Pass on original start|end to super _load_seqs, to perform extra checks and cleanup.
# However, for load_seqs on our subdatasets, and other extra checks,
# do not redo them if they were already done.
# _load_seqs is often called many times with the same start|end, during chunked batch construction.
start_ = start
if self.added_data:
start_ = max(self.added_data[-1].seq_idx + 1, start)
if start_ < end:
for dataset_key in self.dataset_keys:
self.datasets[dataset_key].load_seqs(start_, end)
for seq_idx in range(start_, end):
self._check_dataset_seq(dataset_key, seq_idx)
super(MetaDataset, self)._load_seqs(start=start, end=end)
def _check_dataset_seq(self, dataset_key, seq_idx):
"""
:param str dataset_key:
:param int seq_idx:
"""
dataset_seq_tag = self.datasets[dataset_key].get_tag(seq_idx)
self_seq_tag = self.seq_list_ordered[dataset_key][seq_idx]
assert dataset_seq_tag == self_seq_tag
def _get_data(self, seq_idx, data_key):
"""
:type seq_idx: int
:type data_key: str
:rtype: numpy.ndarray
"""
dataset_key, dataset_data_key = self.data_map[data_key]
dataset = self.datasets[dataset_key] # type: Dataset
return dataset.get_data(seq_idx, dataset_data_key)
def _collect_single_seq(self, seq_idx):
"""
:type seq_idx: int
:rtype: DatasetSeq
"""
seq_tag = self.seq_list_ordered[self.default_dataset_key][seq_idx]
features = self._get_data(seq_idx, "data")
targets = {target: self._get_data(seq_idx, target) for target in self.target_list}
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
def get_seq_length(self, sorted_seq_idx):
"""
:param int sorted_seq_idx:
:rtype: NumbersDict
"""
if self._seq_lens:
return self._seq_lens[self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]]
return super(MetaDataset, self).get_seq_length(sorted_seq_idx)
def get_tag(self, sorted_seq_idx):
"""
:param int sorted_seq_idx:
:rtype: str
"""
return self.seq_list_ordered[self.default_dataset_key][sorted_seq_idx]
def get_target_list(self):
"""
:rtype: list[str]
"""
return self.target_list
def get_data_shape(self, data_key):
"""
:param str data_key:
:rtype: list[int]
"""
dataset_key, dataset_data_key = self.data_map[data_key]
return self.datasets[dataset_key].get_data_shape(dataset_data_key)
def get_data_dtype(self, key):
"""
:param str key:
:rtype: str
"""
dataset_key, dataset_data_key = self.data_map[key]
return self.datasets[dataset_key].get_data_dtype(dataset_data_key)
def is_data_sparse(self, key):
"""
:param str key:
:rtype: bool
"""
dataset_key, dataset_data_key = self.data_map[key]
return self.datasets[dataset_key].is_data_sparse(dataset_data_key)
class ClusteringDataset(CachedDataset2):
"""
This is a special case of MetaDataset,
with one main subdataset, and we add a cluster-idx for each seq.
We will read the cluster-map (seq-name -> cluster-idx) here directly.
"""
def __init__(self, dataset, cluster_map_file, n_clusters, single_cluster=False, **kwargs):
"""
:param dict[str] dataset:
:param cluster_map_file:
:param int n_clusters:
:param single_cluster:
"""
super(CachedDataset2, self).__init__(**kwargs)
self.dataset = init_dataset(dataset)
self.n_clusters = n_clusters
self.single_cluster = single_cluster
self.cluster_map = self._load_cluster_map(cluster_map_file)
self.cluster_idx_dtype = "int32"
self.num_inputs = self.dataset.num_inputs
self.num_outputs = self.dataset.num_outputs.copy()
self.num_outputs["cluster_idx"] = (n_clusters, 1) # will be a single int32
self.expected_load_seq_start = 0
def _load_cluster_map(self, filename):
ls = open(filename).read().splitlines()
assert "<coprus-key-map>" in ls[:3], "We expect the Sprint XML format."
# It has lines like: <map-item key="CHiME3/dt05_bth/M03_22GC010M_BTH.CH5/1" value="0"/>
import re
pattern = re.compile('<map-item key="(.*)" value="(.*)"/>')
cluster_map = {} # type: typing.Dict[str,int] # seq-name -> cluster-idx
for l in ls:
if not l.startswith("<map-item"):
continue
seq_name, cluster_idx_s = pattern.match(l).groups()
cluster_idx = int(cluster_idx_s)
assert 0 <= cluster_idx < self.n_clusters
cluster_map[seq_name] = cluster_idx
return cluster_map
def init_seq_order(self, epoch=None, seq_list=None):
"""
:param int epoch:
:param list[str]|int seq_list:
:rtype: bool
"""
self.dataset.init_seq_order(epoch=epoch, seq_list=seq_list)
return super(ClusteringDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list)
def get_data_keys(self):
"""
:rtype: list[str]
"""
return self.dataset.get_data_keys() + ["cluster_idx"]
def get_data_dtype(self, key):
"""
:param str key:
:rtype: str
"""
if key == "cluster_idx":
return self.cluster_idx_dtype
return self.dataset.get_data_dtype(key)
@property
def num_seqs(self):
"""
:rtype: int
"""
return self.dataset.num_seqs
def is_less_than_num_seqs(self, n):
"""
:param int n:
:rtype: bool
"""
return self.dataset.is_less_than_num_seqs(n)
def _load_seqs(self, start, end):
"""
:param int start:
:param int end:
"""
self.dataset.load_seqs(start, end)
super(ClusteringDataset, self)._load_seqs(start=start, end=end)
def get_tag(self, seq_idx):
"""
:param int seq_idx:
:rtype: str
"""
return self.dataset.get_tag(seq_idx)
def _collect_single_seq(self, seq_idx):
"""
:param int seq_idx:
:rtype: DatasetSeq
"""
seq_name = self.get_tag(seq_idx)
data = {key: self.dataset.get_data(seq_idx=seq_idx, key=key) for key in self.dataset.get_data_keys()}
data["cluster_idx"] = numpy.array([self.cluster_map[seq_name]], dtype=self.cluster_idx_dtype)
return DatasetSeq(seq_idx=seq_idx, features=data["data"], targets=data)
# noinspection PyMethodOverriding
def _generate_batches(self, recurrent_net, batch_size, max_seqs=-1, seq_drop=0.0, max_seq_length=None,
used_data_keys=None):
import sys
if max_seq_length is None:
max_seq_length = sys.maxsize
if batch_size == 0:
batch_size = sys.maxsize
assert batch_size > 0
if max_seqs == -1:
max_seqs = float('inf')
assert max_seqs > 0
assert seq_drop <= 1.0
chunk_size = self.chunk_size
chunk_step = self.chunk_step
from EngineBatch import Batch
batch = Batch()
last_seq_idx = None
for seq_idx, t_start, t_end in self.iterate_seqs(
chunk_size=chunk_size, chunk_step=chunk_step, used_data_keys=used_data_keys):
if self.single_cluster:
if last_seq_idx is not None and last_seq_idx != seq_idx:
last_seq_name = self.get_tag(last_seq_idx)
seq_name = self.get_tag(seq_idx)
if self.cluster_map[last_seq_name] != self.cluster_map[seq_name]:
print("ClusteringDataset::_generate_batches", last_seq_idx, "is not", seq_idx, file=log.v5)
yield batch
batch = Batch()
length = t_end - t_start
if max_seq_length < 0 and length['classes'] > -max_seq_length:
continue
elif 0 < max_seq_length < length.max_value():
continue
if length.max_value() > batch_size:
print("warning: sequence length (%i) larger than limit (%i)" % (length.max_value(), batch_size), file=log.v4)
if self.rnd_seq_drop.random() < seq_drop:
continue
dt, ds = batch.try_sequence_as_slice(length)
if ds > 1 and ((dt * ds).max_value() > batch_size or ds > max_seqs):
yield batch
batch = Batch()
print("batch add slice length", length, file=log.v5)
batch.add_sequence_as_slice(seq_idx=seq_idx, seq_start_frame=t_start, length=length)
last_seq_idx = seq_idx
if batch.get_all_slices_num_frames().max_value() > 0:
yield batch
class ConcatDataset(CachedDataset2):
"""
This concatenates multiple datasets. They are expected to provide the same data-keys and data-dimensions.
It will go through the datasets always in order.
"""
def __init__(self, datasets, **kwargs):
"""
:param list[dict[str]] datasets: list of kwargs for init_dataset
"""
super(ConcatDataset, self).__init__(**kwargs)
self.datasets = [init_dataset(d_kwargs) for d_kwargs in datasets]
assert self.datasets
self.num_inputs = self.datasets[0].num_inputs
self.num_outputs = self.datasets[0].num_outputs
self.labels = self.datasets[0].labels
for ds in self.datasets[1:]:
assert ds.num_inputs == self.num_inputs
assert ds.num_outputs == self.num_outputs
self.dataset_seq_idx_offsets = None # type: typing.Optional[typing.List[int]]
def init_seq_order(self, epoch=None, seq_list=None):
"""
:type epoch: int|None
:param list[str] | None seq_list: In case we want to set a predefined order.
"""
need_reinit = self.epoch is None or self.epoch != epoch
super(ConcatDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list)
self.dataset_seq_idx_offsets = [0]
if not need_reinit:
return False
if seq_list: # reference order
seq_lists = []
for dataset in self.datasets:
# This depends on the num_seqs of our childs.
seq_lists += seq_list[:dataset.num_seqs]
seq_list = seq_list[dataset.num_seqs:]
assert len(seq_list) == 0 # we have consumed all
else:
seq_lists = [None] * len(self.datasets)
if self.seq_ordering != "default":
# Not sure about these cases (random / sorted). Maybe a separate implementation makes more sense.
raise NotImplementedError("seq_ordering %s" % self.seq_ordering)
assert len(seq_lists) == len(self.datasets)
for dataset, sub_list in zip(self.datasets, seq_lists):
dataset.init_seq_order(epoch=epoch, seq_list=sub_list)
return True
def _get_dataset_for_seq_idx(self, seq_idx):
"""
:param int seq_idx:
:rtype: int
"""
i = 0
while i < len(self.dataset_seq_idx_offsets):
if seq_idx + self.dataset_seq_idx_offsets[i] < 0:
return i - 1
i += 1
return i - 1
def _load_seqs(self, start, end):
"""
:param int start:
:param int end:
"""
sub_start = start
# We maybe need to call load_seqs on several of our datasets, thus we need this loop.
while True:
dataset_idx = self._get_dataset_for_seq_idx(sub_start)
dataset = self.datasets[dataset_idx]
dataset_seq_idx_start = sub_start + self.dataset_seq_idx_offsets[dataset_idx]
dataset_seq_idx_end = end + self.dataset_seq_idx_offsets[dataset_idx]
dataset.load_seqs(dataset_seq_idx_start, dataset_seq_idx_end)
if dataset.is_less_than_num_seqs(dataset_seq_idx_end):
# We are still inside this dataset and have loaded everything.
# Thus we can stop now.
break
# We have reached the end of the dataset.
if dataset_idx + 1 == len(self.datasets):
# We are at the last dataset.
break
# Continue with the next one.
self.dataset_seq_idx_offsets[dataset_idx + 1:dataset_idx + 2] = [
self.dataset_seq_idx_offsets[dataset_idx] - dataset.num_seqs]
sub_start = -self.dataset_seq_idx_offsets[dataset_idx + 1]
super(ConcatDataset, self)._load_seqs(start=start, end=end)
def _collect_single_seq(self, seq_idx):
"""
:param int seq_idx:
:rtype: DatasetSeq
"""
dataset_idx = self._get_dataset_for_seq_idx(seq_idx)
dataset = self.datasets[dataset_idx]
dataset_seq_idx = seq_idx + self.dataset_seq_idx_offsets[dataset_idx]
seq_tag = dataset.get_tag(dataset_seq_idx)
features = dataset.get_input_data(dataset_seq_idx)
targets = {k: dataset.get_targets(k, dataset_seq_idx) for k in dataset.get_target_list()}
return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
@property
def num_seqs(self):
"""
:rtype: int
"""
return sum([ds.num_seqs for ds in self.datasets])
def get_target_list(self):
"""
:rtype: list[str]
"""
return self.datasets[0].get_target_list()
class CombinedDataset(CachedDataset2):
"""
The CombinedDataset is to be used in the cases of **Multi-Task Learning** and **Combination of Corpora**.
Here, in general, the datasets describe **different training sequences**.
For each sequence, only the features of the corresponding dataset will be available.
Features of the other datasets are set to empty arrays.
The input parameter ``"datasets"`` is the same as for the MetaDataset.
The ``"data_map"`` is reversed to allow for several datasets mapping to the same feature.
The ``"default"`` ``"seq_ordering"`` is to first go through all sequences of the first dataset,
then the second and so on.
All other sequence orderings (``"random"``, ``"sorted"``, ``"laplace"``, ...) are supported
and based on this "default" ordering.
There is a special sequence ordering ``"random_dataset"``, where we pick datasets at random,
while keeping the sequence order within the datasets as is.
To adjust the ratio of number of training examples from the different datasets in an epoch,
one can use ``"repeat_epoch"`` in some of the datasets to
increase their size relative to the others.
Also, ``"partition_epoch"`` in some of the datasets can be used to shrink them relative to the others.
**Example of CombinedDataset config:**
.. code::
train = {"class": "CombinedDataset",
"datasets": {"sprint": train_sprint, "translation": train_translation},
"data_map": {("sprint", "data"): "data",
("sprint", "orth_classes"): "orth_classes",
("translation", "data"): "source_text",
("translation", "classes"): "orth_classes"},
"seq_ordering": "default",
"partition_epoch": 2,
}
This combines a SprintDataset and a TranslationDataset.
These are defined as ``"train_sprint"`` and ``"train_translation"`` separately.
*Note that the current implementation expects one input feature to be called "data".*
Note: The mapping has been inverted. We now expect (dataset-key, dataset-data-key) -> self-data-key
am-dataset:data -> am-data, am-dataset:classes -> am-classes, lm-dataset:data -> lm-data.
For each sequence idx, it will select one of the given datasets, fill in the data-keys of this dataset
and will return empty sequences for the remaining datasets.
The default sequence ordering is to first go through all sequences of dataset 1, then dataset 2 and so on. If
seq_ordering is set to 'random_dataset', we always pick one of the datasets at random (equally distributed over the
sum of num-seqs), but still go through the sequences of a particular dataset in the order defined for it in the config
(in order if not defined). For 'sorted' or 'laplace' the sequence length as provided by the datasets is used to sort
all sequences jointly. Note, that this overrides the sequence order of the sub-datasets (also the case for 'random').
'partition_epoch' of the CombinedDataset is applied to the joint sequence order for all sequences.
'partition_epoch' of the sub-datasets is still applied. This can be used to adjust the relative size of
the datasets. (However, do not combine 'partition_epoch' on both levels, as this leads to an unexpected selection
of sequences.) To upscale a dataset, rather than downscaling the others via 'partition_epoch', use the
'repeat_epoch' option.
Also see :class:`MetaDataset`.
"""
def __init__(self,
datasets,
data_map,
data_dims=None,
data_dtypes=None,
window=1, **kwargs):
"""
:param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
:param dict[(str,str),str] data_map: (dataset-key, dataset-data-key) -> self-data-key.
Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
:param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
Deprecated/Only to double check. Read from data if not specified.
:param dict[str,str] data_dtypes: self-data-key -> dtype. Read from data if not specified.
"""
assert window == 1 # not implemented
super(CombinedDataset, self).__init__(**kwargs)
assert self.shuffle_frames_of_nseqs == 0 # not implemented. anyway only for non-recurrent nets
self.rnd = Random(self.epoch)
self.dataset_keys = set([m[0] for m in data_map.keys()]) # type: typing.Set[str]
self.dataset_idx2key_map = dict(enumerate(sorted(self.dataset_keys))) # idx -> dataset-key
self.data_keys = set(data_map.values()) # type: typing.Set[str]
assert "data" in self.data_keys
self.target_list = sorted(self.data_keys - {"data"})
# Build target lookup table that maps from dataset_key and data_key (data key used by CombinedDataset)
# to dataset_data_key (data_key used by the sub-dataset). This is needed in get_data() to access data
# by data_key. Maps to None if data_key does not correspond to a feature in datasets[dataset_key].
target_lookup_table = {}
for dataset_key in self.dataset_keys:
target_lookup_table[dataset_key] = {
data_key: dataset_key_tuple[1]
for dataset_key_tuple, data_key in data_map.items()
if dataset_key_tuple[0] == dataset_key}
for key in self.data_keys:
target_lookup_table[dataset_key].setdefault(key, None)
self.target_lookup_table = target_lookup_table
# This will only initialize datasets needed for features occurring in data_map
self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
self._estimated_num_seqs = sum([self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())])
self.estimated_num_seq_per_subset = [self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())]
if data_dims:
data_dims = convert_data_dims(data_dims)
self.data_dims = data_dims
assert "data" in data_dims
for key in self.target_list:
assert key in data_dims
else:
self.data_dims = {}
for dataset_key_tuple, data_key in data_map.items():
dataset_key, dataset_data_key = dataset_key_tuple
dataset = self.datasets[dataset_key]
if not data_dims:
self.data_dims[data_key] = dataset.num_outputs[dataset_data_key]
if dataset_data_key in dataset.labels:
self.labels[data_key] = dataset.labels[dataset_data_key]
self.num_inputs = self.data_dims["data"][0]
self.num_outputs = self.data_dims
self.data_dtypes = {data_key: _select_dtype(data_key, self.data_dims, data_dtypes) for data_key in self.data_keys}
self.dataset_seq_idx_list = None # type: typing.Optional[typing.List[typing.Tuple[int,int]]]
self.seq_order = None # type: typing.Optional[typing.List[int]]
self.dataset_sorted_seq_idx_list = None # type: typing.Optional[typing.List[typing.Tuple[int,int]]]
self.used_num_seqs_per_subset = None # type: typing.Optional[typing.List[int]]
def init_seq_order(self, epoch=None, seq_list=None):
"""
:param int epoch:
:param list[str]|None seq_list:
:rtype: bool
"""
assert seq_list is None, "seq_list not supported for %s" % self.__class__
need_reinit = self.epoch is None or self.epoch != epoch
super(CombinedDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list)
self.rnd.seed(epoch or 1)
if not need_reinit:
return False
# First init sequence order for sub-datasets as usual to get a list of available sequences. This way sorting and
# partition epoch of the individual sub-datasets is still supported. Later we will call init_seq_order again with a
# sequence list to e.g. apply joint sorting or partition epoch of all sequences.
for dataset in self.datasets.values():
dataset.init_seq_order(epoch=epoch)
self.dataset_sorted_seq_idx_list = [] # We will fill this as we go
self.used_num_seqs_per_subset = [0] * len(self.datasets)
return True
def _expand_dataset_sec_idxs(self, num_values):
"""
:param int num_values: Add num_values entries to the dataset-segment-idx mapping table
:return: something?
:rtype: bool
"""
for i in range(num_values):
if self.seq_ordering == "default": # i.e. in order
dataset_idx = 0
while dataset_idx < len(self.datasets):
if self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(
self.used_num_seqs_per_subset[dataset_idx]):
break
dataset_idx += 1
else:
return False # No dataset has remaining data
elif self.seq_ordering == "reversed":
dataset_idx = len(self.datasets) - 1
while dataset_idx >= 0:
if self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(
self.used_num_seqs_per_subset[dataset_idx]):
break
dataset_idx -= 1
else:
return False # No dataset has remaining data
elif self.seq_ordering == "random_dataset":
while True:
# Build probability table
expected_remaining_seqs = [
estimated - used
for estimated, used in zip(self.estimated_num_seq_per_subset, self.used_num_seqs_per_subset)]
total_remaining = float(sum(expected_remaining_seqs))
if total_remaining < 0.1: # We expect no more data, but try anyway
nonempty_datasets = []
for j, k in enumerate(sorted(self.datasets.keys())):
if self.datasets[k].is_less_than_num_seqs(self.used_num_seqs_per_subset[j]):
nonempty_datasets.append(j)
if not nonempty_datasets:
return False # No more data to add
dataset_idx = numpy.random.choice(nonempty_datasets)
self.estimated_num_seq_per_subset[dataset_idx] += 1
break
else: # We sample from all sets which should contain more data
prob_table = [remaining / total_remaining for remaining in expected_remaining_seqs]
dataset_idx = numpy.random.choice(len(self.datasets), p=prob_table)
if self.datasets[self.dataset_idx2key_map[dataset_idx]].is_less_than_num_seqs(
self.used_num_seqs_per_subset[dataset_idx]):
break # Found good Data
else:
self.estimated_num_seq_per_subset[dataset_idx] = self.used_num_seqs_per_subset[dataset_idx]
else:
raise Exception("The sorting method '{}' is not implemented for the case that number of sequences"
"is not known in advance.".format(self.seq_ordering))
# We now have a valid dataset index to take the next segment from
self.dataset_sorted_seq_idx_list.append((dataset_idx, self.used_num_seqs_per_subset[dataset_idx]))
self.used_num_seqs_per_subset[dataset_idx] += 1
return True
def _load_seqs(self, start, end):
# If the segment order is not yet known, fix the next few segments
if end > len(self.dataset_sorted_seq_idx_list):
self._expand_dataset_sec_idxs(end - len(self.dataset_sorted_seq_idx_list))
requested_seqs = self.dataset_sorted_seq_idx_list[start:end]
for dataset_idx in range(len(self.datasets)):
dataset = self.datasets[self.dataset_idx2key_map[dataset_idx]]
sub_requested_seqs = [s[1] for s in requested_seqs if s[0] == dataset_idx]
if not sub_requested_seqs:
continue
sub_start, sub_end = min(sub_requested_seqs), max(sub_requested_seqs)
dataset.load_seqs(sub_start, sub_end + 1)
super(CombinedDataset, self)._load_seqs(start=start, end=end)
def _get_data(self, dataset_key, dataset_seq_idx, data_key):
"""
:type dataset_seq_idx: int
:type dataset_key: str
:type data_key: str
:rtype: numpy.ndarray
"""
dataset_data_key = self.target_lookup_table[dataset_key][data_key]
dataset = self.datasets[dataset_key] # type: Dataset
if dataset_data_key is not None:
return dataset.get_data(dataset_seq_idx, dataset_data_key)
else:
return numpy.array([], self.data_dtypes[data_key])
def _collect_single_seq(self, seq_idx):
"""