diff --git a/gokart/file_processor.py b/gokart/file_processor.py index c52fa105..707b3112 100644 --- a/gokart/file_processor.py +++ b/gokart/file_processor.py @@ -14,6 +14,7 @@ from luigi.format import TextFormat from gokart.object_storage import ObjectStorage +from gokart.utils import load_dill_with_pandas_backward_compatibility logger = getLogger(__name__) @@ -82,8 +83,9 @@ def format(self): def load(self, file): if not ObjectStorage.is_buffered_reader(file): - return dill.loads(file.read()) - return dill.load(_ChunkedLargeFileReader(file)) + # we cannot use dill.load(file) because ReadableS3File does not have 'readline' method + return load_dill_with_pandas_backward_compatibility(BytesIO(file.read())) + return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file)) def dump(self, obj, file): self._write(dill.dumps(obj, protocol=4), file) diff --git a/gokart/utils.py b/gokart/utils.py index 99915523..0d6d6617 100644 --- a/gokart/utils.py +++ b/gokart/utils.py @@ -2,9 +2,17 @@ import os import sys -from typing import Iterable, TypeVar +from typing import Any, Iterable, Protocol, TypeVar, Union +import dill import luigi +import pandas as pd + + +class FileLike(Protocol): + def read(self, n: int) -> bytes: ... + + def readline(self) -> bytes: ... def add_config(file_path: str): @@ -58,3 +66,14 @@ def flatten(targets: FlattenableItems[T]) -> list[T]: for result in targets: flat += flatten(result) return flat + + +def load_dill_with_pandas_backward_compatibility(file: FileLike) -> Any: + """Load binary dumped by dill with pandas backward compatibility. + pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle. + It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback. + """ + try: + return pd.read_pickle(file) + except Exception: + return dill.load(file) diff --git a/test/test_file_processor.py b/test/test_file_processor.py index f5f1640d..c4243b63 100644 --- a/test/test_file_processor.py +++ b/test/test_file_processor.py @@ -1,11 +1,15 @@ +import os import tempfile import unittest from typing import Callable +import boto3 import pandas as pd from luigi import LocalTarget +from moto import mock_aws from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, PickleFileProcessor +from gokart.object_storage import ObjectStorage class TestCsvFileProcessor(unittest.TestCase): @@ -115,6 +119,23 @@ def run(self) -> int: self.assertEqual(loaded.run(), obj.run()) + @mock_aws + def test_dump_and_load_with_readables3file(self): + conn = boto3.resource('s3', region_name='us-east-1') + conn.create_bucket(Bucket='test') + file_path = os.path.join('s3://test/', 'test.pkl') + + var = 'abc' + processor = PickleFileProcessor() + + target = ObjectStorage.get_object_storage_target(file_path, processor.format()) + with target.open('w') as f: + processor.dump(var, f) + with target.open('r') as f: + loaded = processor.load(f) + + self.assertEqual(loaded, var) + class TestFeatherFileProcessor(unittest.TestCase): def test_feather_should_return_same_dataframe(self):