Skip to content

Commit

Permalink
using seppl 0.1.0 and its I/O classes now
Browse files Browse the repository at this point in the history
  • Loading branch information
fracpete committed Feb 5, 2024
1 parent d6f7621 commit 97dd57a
Show file tree
Hide file tree
Showing 62 changed files with 218 additions and 1,003 deletions.
8 changes: 3 additions & 5 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
Changelog
=========

0.0.6 (????-??-??)
0.1.0 (2024-02-05)
------------------

- fixed output format of `to-llama2-format` filter
- `llama2-to-pairs` filter has more robust parsing now
- removed numpy dependency for calculating the gcd for a list of integers in the
`Splitter` class (module `ldc.utils`)
- moved meta-data related classes/methods into `ldc.metadata` module

- upgraded seppl to 0.1.0
- switched to seppl classes: Splitter, MetaDataHandler, Reader, Writer, StreamWriter, BatchWriter


0.0.5 (2024-01-24)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -529,13 +529,13 @@ and then outputs it in zstandard-compressed jsonlines format:

```python
from wai.logging import LOGGING_INFO, init_logging
from seppl.io import execute
from ldc.core import Session, ENV_LLM_LOGLEVEL
from ldc.base_io import COMPRESSION_ZSTD
from ldc.registry import register_plugins
from ldc.supervised.pairs import AlpacaReader, PAIRDATA_FIELDS
from ldc.pretrain import JsonLinesPretrainWriter
from ldc.filter import PairsToPretrain, Keyword
from ldc.execution import execute

init_logging(env_var=ENV_LLM_LOGLEVEL)
register_plugins()
Expand Down
1 change: 0 additions & 1 deletion plugins/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
* [max-length-pt](max-length-pt.md)
* [max-records](max-records.md)
* [metadata](metadata.md)
* [multi-filter](multi-filter.md)
* [pairs-to-llama2](pairs-to-llama2.md)
* [pairs-to-pretrain](pairs-to-pretrain.md)
* [pretrain-sentences-to-pairs](pretrain-sentences-to-pairs.md)
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def _read(f):
"pyarrow",
"pyzstd",
"huggingface-hub",
"seppl>=0.0.11",
"seppl>=0.1.0",
"pyyaml",
"wai.logging",
],
version="0.0.5",
version="0.1.0",
author='Peter Reutemann',
author_email='fracpete@waikato.ac.nz',
entry_points={
Expand Down
215 changes: 31 additions & 184 deletions src/ldc/base_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@
import logging

import chardet
import glob
import gzip
import lzma
import os
import pyzstd
import seppl.io

from typing import Union, Iterable, List, Optional
from typing import Iterable, Optional

from ldc.core import CommandlineHandler, DomainHandler, Session, SessionHandler
from ldc.core import DomainHandler
from wai.logging import LOGGING_WARNING
from seppl import OutputProducer, InputConsumer


COMPRESSION_BZIP2 = "bz2"
Expand All @@ -38,73 +37,6 @@
""" the determined max check length for determining the file encoding. """


def locate_files(inputs: Union[str, List[str]], input_lists: Union[str, List[str]] = None,
fail_if_empty: bool = False) -> List[str]:
"""
Locates all the files from the specified inputs, which may contain globs.
glob results get sorted to ensure the same file order each time.
:param inputs: the input path(s) with optional globs
:type inputs: str or list
:param input_lists: text file(s) that list the actual input files to use
:type input_lists: str or list
:param fail_if_empty: whether to throw an exception if no files were located
:type fail_if_empty: bool
:return: the expanded list of files
:rtype: list
"""
if (inputs is None) and (input_lists is None):
raise Exception("Neither input paths nor input lists provided!")

if inputs is not None:
if isinstance(inputs, str):
inputs = [inputs]
elif isinstance(inputs, list):
inputs = inputs
else:
raise Exception("Invalid inputs, must be string(s)!")

if input_lists is not None:
if isinstance(input_lists, str):
input_lists = [input_lists]
elif isinstance(input_lists, list):
input_lists = input_lists
else:
raise Exception("Invalid input lists, must be string(s)!")

result = []

# globs
if inputs is not None:
for inp in inputs:
for f in sorted(glob.glob(inp)):
if os.path.isdir(f):
continue
result.append(f)

# path lists
if input_lists is not None:
for inp in input_lists:
if not os.path.exists(inp):
print("WARNING: Input list does not exist: %s" % inp)
continue
if os.path.isdir(inp):
print("WARNING: Input list points to directory: %s" % inp)
continue
with open(inp, "r") as fp:
lines = [x.strip() for x in fp.readlines()]
for line in lines:
if not os.path.exists(line):
print("WARNING: Path from input list '%s' does not exist: %s" % (inp, line))
continue
result.append(line)

if fail_if_empty and (len(result) == 0):
raise Exception("Failed to locate any files using: %s" % str(inputs))

return result


def encoding_max_check_length() -> int:
"""
Returns the maximum number of bytes to use for determining the file encoding.
Expand Down Expand Up @@ -270,7 +202,7 @@ def generate_output(input_path: str, output_path: str, ext: str, compression: Op
return output_path


class Reader(CommandlineHandler, OutputProducer, DomainHandler, SessionHandler, abc.ABC):
class Reader(seppl.io.Reader, seppl.Initializable, DomainHandler, abc.ABC):
"""
Ancestor of classes that read data.
"""
Expand All @@ -285,111 +217,20 @@ def __init__(self, logger_name: str = None, logging_level: str = LOGGING_WARNING
:type logging_level: str
"""
super().__init__(logger_name=logger_name, logging_level=logging_level)
self._session = None

@property
def session(self) -> Session:
"""
Returns the current session object
:return: the session object
:rtype: Session
"""
return self._session

@session.setter
def session(self, s: Session):
"""
Sets the session object to use.
:param s: the session object
:type s: Session
"""
self._session = s

def read(self) -> Iterable:
"""
Loads the data and returns the items one by one.
:return: the data
:rtype: Iterable
"""
raise NotImplementedError()

def has_finished(self) -> bool:
"""
Returns whether reading has finished.
:return: True if finished
:rtype: bool
"""
raise NotImplementedError()


class Writer(CommandlineHandler, InputConsumer, DomainHandler, SessionHandler, abc.ABC):
class StreamWriter(seppl.io.StreamWriter, seppl.Initializable, DomainHandler, abc.ABC):
"""
Ancestor of classes that write data.
Ancestor for classes that write data one record at a time.
"""

def __init__(self, logger_name: str = None, logging_level: str = LOGGING_WARNING):
"""
Initializes the handler.
:param logger_name: the name to use for the logger
:type logger_name: str
:param logging_level: the logging level to use
:type logging_level: str
"""
super().__init__(logger_name=logger_name, logging_level=logging_level)
self._session = None
self._last_input = None

@property
def session(self) -> Session:
"""
Returns the current session object
:return: the session object
:rtype: Session
"""
return self._session

@session.setter
def session(self, s: Session):
"""
Sets the session object to use.
:param s: the session object
:type s: Session
"""
self._session = s

def _has_input_changed(self, current_input: str = None, update: bool = False) -> bool:
"""
Checks whether the current input is different from the last one we processed.
:param current_input: the current input, uses the current_input from the session if None
:type current_input: str
:param update: whether to update the last input immediately
:type update: bool
:return: True if input has changed
:rtype: bool
"""
if current_input is None:
current_input = self.session.current_input
result = self._last_input != current_input
if update:
self._update_last_input(current_input)
return result

def _update_last_input(self, current_input: str):
def write_stream(self, data):
"""
Updates the last input that was processed.
Saves the data one by one.
:param current_input: the "new" last input
:type current_input: str
:param data: the data to write (single record or iterable of records)
"""
self._last_input = current_input
raise NotImplementedError()

def _output_needs_changing(self, current_output: str, target: str, ext: str) -> bool:
"""
Expand All @@ -412,21 +253,7 @@ def _output_needs_changing(self, current_output: str, target: str, ext: str) ->
return False


class StreamWriter(Writer, abc.ABC):
"""
Ancestor for classes that write data one record at a time.
"""

def write_stream(self, data):
"""
Saves the data one by one.
:param data: the data to write (single record or iterable of records)
"""
raise NotImplementedError()


class BatchWriter(Writer, abc.ABC):
class BatchWriter(seppl.io.BatchWriter, seppl.Initializable, DomainHandler, abc.ABC):
"""
Ancestor of classes that write data all at once.
"""
Expand All @@ -439,3 +266,23 @@ def write_batch(self, data: Iterable):
:type data: Iterable
"""
raise NotImplementedError()

def _output_needs_changing(self, current_output: str, target: str, ext: str) -> bool:
"""
Checks whether the output needs changing.
:param current_output: the current output
:type current_output: str
:param target: the output target
:type target: str
:param ext: the extension for the output file, incl dot
:type ext: str
:return: True if the output needs to change
:rtype: bool
"""
if current_output is None:
return True
output = generate_output(self.session.current_input, target, ext, self.session.options.compression)
if current_output != output:
return True
return False
Loading

0 comments on commit 97dd57a

Please sign in to comment.