diff --git a/csv_wrangler/exporter.py b/csv_wrangler/exporter.py index 466697b..8a83150 100644 --- a/csv_wrangler/exporter.py +++ b/csv_wrangler/exporter.py @@ -29,6 +29,13 @@ class BaseExporter(metaclass=ABCMeta): def to_list(self) -> List[List[str]]: # pragma: no cover pass + def as_response(self, filename: str='export') -> HttpResponse: + response = HttpResponse(content_type='text/csv') + response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename) + writer = csv.writer(response) + [writer.writerow(row) for row in self.to_list()] + return response + class Exporter(Generic[T], BaseExporter, metaclass=ABCMeta): @@ -65,13 +72,6 @@ def to_list(self) -> List[List[str]]: ] return [self.get_header_labels()] + lines - def as_response(self, filename: str='export') -> HttpResponse: - response = HttpResponse(content_type='text/csv') - response['Content-Disposition'] = 'attachment; filename="{}.csv"'.format(filename) - writer = csv.writer(response) - [writer.writerow(row) for row in self.to_list()] - return response - class MultiExporter(BaseExporter): @@ -119,22 +119,12 @@ def to_list(self) -> List[List[str]]: return [self.get_csv_header_labels()] + lines -class PassthroughExporter(Exporter): +class PassthroughExporter(BaseExporter): data = [] # type: List[List[str]] def __init__(self, data: List[List[str]]) -> None: self.data = data - def make_header(self, field_name: str, idx: int) -> Header[List[str]]: - return Header(label=field_name, callback=lambda record: record[idx]) - - def fetch_records(self) -> List[List[str]]: - return self.data[1:] - - def get_headers(self) -> List[Header[List[str]]]: - return [ - self.make_header(field_name, idx) - for idx, field_name in enumerate(self.data[0]) - ] - + def to_list(self) -> List[List[str]]: + return self.data diff --git a/csv_wrangler/test_exporter.py b/csv_wrangler/test_exporter.py index 55ae5f9..a2220f2 100644 --- a/csv_wrangler/test_exporter.py +++ b/csv_wrangler/test_exporter.py @@ -123,3 +123,14 @@ def test_passthrough_to_list(self) -> None: self.assertEqual(results[0], ['a', 'b', 'c']) self.assertEqual(results[1], ['1', '2', '3']) self.assertEqual(results[2], ['2', '3', '4']) + + def test_malformed_passthrough(self) -> None: + exporter = PassthroughExporter([ + ['a', 'b', 'c'], + [], + ['d', 'e', 'f'], + ]) + results = exporter.to_list() + self.assertEqual(results[0], ['a', 'b', 'c']) + self.assertEqual(results[1], []) + self.assertEqual(results[2], ['d', 'e', 'f'])