@@ -54,6 +54,8 @@ class DatasetImporterBuilder(
5454 'ssot_session_key' ,
5555 ]
5656
57+ images_from_observation_dict = {}
58+
5759
5860 @abc .abstractmethod
5961 def get_description (self ):
@@ -83,9 +85,15 @@ def _info(self) -> tfds.core.DatasetInfo:
8385
8486 tmp = dict (features )
8587
88+ # add all image features from observations to a new featuresdict
89+ self .images_from_observation_dict = self .get_images_from_observation_dict ()
90+ if self .images_from_observation_dict :
91+ tmp ['display_image' ] = self .images_from_observation_dict
92+
8693 for key in self .KEYS_TO_STRIP :
8794 if key in tmp :
8895 del tmp [key ]
96+
8997 features = tfds .features .FeaturesDict (tmp )
9098
9199 return tfds .core .DatasetInfo (
@@ -120,15 +128,28 @@ def _generate_examples(
120128 def converter_fn (example ):
121129 # Decode the RLDS Episode and transform it to numpy.
122130 example_out = dict (example )
131+
123132 example_out ['steps' ] = tf .data .Dataset .from_tensor_slices (
124133 example_out ['steps' ]
125134 ).map (decode_fn )
135+
126136 steps = list (iter (example_out ['steps' ].take (- 1 )))
127137 example_out ['steps' ] = steps
128138
129139 example_out = dataset_utils .as_numpy (example_out )
140+ first_step = example_out ['steps' ][0 ]
141+ image_feature_dict = {}
142+
143+ for feature_name in self .images_from_observation_dict :
144+ image_feature_dict [feature_name ] = first_step ['observation' ][
145+ feature_name
146+ ]
147+
148+ if image_feature_dict :
149+ example_out ['display_image' ] = image_feature_dict
130150
131151 example_id = example_out ['tfds_id' ].decode ('utf-8' )
152+
132153 del example_out ['tfds_id' ]
133154 for key in self .KEYS_TO_STRIP :
134155 if key in example_out :
@@ -148,3 +169,17 @@ def get_ds_builder(self):
148169 ds_location = self .get_dataset_location ()
149170 ds_builder = tfds .builder_from_directory (ds_location )
150171 return ds_builder
172+
173+ def get_images_from_observation_dict (self ):
174+ features = self .get_ds_builder ().info .features
175+ tmp = dict (features )
176+ images_from_observation = {}
177+ if 'steps' in tmp and 'observation' in tmp ['steps' ]:
178+ observation = tmp ['steps' ]['observation' ]
179+ for feature_name , feature_data in observation .items ():
180+ if isinstance (feature_data , tfds .features .Image ):
181+ images_from_observation [feature_name ] = feature_data
182+ images_from_observation_dict = tfds .features .FeaturesDict (
183+ images_from_observation
184+ )
185+ return images_from_observation_dict
0 commit comments