From 9fbab55aa141c78f0e850a842998cafd3ead4d67 Mon Sep 17 00:00:00 2001 From: Jan Deriu Date: Wed, 19 Jun 2024 09:05:55 +0200 Subject: [PATCH] bugfixes --- examples/local_minhash_deduplication.py | 2 +- pipelines/curia_vista.py | 18 +++-- pipelines/hugginface_pipeline.py | 3 +- pipelines/swissdox_raw.py | 9 +-- src/datatrove/executor/local.py | 2 +- src/swiss_ai/pipeline/__init__.py | 0 src/swiss_ai/pipeline/pii_removal.py | 3 + src/swiss_ai/readers/curia_vista.py | 87 +++++++++++++++++-------- src/swiss_ai/readers/swissdox.py | 16 +++-- src/swiss_ai/writers/jsonl.py | 4 +- 10 files changed, 90 insertions(+), 54 deletions(-) create mode 100644 src/swiss_ai/pipeline/__init__.py create mode 100644 src/swiss_ai/pipeline/pii_removal.py diff --git a/examples/local_minhash_deduplication.py b/examples/local_minhash_deduplication.py index 3800fb09..c195eabd 100644 --- a/examples/local_minhash_deduplication.py +++ b/examples/local_minhash_deduplication.py @@ -14,7 +14,7 @@ # you can also change ngrams or the number of buckets and their size here minhash_config = MinhashConfig(use_64bit_hashes=True) # better precision -> fewer false positives (collisions) -corpus = 'swissdox' +corpus = 'curiavista' S3_MINHASH_BASE_PATH = f"/work_space_data/{corpus}/minhash/" diff --git a/pipelines/curia_vista.py b/pipelines/curia_vista.py index adf8fa30..391bbfef 100644 --- a/pipelines/curia_vista.py +++ b/pipelines/curia_vista.py @@ -1,6 +1,6 @@ from swiss_ai.readers.curia_vista import RawCuriaVistaReader from datatrove.pipeline.tokens import TokensCounter, LengthCounter -from datatrove.pipeline.writers import JsonlWriter +from swiss_ai.writers.jsonl import SwissAIJsonlWriter from datatrove.pipeline.readers import JsonlReader from datatrove.executor.local import LocalPipelineExecutor from datetime import datetime @@ -8,7 +8,10 @@ now = datetime.now() if __name__ == '__main__': - table = 'Business' + table = 'Transcript' + trascr_cols = [ + 'Text' + ] now = datetime.now() batch = now.strftime("%Y_%m_%d_%H_%M_%S") @@ -20,17 +23,12 @@ RawCuriaVistaReader( table=table, progress=True, - columns=[ - 'SubmittedText', - 'FederalCouncilResponseText', - 'InitialSituation', - 'Proceedings' - ], - limit=100 + columns=trascr_cols, + limit=1500000 ), TokensCounter(tokenizer_name_or_path='t5-small'), LengthCounter(), - JsonlWriter( + SwissAIJsonlWriter( output_folder=f"/work_space_data/curiavista/{table}/jsonl_{batch}" ) ] diff --git a/pipelines/hugginface_pipeline.py b/pipelines/hugginface_pipeline.py index 61ae5f9f..dc762457 100644 --- a/pipelines/hugginface_pipeline.py +++ b/pipelines/hugginface_pipeline.py @@ -60,7 +60,6 @@ def _multilegal_adapter(data: dict, path: str, id_in_file: int | str): limit=1000 ), TokensCounter(tokenizer_name_or_path='t5-small'), - LengthCounter(), SwissAIJsonlWriter( output_folder="/work_space_data/multilegal_pile/jsonl" ) @@ -69,7 +68,7 @@ def _multilegal_adapter(data: dict, path: str, id_in_file: int | str): exec = LocalPipelineExecutor( pipeline=pipeline, tasks=16, - workers=1, + workers=16, start_method="spawn", logging_dir="/work_space_data/multilegal_pile/logging" ) diff --git a/pipelines/swissdox_raw.py b/pipelines/swissdox_raw.py index 4d6b8c29..801de925 100644 --- a/pipelines/swissdox_raw.py +++ b/pipelines/swissdox_raw.py @@ -1,13 +1,11 @@ """ """ - from swiss_ai.readers.swissdox import RawSwissDoxReader from datatrove.pipeline.tokens import TokensCounter, LengthCounter -from datatrove.pipeline.writers import JsonlWriter +from swiss_ai.writers.jsonl import SwissAIJsonlWriter from datatrove.executor.local import LocalPipelineExecutor -os.environ["HF_BASE"] = "/work_space_data/hf_cache/" if __name__ == '__main__': pipeline = [ @@ -16,15 +14,14 @@ limit=-1 ), TokensCounter(tokenizer_name_or_path='t5-small'), - LengthCounter(), - JsonlWriter( + SwissAIJsonlWriter( output_folder="/work_space_data/swissdox/jsonl" ) ] exec = LocalPipelineExecutor( pipeline=pipeline, - tasks=16, + tasks=64, workers=16, start_method="spawn", logging_dir="/work_space_data/swissdox/logging" diff --git a/src/datatrove/executor/local.py b/src/datatrove/executor/local.py index 93b6c172..ef6df4ca 100644 --- a/src/datatrove/executor/local.py +++ b/src/datatrove/executor/local.py @@ -141,7 +141,7 @@ def run(self): ) # merged stats stats = sum(stats, start=PipelineStats()) - with self.logging_dir.open("stats.json", "wt") as statsfile: + with self.logging_dir.open("stats.json", "wt", encoding='utf-8') as statsfile: stats.save_to_disk(statsfile) logger.success(stats.get_repr(f"All {self.local_tasks} tasks")) return stats diff --git a/src/swiss_ai/pipeline/__init__.py b/src/swiss_ai/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/swiss_ai/pipeline/pii_removal.py b/src/swiss_ai/pipeline/pii_removal.py new file mode 100644 index 00000000..df2118cc --- /dev/null +++ b/src/swiss_ai/pipeline/pii_removal.py @@ -0,0 +1,3 @@ +from presidio_analyzer import AnalyzerEngine +from presidio_anonymizer import AnonymizerEngine + diff --git a/src/swiss_ai/readers/curia_vista.py b/src/swiss_ai/readers/curia_vista.py index 1b827cd5..65853ee4 100644 --- a/src/swiss_ai/readers/curia_vista.py +++ b/src/swiss_ai/readers/curia_vista.py @@ -6,7 +6,7 @@ import xml.etree.ElementTree as ET from tqdm import tqdm - +from langdetect import detect from datatrove.io import DataFolderLike, get_datafolder from datatrove.pipeline.readers.base import BaseReader, DocumentsPipeline @@ -68,7 +68,7 @@ def parse_ids(self, id_url): if not child.tag.endswith('entry'): continue idx = child[-1][0][-1].text - indices.add(idx) + indices.add(int(idx)) return indices def retrieve_single_record_for_id(self, in_id): @@ -91,13 +91,18 @@ def retrieve_single_record_for_id(self, in_id): return all_data def _curia_vista_adapter(self, data: dict, path: str, id_in_file: int | str): - text = ''.join([f'

