Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
speed1313 committed Nov 14, 2024
1 parent 7f57408 commit 1df0fc6
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 130 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ authors = [
{ name = "speed1313", email = "speedtry13@icloud.com" }
]
dependencies = [
"click>=8.1.7",
"vllm>=0.6.1",
"datasets>=3.0.0",
"wandb>=0.18.0",
Expand Down
288 changes: 159 additions & 129 deletions src/text2dataset/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datasets
from datasets import load_dataset
import click
import os
from datasets import Dataset
import wandb
Expand All @@ -13,6 +12,115 @@
from text2dataset.utils import State
from text2dataset.reader import create_dataset
import yaml
import argparse
from dataclasses import dataclass


@dataclass(frozen=True)
class Args:
model_id: str
batch_size: int
tensor_parallel_size: int
pipeline_parallel_size: int
gpu_id: int
input_path: str
source_column: str
target_column: str
push_to_hub: bool
push_to_hub_path: str
output_dir: str
output_format: str
number_sample_per_shard: int
resume_from_checkpoint: bool
use_wandb: bool
wandb_project: str
wandb_run_name: str
prompt_template_path: str
temperature: float
top_p: float
max_tokens: int
target_lang: str
keep_columns: str | None
split: str


def parse_args():
parser = argparse.ArgumentParser(description="Argument parser for model inference")
parser.add_argument(
"--model_id",
type=str,
default="llm-jp/llm-jp-3-3.7b-instruct",
help="Model name. e.g., llm-jp/llm-jp-3-3.7b-instruct. Specify 'gpt-4o-mini-2024-07-18' for OpenAI API or 'deepl' for DeepL API.",
)
parser.add_argument(
"--batch_size", type=int, default=1024, help="Batch size for vLLM inference."
)
parser.add_argument("--tensor_parallel_size", type=int, default=1)
parser.add_argument("--pipeline_parallel_size", type=int, default=1)
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument(
"--input_path",
type=str,
default="data/english_quotes.json",
help="Local file path or Hugging Face dataset name.",
)
parser.add_argument(
"--source_column",
type=str,
default="txt",
help="Column name in the dataset to be prompted.",
)
parser.add_argument(
"--target_column",
type=str,
default="txt_ja",
help="Column name in the dataset to store generated text.",
)
parser.add_argument("--push_to_hub", type=bool, default=False)
parser.add_argument("--push_to_hub_path", type=str, default="speed/english_quotes")
parser.add_argument("--output_dir", type=str, default="data/english_quotes_ja")
parser.add_argument("--output_format", type=str, default="json")
parser.add_argument("--number_sample_per_shard", type=int, default=1000)
parser.add_argument(
"--resume_from_checkpoint",
type=bool,
default=False,
help="Resume from the last checkpoint.",
)
parser.add_argument("--use_wandb", type=bool, default=False)
parser.add_argument("--wandb_project", type=str, default="text2dataset")
parser.add_argument("--wandb_run_name", type=str, default="")
parser.add_argument(
"--prompt_template_path",
type=str,
default="config/prompt.yaml",
help="Path to the prompt template.",
)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--max_tokens", type=int, default=200)
parser.add_argument(
"--target_lang",
type=str,
default="ja",
help="Target language for translation; used for DeepL API.",
)
parser.add_argument(
"--keep_columns",
type=str,
default=None,
help="Columns to keep in the output dataset, separated by comma. If None, all columns are kept. e.g., 'txt'. target_column is always kept.",
)
parser.add_argument(
"--split",
type=str,
default="train",
help="Split of the dataset to use. e.g., 'train', 'validation', 'test'.",
)

args = parser.parse_args()
return Args(**vars(args))


logger = logging.getLogger(__name__)
logging.basicConfig(
Expand All @@ -22,101 +130,13 @@
)


@click.command()
@click.option(
"--model_id",
type=str,
default="llm-jp/llm-jp-3-3.7b-instruct",
help="Model name. e.g. llm-jp/llm-jp-3-3.7b-instruct. If you want to use OpenAI API, specify the model name like 'gpt-4o-mini-2024-07-18'. If you want to use DeepL API, specify 'deepl'.",
)
@click.option(
"--batch_size", type=int, default=1024, help="Batch size for vLLM inference."
)
@click.option("--tensor_parallel_size", type=int, default=1)
@click.option("--pipeline_parallel_size", type=int, default=1)
@click.option("--gpu_id", type=int, default=0)
@click.option(
"--input_path",
type=str,
default="data/english_quotes.json",
help="Local file path or Hugging Face dataset name.",
)
@click.option(
"--source_column",
type=str,
default="txt",
help="Existing column name in the dataset to be prompted.",
)
@click.option(
"--target_column",
type=str,
default="txt_ja",
help="New column name in the dataset to store the generated text.",
)
@click.option("--push_to_hub", type=bool, default=False)
@click.option("--push_to_hub_path", type=str, default="speed/english_quotes")
@click.option("--output_dir", type=str, default="data/english_quotes_ja")
@click.option("--output_format", type=str, default="json")
@click.option("--number_sample_per_shard", type=int, default=1000)
@click.option(
"--resume_from_checkpoint",
type=bool,
default=False,
help="Resume from the last checkpoint.",
)
@click.option("--use_wandb", type=bool, default=False)
@click.option("--wandb_project", type=str, default="text2dataset")
@click.option("--wandb_run_name", type=str, default="")
@click.option(
"--prompt_template_path",
type=str,
default="config/prompt.yaml",
help="Path to the prompt template.",
)
@click.option("--temperature", type=float, default=0.8)
@click.option("--top_p", type=float, default=0.95)
@click.option("--max_tokens", type=int, default=200)
@click.option(
"--target_lang",
type=str,
default="ja",
help="Target language for translation. This is used only for DeepL API.",
)
@click.option(
"--keep_columns",
type=str,
default="txt",
help="Columns to keep in the output dataset. Specify the column names separated by comma.",
)
def main(
model_id: str,
batch_size: int,
output_dir: str,
tensor_parallel_size: int,
pipeline_parallel_size: int,
gpu_id: int,
source_column: str,
target_column: str,
input_path: str,
push_to_hub: bool,
push_to_hub_path: str,
output_format: str,
number_sample_per_shard: int,
resume_from_checkpoint: bool,
use_wandb: bool,
wandb_project: str,
wandb_run_name: str,
prompt_template_path: str,
temperature: float,
top_p: float,
max_tokens: int,
target_lang: str,
keep_columns: str,
):
def main():
args = parse_args()

# Text in source_column of the Dataset will be translated into Japanese.
state = State(0, 0, 0)
if resume_from_checkpoint:
state_path = os.path.join(output_dir, "state.jsonl")
if args.resume_from_checkpoint:
state_path = os.path.join(args.output_dir, "state.jsonl")
if os.path.exists(state_path):
with open(state_path, "r") as f:
state = State(**json.load(f), total_processed_examples=0)
Expand All @@ -129,53 +149,60 @@ def main(

logger.info("Start translation")

os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

os.makedirs(output_dir, exist_ok=True)
state_path = os.path.join(output_dir, "state.jsonl")
ds = create_dataset(input_path, state)
os.makedirs(args.output_dir, exist_ok=True)
state_path = os.path.join(args.output_dir, "state.jsonl")
ds = create_dataset(args.input_path, state, args.split)
# keep only the specified columns
ds = ds.select_columns(keep_columns.split(","))
if args.keep_columns is not None:
ds = ds.select_columns(args.keep_columns.split(","))
# batch dataloader
data_loader = ds.batch(batch_size=batch_size)
data_loader = ds.batch(batch_size=args.batch_size)

if use_wandb:
if args.use_wandb:
config_parameters = dict(locals())
config_parameters.pop("use_wandb")
wandb.init(project=wandb_project, name=wandb_run_name, config=config_parameters)
wandb.init(
project=args.wandb_project,
name=args.wandb_run_name,
config=config_parameters,
)

with open(prompt_template_path) as f:
with open(args.prompt_template_path) as f:
data = yaml.safe_load(f)
template = data["prompt"]

if model_id == "deepl":
translator = DeeplTranslator(target_lang)
elif model_id in ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-07-18"]:
if args.model_id == "deepl":
translator = DeeplTranslator(args.target_lang)
elif args.model_id in ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-07-18"]:
translator = OpenAIAPITranslator(
model_id, template, temperature, top_p, max_tokens
args.model_id, template, args.temperature, args.top_p, args.max_tokens
)
else:
translator = Translator(
model_id,
tensor_parallel_size,
pipeline_parallel_size,
args.model_id,
args.tensor_parallel_size,
args.pipeline_parallel_size,
template,
temperature,
top_p,
max_tokens,
args.temperature,
args.top_p,
args.max_tokens,
)

dataset_buffer = Dataset.from_dict({})

for examples in data_loader:
start_time = time.time()
text_list = examples[source_column]
text_list = examples[args.source_column]
translated = translator.translate(text_list)
# store to buffer
dataset_buffer = datasets.concatenate_datasets(
[
dataset_buffer,
datasets.Dataset.from_dict({**examples, target_column: translated}),
datasets.Dataset.from_dict(
{**examples, args.target_column: translated}
),
]
)
state.total_processed_examples += len(text_list)
Expand All @@ -184,22 +211,25 @@ def main(
# write shards to output_dir if the buffer is full
# e.g number_sample_per_shard = 100, len(dataset_buffer) = 1024
# 1024 // 100 = 10 shards will be written to output_dir
if len(dataset_buffer) >= number_sample_per_shard:
for i in range(len(dataset_buffer) // number_sample_per_shard):
if len(dataset_buffer) >= args.number_sample_per_shard:
for i in range(len(dataset_buffer) // args.number_sample_per_shard):
shard_dict = dataset_buffer[
i * number_sample_per_shard : (i + 1) * number_sample_per_shard
i * args.number_sample_per_shard : (i + 1)
* args.number_sample_per_shard
]
shard_ds = Dataset.from_dict(shard_dict)

state = write_shard(shard_ds, output_dir, output_format, state)
state = write_shard(
shard_ds, args.output_dir, args.output_format, state
)
state.current_shard_id += 1
state.save_state(state_path)

dataset_buffer = Dataset.from_dict(
dataset_buffer[
len(dataset_buffer)
// number_sample_per_shard
* number_sample_per_shard :
// args.number_sample_per_shard
* args.number_sample_per_shard :
]
)

Expand All @@ -214,25 +244,25 @@ def main(

# write the remaining examples
if len(dataset_buffer) > 0:
state = write_shard(dataset_buffer, output_dir, output_format, state)
state = write_shard(dataset_buffer, args.output_dir, args.output_format, state)
state.save_state(state_path)

if push_to_hub:
if output_format == "jsonl" or output_format == "json":
if args.push_to_hub:
if args.output_format == "jsonl" or args.output_format == "json":
# jsonl without state.jsonl
files = os.listdir(output_dir)
files = os.listdir(args.output_dir)
if "state.jsonl" in files:
files.remove("state.jsonl")
# Sort files by shard id to keep the order.
files.sort(key=lambda x: int(x.split(".")[0]))
translated_ds = load_dataset(
"json", data_files=[os.path.join(output_dir, f) for f in files]
"json", data_files=[os.path.join(args.output_dir, f) for f in files]
)
elif output_format == "parquet":
elif args.output_format == "parquet":
translated_ds = load_dataset(
"parquet", data_files=os.path.join(output_dir, "*.parquet")
"parquet", data_files=os.path.join(args.output_dir, "*.parquet")
)
translated_ds.push_to_hub(push_to_hub_path, private=True)
translated_ds.push_to_hub(args.push_to_hub_path, private=True)


if __name__ == "__main__":
Expand Down

0 comments on commit 1df0fc6

Please sign in to comment.