From aab71dac71091e8794e235358caeafb7240a648a Mon Sep 17 00:00:00 2001 From: Guillaume Poulin Date: Thu, 26 Feb 2015 13:06:45 +0800 Subject: [PATCH] Make the new hdfs format more robust --- luigi/hdfs.py | 94 +++++++++++++++++++++++++++++++++++++---------- test/hdfs_test.py | 37 +++++++++++++++++++ 2 files changed, 112 insertions(+), 19 deletions(-) diff --git a/luigi/hdfs.py b/luigi/hdfs.py index f04a7cdbef..2e1ba2599b 100644 --- a/luigi/hdfs.py +++ b/luigi/hdfs.py @@ -746,12 +746,17 @@ class PlainFormat(luigi.format.Format): input = 'bytes' output = 'hdfs' - def pipe_reader(cls, path): + def hdfs_writer(self, path): + return self.pipe_writer(path) + + def hdfs_reader(self, path): + return self.pipe_reader(path) + + def pipe_reader(self, path): return HdfsReadPipe(path) def pipe_writer(self, output_pipe): return HdfsAtomicWritePipe(output_pipe) - return output_pipe class PlainDirFormat(luigi.format.Format): @@ -759,11 +764,17 @@ class PlainDirFormat(luigi.format.Format): input = 'bytes' output = 'hdfs' - def pipe_reader(cls, path): + def hdfs_writer(self, path): + return self.pipe_writer(path) + + def hdfs_reader(self, path): + return self.pipe_reader(path) + + def pipe_reader(self, path): # exclude underscore-prefixedfiles/folders (created by MapReduce) return HdfsReadPipe("%s/[^_]*" % path) - def pipe_writer(cls, path): + def pipe_writer(self, path): return HdfsAtomicWriteDirPipe(path) @@ -771,6 +782,30 @@ def pipe_writer(cls, path): PlainDir = PlainDirFormat() +class CompatibleHdfsFormat(luigi.format.Format): + + output = 'hdfs' + + def __init__(self, writer, reader, input=None): + if input is not None: + self.input = input + + self.reader = reader + self.writer = writer + + def pipe_writer(self, output): + return self.writer(output) + + def pipe_reader(self, input): + return self.reader(input) + + def hdfs_writer(self, output): + return self.writer(output) + + def hdfs_reader(self, input): + return self.reader(input) + + class HdfsTarget(FileSystemTarget): def __init__(self, path=None, format=None, is_tmp=False, fs=None): @@ -778,31 +813,52 @@ def __init__(self, path=None, format=None, is_tmp=False, fs=None): assert is_tmp path = tmppath() super(HdfsTarget, self).__init__(path) + if format is None: format = luigi.format.get_default_format() >> Plain - if hasattr(format, 'hdfs_writer'): + old_format = ( + ( + hasattr(format, 'hdfs_writer') or + hasattr(format, 'hdfs_reader') + ) and + not hasattr(format, 'output') + ) + + if not old_format and getattr(format, 'output', '') != 'hdfs': + format = format >> Plain + + if old_format: warnings.warn( - 'hdfs_writer method for format is deprecated, specify the' - 'property output of your format as \'hdfs\' instead', + 'hdfs_writer and hdfs_reader method for format is deprecated,' + 'specify the property output of your format as \'hdfs\' instead', DeprecationWarning, stacklevel=2 ) - format.pipe_writer = format.hdfs_writer - format.output = 'hdfs' - if hasattr(format, 'hdfs_reader'): - warnings.warn( - 'hdfs_reader method for format is deprecated, specify the' - 'property output of your format as \'hdfs\' instead', - DeprecationWarning, - stacklevel=2 + if hasattr(format, 'hdfs_writer'): + format_writer = format.hdfs_writer + else: + w_format = format >> Plain + format_writer = w_format.pipe_writer + + if hasattr(format, 'hdfs_reader'): + format_reader = format.hdfs_reader + else: + r_format = format >> Plain + format_reader = r_format.pipe_reader + + format = CompatibleHdfsFormat( + format_writer, + format_reader, ) - format.pipe_reader = format.hdfs_reader - format.output = 'hdfs' - if not hasattr(format, 'output') or format.output != 'hdfs': - format = format >> Plain + else: + format = CompatibleHdfsFormat( + format.pipe_writer, + format.pipe_reader, + getattr(format, 'input', None), + ) self.format = format diff --git a/test/hdfs_test.py b/test/hdfs_test.py index b0854d62a7..00eda5dc0f 100644 --- a/test/hdfs_test.py +++ b/test/hdfs_test.py @@ -23,12 +23,27 @@ import helpers import luigi import mock +import luigi.format from luigi import hdfs from luigi import six from minicluster import MiniClusterTestCase from nose.plugins.attrib import attr +class ComplexOldFormat(luigi.format.Format): + """Should take unicode but output bytes + """ + + def hdfs_writer(self, output_pipe): + return self.pipe_writer(luigi.hdfs.Plain.hdfs_writer(output_pipe)) + + def pipe_writer(self, output_pipe): + return luigi.format.UTF8.pipe_writer(output_pipe) + + def pipe_reader(self, output_pipe): + return output_pipe + + class TestException(Exception): pass @@ -234,6 +249,28 @@ def test_multifile(self): self.assertEqual(tuple(parts), (b'bar', b'foo')) +@attr('minicluster') +class ComplexOldFormatTest(MiniClusterTestCase): + format = ComplexOldFormat() + + def setUp(self): + super(ComplexOldFormatTest, self).setUp() + self.target = hdfs.HdfsTarget(self._test_file(), format=self.format) + if self.target.exists(): + self.target.remove(skip_trash=True) + + def test_with_write_success(self): + with self.target.open('w') as fobj: + fobj.write(u'foo') + self.assertTrue(self.target.exists()) + + with self.target.open('r') as fobj: + a = fobj.read() + + self.assertFalse(isinstance(a, six.text_type)) + self.assertEqual(a, b'foo') + + @attr('minicluster') class HdfsTargetTests(MiniClusterTestCase):