Skip to content

Commit

Permalink
moved meta-data related classes/methods into ldc.metadata module
Browse files Browse the repository at this point in the history
  • Loading branch information
fracpete committed Feb 4, 2024
1 parent 80c245b commit d6f7621
Show file tree
Hide file tree
Showing 21 changed files with 127 additions and 114 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Changelog
- `llama2-to-pairs` filter has more robust parsing now
- removed numpy dependency for calculating the gcd for a list of integers in the
`Splitter` class (module `ldc.utils`)
- moved meta-data related classes/methods into `ldc.metadata` module



Expand Down
51 changes: 1 addition & 50 deletions src/ldc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import traceback

from dataclasses import dataclass
from typing import List, Union, Dict, Optional
from typing import List, Union

from seppl import Plugin
from seppl import check_compatibility as seppl_check_compatibility
Expand Down Expand Up @@ -244,55 +244,6 @@ def finalize(self):
self.logger().info("Finalizing...")


class MetaDataHandler(object):
"""
Mixin for classes that manage meta-data.
"""

def has_metadata(self) -> bool:
"""
Returns whether meta-data is present.
:return: True if meta-data present
:rtype: bool
"""
raise NotImplementedError()

def get_metadata(self) -> Optional[Dict]:
"""
Returns the meta-data.
:return: the meta-data, None if not available
:rtype: dict
"""
raise NotImplementedError()

def set_metadata(self, metadata: Optional[Dict]):
"""
Sets the meta-data to use.
:param metadata: the new meta-data, can be None
:type metadata: dict
"""
raise NotImplementedError()


def get_metadata(o) -> Optional[Dict]:
"""
Retrieves the meta-data from the specified object.
:param o: the object to get the meta-data from
:return: the meta-data, None if not available
"""
if isinstance(o, MetaDataHandler):
return o.get_metadata()
if hasattr(o, "meta"):
obj = getattr(o, "meta")
if isinstance(obj, dict):
return obj
return None


