diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 426083fc718..fdab2241326 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -95,6 +95,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, self.config.features.arrow_schema) return pa_table + def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): # If the file is one json object and if we need to look at the items in one specific field @@ -113,8 +114,6 @@ def _generate_tables(self, files): else: with open(file, "rb") as f: batch_idx = 0 - # Use block_size equal to the chunk size divided by 32 to leverage multithreading - # Set a default minimum value of 16kB if the chunk size is really small block_size = max(self.config.chunksize // 32, 16 << 10) encoding_errors = ( self.config.encoding_errors if self.config.encoding_errors is not None else "strict" @@ -123,19 +122,18 @@ def _generate_tables(self, files): batch = f.read(self.config.chunksize) if not batch: break - # Finish current line try: batch += f.readline() except (AttributeError, io.UnsupportedOperation): batch += readline(f) - # PyArrow only accepts utf-8 encoded bytes if self.config.encoding != "utf-8": batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") try: while True: try: pa_table = paj.read_json( - io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) + io.BytesIO(batch), + read_options=paj.ReadOptions(block_size=block_size), ) break except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: @@ -146,8 +144,6 @@ def _generate_tables(self, files): ): raise else: - # Increase the block size in case it was too small. - # The block size will be reset for the next file. logger.debug( f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." ) @@ -155,7 +151,9 @@ def _generate_tables(self, files): except pa.ArrowInvalid as e: try: with open( - file, encoding=self.config.encoding, errors=self.config.encoding_errors + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, ) as f: df = pandas_read_json(f) except ValueError: @@ -163,6 +161,14 @@ def _generate_tables(self, files): raise e if df.columns.tolist() == [0]: df.columns = list(self.config.features) if self.config.features else ["text"] + + # ✅ FIX: Coerce float-looking ints (like 0.0, 1.0) back to float64 + for col in df.columns: + col_data = df[col].dropna() + if col_data.apply(lambda x: isinstance(x, float)).all(): + if col_data.apply(lambda x: x.is_integer()).all(): + df[col] = df[col].astype("float64") + try: pa_table = pa.Table.from_pandas(df, preserve_index=False) except pa.ArrowInvalid as e: @@ -176,3 +182,4 @@ def _generate_tables(self, files): break yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 +