@@ -45,16 +45,6 @@ def __getitems__(self, keys: Iterable[int]) -> T:
4545 """Returns the value for the given `keys`."""
4646
4747
48- def file_instructions (
49- dataset_info : dataset_info_lib .DatasetInfo ,
50- split : splits_lib .Split | None = None ,
51- ) -> list [shard_utils .FileInstruction ]:
52- """Retrieves the file instructions from the DatasetInfo."""
53- split_infos = dataset_info .splits .values ()
54- split_dict = splits_lib .SplitDict (split_infos = split_infos )
55- return split_dict [split ].file_instructions
56-
57-
5848@dataclasses .dataclass
5949class BaseDataSource (MappingView , Sequence ):
6050 """Base DataSource to override all dunder methods with the deserialization.
@@ -94,6 +84,16 @@ def _deserialize(self, record: Any) -> Any:
9484 return features .deserialize_example_np (record , decoders = self .decoders ) # pylint: disable=attribute-error
9585 raise ValueError ('No features set, cannot decode example!' )
9686
87+ @property
88+ def split_info (self ) -> splits_lib .SplitInfo | splits_lib .SubSplitInfo :
89+ """Returns the SplitInfo for the split."""
90+ splits = self .dataset_info .splits
91+ if self .split not in splits :
92+ raise ValueError (
93+ f'Split { self .split } not found in dataset { self .dataset_info .name } !'
94+ )
95+ return splits [self .split ]
96+
9797 def __getitem__ (self , key : SupportsIndex ) -> Any :
9898 record = self .data_source [key .__index__ ()]
9999 return self ._deserialize (record )
@@ -133,7 +133,7 @@ def __repr__(self) -> str:
133133 )
134134
135135 def __len__ (self ) -> int :
136- return self .data_source . __len__ ( )
136+ return sum ( fi . examples_in_shard for fi in self .split_info . file_instructions )
137137
138138 def __iter__ (self ):
139139 for i in range (self .__len__ ()):
0 commit comments