def ensure_valid_domains(plugin: Plugin):
"""
Checks whether valid domains are specified.
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/filter/_max_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List

from wai.logging import LOGGING_WARNING
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION, get_metadata, MetaDataHandler
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION
from ..metadata import MetaDataHandler, get_metadata
from ._core import Filter
from ldc.pretrain import PretrainData
from ldc.supervised.pairs import PairData
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/filter/_record_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List

from wai.logging import LOGGING_WARNING
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION, get_metadata, MetaDataHandler
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION
from ..metadata import MetaDataHandler, get_metadata
from ._core import Filter
from ldc.pretrain import PretrainData
from ldc.supervised.pairs import PairData
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/filter/_reset_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List

from wai.logging import LOGGING_WARNING
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION, get_metadata, MetaDataHandler
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION
from ..metadata import MetaDataHandler, get_metadata
from ._core import Filter
from ldc.pretrain import PretrainData
from ldc.supervised.pairs import PairData
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/filter/_skip_duplicate_ids.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List

from wai.logging import LOGGING_WARNING
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION, get_metadata
from ldc.core import DOMAIN_PAIRS, DOMAIN_PRETRAIN, DOMAIN_TRANSLATION
from ldc.metadata import get_metadata
from ldc.filter import Filter
from ldc.pretrain import PretrainData
from ldc.supervised.pairs import PairData
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/filter/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import List

from wai.logging import LOGGING_WARNING
from ldc.core import DOMAIN_ANY, get_metadata, MetaDataHandler
from ldc.core import DOMAIN_ANY
from ldc.metadata import MetaDataHandler, get_metadata
from ldc.filter import Filter
from ldc.pretrain import PretrainData
from ldc.supervised.pairs import PairData
Expand Down
75 changes: 75 additions & 0 deletions src/ldc/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Optional, Dict


class MetaDataHandler(object):
"""
Mixin for classes that manage meta-data.
"""

def has_metadata(self) -> bool:
"""
Returns whether meta-data is present.
:return: True if meta-data present
:rtype: bool
"""
raise NotImplementedError()

def get_metadata(self) -> Optional[Dict]:
"""
Returns the meta-data.
:return: the meta-data, None if not available
:rtype: dict
"""
raise NotImplementedError()

def set_metadata(self, metadata: Optional[Dict]):
"""
Sets the meta-data to use.
:param metadata: the new meta-data, can be None
:type metadata: dict
"""
raise NotImplementedError()


def get_metadata(o) -> Optional[Dict]:
"""
Retrieves the meta-data from the specified object.
:param o: the object to get the meta-data from
:return: the meta-data, None if not available
"""
if isinstance(o, MetaDataHandler):
return o.get_metadata()
if hasattr(o, "meta"):
obj = getattr(o, "meta")
if isinstance(obj, dict):
return obj
return None


def add_metadata(meta: Optional[Dict], key: str, value) -> Dict:
"""
Adds the specified key/value pair to the meta-data.
If the provided meta-data dictionary is empty, it gets instantiated.
If value is None, nothing gets added.
If value is a string and empty, nothing gets added.
:param meta: the meta-data to add to
:type meta: dict
:param key: the key to store the value under
:type key: str
:param value: the value to store
:return: the created or updated dictionary
:rtype: dict
"""
if value is None:
return meta
if isinstance(value, str) and (len(value) == 0):
return meta
if meta is None:
meta = dict()
meta[key] = value
return meta
3 changes: 2 additions & 1 deletion src/ldc/pretrain/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass
from typing import Iterable, List, Dict, Optional, Union

from ldc.core import DOMAIN_PRETRAIN, MetaDataHandler
from ldc.core import DOMAIN_PRETRAIN
from ldc.metadata import MetaDataHandler
from ldc.base_io import Reader, Writer, StreamWriter, BatchWriter
from ldc.filter import Filter

Expand Down
9 changes: 5 additions & 4 deletions src/ldc/pretrain/_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ldc.core import domain_suffix
from ldc.base_io import locate_files, open_file, generate_output
from ._core import PretrainData, PretrainReader, BatchPretrainWriter
from ldc.utils import str_to_column_index, add_meta_data
from ldc.utils import str_to_column_index
from ldc.metadata import add_metadata


class AbstractCsvLikePretrainReader(PretrainReader, abc.ABC):
Expand Down Expand Up @@ -166,18 +167,18 @@ def read(self) -> Iterable[PretrainData]:

# ID?
if id_ is not None:
meta = add_meta_data(meta, "id", id_)
meta = add_metadata(meta, "id", id_)

# additional meta-data columns
if self.col_meta is not None:
if self.no_header:
for i in self.idx_meta:
if i > -1:
meta = add_meta_data(meta, str(i), row[i])
meta = add_metadata(meta, str(i), row[i])
else:
for c in self.col_meta:
if c in row:
meta = add_meta_data(meta, c, row[c])
meta = add_metadata(meta, c, row[c])

if self.no_header:
yield PretrainData(
Expand Down
6 changes: 3 additions & 3 deletions src/ldc/pretrain/_jsonlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ldc.core import domain_suffix
from ldc.base_io import locate_files, open_file, generate_output, is_compressed
from ._core import PretrainData, PretrainReader, StreamPretrainWriter
from ldc.utils import add_meta_data
from ldc.metadata import add_metadata


class JsonLinesPretrainReader(PretrainReader):
Expand Down Expand Up @@ -133,13 +133,13 @@ def read(self) -> Iterable[PretrainData]:

# ID?
if id_ is not None:
meta = add_meta_data(meta, "id", id_)
meta = add_metadata(meta, "id", id_)

# additional meta-data columns
if self.att_meta is not None:
for c in self.att_meta:
if c in item:
meta = add_meta_data(meta, c, item[c])
meta = add_metadata(meta, c, item[c])

yield PretrainData(
content=val_content,
Expand Down
6 changes: 3 additions & 3 deletions src/ldc/pretrain/_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ldc.core import domain_suffix
from ldc.base_io import locate_files, generate_output
from ._core import PretrainData, PretrainReader, BatchPretrainWriter
from ldc.utils import add_meta_data
from ldc.metadata import add_metadata


class ParquetPretrainReader(PretrainReader):
Expand Down Expand Up @@ -129,13 +129,13 @@ def read(self) -> Iterable[PretrainData]:

# ID?
if id_ is not None:
meta = add_meta_data(meta, "id", id_)
meta = add_metadata(meta, "id", id_)

# additional meta-data columns
if self.col_meta is not None:
for c in self.col_meta:
if c in row:
meta = add_meta_data(meta, c, row[c])
meta = add_metadata(meta, c, row[c])

yield PretrainData(
content=val_content,
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/supervised/pairs/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass
from typing import Iterable, List, Dict, Optional, Union

from ldc.core import DOMAIN_PAIRS, MetaDataHandler
from ldc.core import DOMAIN_PAIRS
from ldc.metadata import MetaDataHandler
from ldc.base_io import Reader, Writer, StreamWriter, BatchWriter
from ldc.filter import Filter

Expand Down
9 changes: 5 additions & 4 deletions src/ldc/supervised/pairs/_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from ldc.core import domain_suffix
from ldc.base_io import locate_files, open_file, generate_output
from ._core import PairData, PairReader, BatchPairWriter
from ldc.utils import str_to_column_index, add_meta_data
from ldc.utils import str_to_column_index
from ldc.metadata import add_metadata


class AbstractCsvLikePairsReader(PairReader, abc.ABC):
Expand Down Expand Up @@ -197,18 +198,18 @@ def read(self) -> Iterable[PairData]:

# ID?
if id_ is not None:
meta = add_meta_data(meta, "id", id_)
meta = add_metadata(meta, "id", id_)

# additional meta-data columns
if self.col_meta is not None:
if self.no_header:
for i in self.idx_meta:
if i > -1:
meta = add_meta_data(meta, str(i), row[i])
meta = add_metadata(meta, str(i), row[i])
else:
for c in self.col_meta:
if c in row:
meta = add_meta_data(meta, c, row[c])
meta = add_metadata(meta, c, row[c])

yield PairData(
instruction=val_instruction,
Expand Down
6 changes: 3 additions & 3 deletions src/ldc/supervised/pairs/_jsonlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ldc.core import domain_suffix
from ldc.base_io import locate_files, open_file, generate_output, is_compressed
from ._core import PairData, PairReader, StreamPairWriter
from ldc.utils import add_meta_data
from ldc.metadata import add_metadata


class JsonLinesPairReader(PairReader):
Expand Down Expand Up @@ -150,13 +150,13 @@ def read(self) -> Iterable[PairData]:

# ID?
if id_ is not None:
meta = add_meta_data(meta, "id", id_)
meta = add_metadata(meta, "id", id_)

# additional meta-data columns
if self.att_meta is not None:
for c in self.att_meta:
if c in item:
meta = add_meta_data(meta, c, item[c])
meta = add_metadata(meta, c, item[c])

yield PairData(
instruction=val_instruction,
Expand Down
6 changes: 3 additions & 3 deletions src/ldc/supervised/pairs/_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ldc.core import domain_suffix
from ldc.base_io import locate_files, generate_output
from ._core import PairData, PairReader, BatchPairWriter
from ldc.utils import add_meta_data
from ldc.metadata import add_metadata


class ParquetPairsReader(PairReader):
Expand Down Expand Up @@ -146,13 +146,13 @@ def read(self) -> Iterable[PairData]:

# ID?
if id_ is not None:
meta = add_meta_data(meta, "id", id_)
meta = add_metadata(meta, "id", id_)

# additional meta-data columns
if self.col_meta is not None:
for c in self.col_meta:
if c in row:
meta = add_meta_data(meta, c, row[c])
meta = add_metadata(meta, c, row[c])

yield PairData(
instruction=val_instruction,
Expand Down
3 changes: 2 additions & 1 deletion src/ldc/translation/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from dataclasses import dataclass
from typing import Iterable, List, Dict, Optional, Union

from ldc.core import DOMAIN_TRANSLATION, MetaDataHandler
from ldc.core import DOMAIN_TRANSLATION
from ldc.metadata import MetaDataHandler
from ldc.base_io import Reader, Writer, StreamWriter, BatchWriter
from ldc.filter import Filter

Expand Down
Loading

0 comments on commit d6f7621

Please sign in to comment.