From 03981449477b39953b08598760325912a690f8de Mon Sep 17 00:00:00 2001 From: Arjun Dinesh Jagdale <142811259+ArjunJagdale@users.noreply.github.com> Date: Tue, 24 Jun 2025 11:41:29 +0530 Subject: [PATCH] Fix: Preserve float columns in JSON loader when values are integer-like (e.g. 0.0, 1.0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR fixes a bug in the JSON loader where columns containing float values like `[0.0, 1.0, 2.0]` were being implicitly coerced to `int`, due to pandas or Arrow type inference. This caused issues downstream in statistics computation (e.g., dataset-viewer) where such columns were incorrectly labeled as `"int"` instead of `"float"`. ### ๐Ÿ” What was happening: When the JSON loader falls back to `pandas_read_json()` (after `pa.read_json()` fails), pandas/Arrow can coerce float values to integers if all values are integer-like (e.g., `0.0 == 0`). ### โœ… What this PR does: - Adds a check in the fallback path of `_generate_tables()` - Ensures that columns made entirely of floats are preserved as `"float64"` even if they are integer-like (e.g. `0.0`, `1.0`) - This prevents loss of float semantics when creating the Arrow table ### ๐Ÿงช Reproducible Example: ```json [{"col": 0.0}, {"col": 1.0}, {"col": 2.0}] --- src/datasets/packaged_modules/json/json.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 426083fc718..fdab2241326 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -95,6 +95,7 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table: pa_table = table_cast(pa_table, self.config.features.arrow_schema) return pa_table + def _generate_tables(self, files): for file_idx, file in enumerate(itertools.chain.from_iterable(files)): # If the file is one json object and if we need to look at the items in one specific field @@ -113,8 +114,6 @@ def _generate_tables(self, files): else: with open(file, "rb") as f: batch_idx = 0 - # Use block_size equal to the chunk size divided by 32 to leverage multithreading - # Set a default minimum value of 16kB if the chunk size is really small block_size = max(self.config.chunksize // 32, 16 << 10) encoding_errors = ( self.config.encoding_errors if self.config.encoding_errors is not None else "strict" @@ -123,19 +122,18 @@ def _generate_tables(self, files): batch = f.read(self.config.chunksize) if not batch: break - # Finish current line try: batch += f.readline() except (AttributeError, io.UnsupportedOperation): batch += readline(f) - # PyArrow only accepts utf-8 encoded bytes if self.config.encoding != "utf-8": batch = batch.decode(self.config.encoding, errors=encoding_errors).encode("utf-8") try: while True: try: pa_table = paj.read_json( - io.BytesIO(batch), read_options=paj.ReadOptions(block_size=block_size) + io.BytesIO(batch), + read_options=paj.ReadOptions(block_size=block_size), ) break except (pa.ArrowInvalid, pa.ArrowNotImplementedError) as e: @@ -146,8 +144,6 @@ def _generate_tables(self, files): ): raise else: - # Increase the block size in case it was too small. - # The block size will be reset for the next file. logger.debug( f"Batch of {len(batch)} bytes couldn't be parsed with block_size={block_size}. Retrying with block_size={block_size * 2}." ) @@ -155,7 +151,9 @@ def _generate_tables(self, files): except pa.ArrowInvalid as e: try: with open( - file, encoding=self.config.encoding, errors=self.config.encoding_errors + file, + encoding=self.config.encoding, + errors=self.config.encoding_errors, ) as f: df = pandas_read_json(f) except ValueError: @@ -163,6 +161,14 @@ def _generate_tables(self, files): raise e if df.columns.tolist() == [0]: df.columns = list(self.config.features) if self.config.features else ["text"] + + # โœ… FIX: Coerce float-looking ints (like 0.0, 1.0) back to float64 + for col in df.columns: + col_data = df[col].dropna() + if col_data.apply(lambda x: isinstance(x, float)).all(): + if col_data.apply(lambda x: x.is_integer()).all(): + df[col] = df[col].astype("float64") + try: pa_table = pa.Table.from_pandas(df, preserve_index=False) except pa.ArrowInvalid as e: @@ -176,3 +182,4 @@ def _generate_tables(self, files): break yield (file_idx, batch_idx), self._cast_table(pa_table) batch_idx += 1 +