{col}

{data[col]}' for col in self.columns if data[col] is not None]) - meta_data = {k: v for k, v in data.items() if k not in self.columns} - if not text: - text = 'DUMMY_TEXT' - meta_data['delete'] = True + opt_meta_data = {k: v for k, v in data.items() if k not in self.columns} + lang = opt_meta_data['LanguageOfText'].lower() if opt_meta_data['LanguageOfText'] is not None else None + if lang is None: + lang = detect(text) + + meta_data = { + 'optional': opt_meta_data, + 'language': lang, + 'year': int(opt_meta_data['MeetingDate'][:4]) + } return { "text": text, @@ -107,25 +112,53 @@ def _curia_vista_adapter(self, data: dict, path: str, id_in_file: int | str): } def run(self, data: DocumentsPipeline = None, rank: int = 0, world_size: int = 1) -> DocumentsPipeline: - processed_ids = set([doc.id for doc in data]) + processed_ids = set() + processed_dp = set() + try: + for document in data: + processed_ids.add(document.id) + dp = f'{document.id}_{document.metadata["language"]}' + if dp in processed_dp: + continue + processed_dp.add(dp) + if document.metadata["language"] is None: + document.metadata["language"] = detect(document.text) + + yield document + except: + print('Noooo') + + print(f'Already processed {len(processed_ids)} Documents') + if len(processed_ids) > 0: + last_id = max(processed_ids) + else: + last_id = 0 + ids = ['dummy'] limit = self.limit - if not limit == -1: - id_url = f"{self.base_url}?$top={limit}&$filter=Language eq 'DE'&$select=ID" - else: - id_url = f"{self.base_url}?$filter=Language eq 'DE'&$select=ID" - ids = self.parse_ids(id_url) - ids = ids.difference(processed_ids) - - for nr, entry_id in tqdm(enumerate(ids, start=1)): - with self.track_time(): - entries = self.retrieve_single_record_for_id(entry_id) - - if nr % 10 == 0: - time.sleep(self.timeout) - - for data_dict in entries: - document = self.get_document_from_dict(data_dict, self.table, entry_id) - if not document: - continue - yield document + + global_count = 0 + while len(ids) > 0 and global_count < limit: + id_url = f"{self.base_url}?$filter=Language eq 'DE' and ID gt {last_id} &$orderby=ID&$select=ID&$top=100" + ids = self.parse_ids(id_url) + ids = ids.difference(processed_ids) + ids = sorted(ids) + + for nr, entry_id in tqdm(enumerate(ids, start=1)): + with self.track_time(): + entries = self.retrieve_single_record_for_id(entry_id) + global_count += len(entries) + + last_id = entry_id + if nr % 10 == 0: + time.sleep(self.timeout) + + for data_dict in entries: + document = self.get_document_from_dict(data_dict, self.table, entry_id) + if not document: + continue + yield document + if global_count >= limit: + break + processed_ids.add(entry_id) + time.sleep(60) diff --git a/src/swiss_ai/readers/swissdox.py b/src/swiss_ai/readers/swissdox.py index 6c2497c3..9386c75e 100644 --- a/src/swiss_ai/readers/swissdox.py +++ b/src/swiss_ai/readers/swissdox.py @@ -95,6 +95,7 @@ def iterate_entries(self, f, meta_data: dict): if ignroe_article: continue tmp_text = f"{tmp_text}\n{content}" + def load_meta_data(self, filepath): meta_data_full = {} with self.meta_data_folder.open(filepath, "r", encoding='utf-8', compression=self.compression) as mf: @@ -108,11 +109,16 @@ def load_meta_data(self, filepath): news_paper_short = sline[-3] date = sline[-4] - meta_dict = json.loads(sdict) - meta_dict['news_paper_short'] = news_paper_short.strip() - meta_dict['news_paper'] = news_paper.strip() - meta_dict['pub_date'] = date.strip() - meta_dict['lang'] = lang + opt_meta_dict = json.loads(sdict) + opt_meta_dict['news_paper_short'] = news_paper_short.strip() + opt_meta_dict['news_paper'] = news_paper.strip() + opt_meta_dict['pub_date'] = date.strip() + meta_dict = { + 'language': lang, + 'year': int(date.strip().split('-')[0]), + 'optional': opt_meta_dict + } + meta_data_full[lid] = meta_dict return meta_data_full diff --git a/src/swiss_ai/writers/jsonl.py b/src/swiss_ai/writers/jsonl.py index bae76e04..1bd31a2d 100644 --- a/src/swiss_ai/writers/jsonl.py +++ b/src/swiss_ai/writers/jsonl.py @@ -2,6 +2,7 @@ from typing import IO, Callable from datatrove.io import DataFolderLike +from loguru import logger from datatrove.pipeline.writers.disk_base import DiskWriter from swiss_ai.utils.language_list import LANGUAGE_CODES from datetime import datetime @@ -85,7 +86,6 @@ def _check_required_metadata(required_metadata: dict): def _write(self, document: dict, file_handler: IO, _filename: str): passed_check = SwissAIJsonlWriter.check_document(document) if not passed_check: - #TODO handle this better and give more descriptive feedback - raise ValueError('Document is not valid') + logger.warning(f'Document not valid: {str(document)}') file_handler.write(json.dumps(document, ensure_ascii=False) + "\n")