diff --git a/LanguageModeling/GPT/tools/README.md b/LanguageModeling/GPT/tools/README.md new file mode 100644 index 0000000..955b149 --- /dev/null +++ b/LanguageModeling/GPT/tools/README.md @@ -0,0 +1,36 @@ +# GPT模型转换 + +### PyTorch模型转OneFlow模型 + - `meta.proto`,是为生成模型目录下的`meta`文件,需要执行`protoc --python_out=. meta.proto`后生成`meta_pb2.py`,即可`import meta_pb2 as meta_pb` + ``` + syntax = "proto2"; + package gpt; + + message Shape { + repeated int32 dim = 1; + } + + enum DataType { + kInvalidDataType = 0; + kChar = 1; + kFloat = 2; + kDouble = 3; + kInt8 = 4; + kInt32 = 5; + kInt64 = 6; + kUInt8 = 7; + kOFRecord = 8; + kFloat16 = 9; + kTensorBuffer = 10; + } + + message Meta { + required Shape shape = 1; + required DataType data_type = 2 [default = kFloat16]; + } + ``` + - 转换脚本`convert_pt_to_of_gpt.py`,执行`python3 convert_pt_to_of_gpt.py --py_model_dir /path/to/iter_0500000/mp_rank_00/model_optim_rng.pt`即可在当前目录下的`convert_pt_to_of_gpt`生成OneFlow模型 + - `--py_model_dir`,pytorch模型地址 + - `--of_dump_path`,保存转换后的模型路径 + + \ No newline at end of file diff --git a/LanguageModeling/GPT/tools/convert_py_model_to_of.py b/LanguageModeling/GPT/tools/convert_py_model_to_of.py new file mode 100644 index 0000000..eea85e8 --- /dev/null +++ b/LanguageModeling/GPT/tools/convert_py_model_to_of.py @@ -0,0 +1,110 @@ +import argparse +import os +import numpy as np +import torch +import meta_pb2 as meta_pb + + +def get_args(): + + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument( + "--py_model_dir", + type=str, + default="/path/to/iter_0500000/mp_rank_00/model_optim_rng.pt", + help="Path the PyTorch checkpoint file path.", + ) + parser.add_argument( + "--of_dump_path", + type=str, + default="./convert_pt_to_of_gpt_release", + help="Path to the output OneFlow model.", + ) + + return parser.parse_args() + + +def _SaveWeightBlob2File(blob, op_name, save_path, var="out", meta="meta"): + folder = os.path.join(save_path, op_name) + if not os.path.exists(folder): + os.makedirs(folder) + filename = os.path.join(folder, var) + f = open(filename, "wb") + f.write(blob.tobytes()) + meta_info = meta_pb.Meta() + meta_info.shape.dim[:] = blob.shape + meta_info.data_type = meta_pb.kFloat + filename = os.path.join(folder, meta) + f = open(filename, "w") + f.write(str(meta_info)) + f.close() + np.save(filename, blob) + + +def _SaveWeightBlob2FileExtend(blob, op_name, save_path, var="out", meta="meta"): + _SaveWeightBlob2File(blob.numpy(), op_name, save_path, var=var, meta=meta) + _SaveWeightBlob2File( + np.ones_like(blob), op_name + "-v", save_path, var=var, meta=meta + ) + _SaveWeightBlob2File( + np.zeros_like(blob), op_name + "-m", save_path, var=var, meta=meta + ) + + +def convert(args): + path = args.py_model_dir + state_dict = torch.load(path, map_location="cpu") + for model_key, model_value in state_dict["model"]["language_model"][ + "transformer" + ].items(): + if len(model_value.shape) > 1: + model_value = torch.transpose(model_value, 0, 1) + model_value = model_value.float() + op_name_list = model_key.split(".") + if "layers." in model_key: + op_name = model_key.replace("layers.", "model-") + op_name = op_name.replace( + "-%s." % (op_name_list[1]), "-h%s-" % (op_name_list[1]) + ) + else: + op_name = model_key.replace("final_layernorm.", "model-layernorm_f-") + op_name = op_name.replace("input_layernorm.", "layernorm_1-") + op_name = op_name.replace("post_attention_layernorm.", "layernorm_2-") + op_name = op_name.replace("attention.", "attn-") + op_name = op_name.replace("query_key_value.", "c_attn-") + op_name = op_name.replace("dense.", "c_proj-") + op_name = op_name.replace("mlp.dense_h_to_4h.", "mlp-c_fc-") + op_name = op_name.replace("mlp.dense_4h_to_h.", "mlp-c_proj-") + + if ( + "layernorm_1" in op_name + or "layernorm_2" in op_name + or "layernorm_f" in op_name + ): + op_name = op_name.replace("-weight", "-gamma") + op_name = op_name.replace("-bias", "-beta") + + print(model_key, "-" * 8, op_name) + _SaveWeightBlob2FileExtend(model_value, op_name, args.of_dump_path) + + _SaveWeightBlob2FileExtend( + state_dict["model"]["language_model"]["embedding"]["position_embeddings"][ + "weight" + ].float(), + "model-wpe", + args.of_dump_path, + ) + _SaveWeightBlob2FileExtend( + state_dict["model"]["language_model"]["embedding"]["word_embeddings"][ + "weight" + ].float(), + "model-wte", + args.of_dump_path, + ) + + +if __name__ == "__main__": + args = get_args() + convert(args) diff --git a/LanguageModeling/GPT/tools/meta.proto b/LanguageModeling/GPT/tools/meta.proto new file mode 100644 index 0000000..1719ad2 --- /dev/null +++ b/LanguageModeling/GPT/tools/meta.proto @@ -0,0 +1,24 @@ +syntax = "proto2"; + +message Shape { + repeated int32 dim = 1; +} + +enum DataType { + kInvalidDataType = 0; + kChar = 1; + kFloat = 2; + kDouble = 3; + kInt8 = 4; + kInt32 = 5; + kInt64 = 6; + kUInt8 = 7; + kOFRecord = 8; + kFloat16 = 9; + kTensorBuffer = 10; +} + +message Meta { + required Shape shape = 1; + required DataType data_type = 2 [default = kFloat16]; +} diff --git a/LanguageModeling/GPT/tools/meta_pb2.py b/LanguageModeling/GPT/tools/meta_pb2.py new file mode 100644 index 0000000..49f1e7b --- /dev/null +++ b/LanguageModeling/GPT/tools/meta_pb2.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: meta.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='meta.proto', + package='', + syntax='proto2', + serialized_options=None, + create_key=_descriptor._internal_create_key, + serialized_pb=b'\n\nmeta.proto\"\x14\n\x05Shape\x12\x0b\n\x03\x64im\x18\x01 \x03(\x05\"E\n\x04Meta\x12\x15\n\x05shape\x18\x01 \x02(\x0b\x32\x06.Shape\x12&\n\tdata_type\x18\x02 \x02(\x0e\x32\t.DataType:\x08kFloat16*\xa3\x01\n\x08\x44\x61taType\x12\x14\n\x10kInvalidDataType\x10\x00\x12\t\n\x05kChar\x10\x01\x12\n\n\x06kFloat\x10\x02\x12\x0b\n\x07kDouble\x10\x03\x12\t\n\x05kInt8\x10\x04\x12\n\n\x06kInt32\x10\x05\x12\n\n\x06kInt64\x10\x06\x12\n\n\x06kUInt8\x10\x07\x12\r\n\tkOFRecord\x10\x08\x12\x0c\n\x08kFloat16\x10\t\x12\x11\n\rkTensorBuffer\x10\n' +) + +_DATATYPE = _descriptor.EnumDescriptor( + name='DataType', + full_name='DataType', + filename=None, + file=DESCRIPTOR, + create_key=_descriptor._internal_create_key, + values=[ + _descriptor.EnumValueDescriptor( + name='kInvalidDataType', index=0, number=0, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kChar', index=1, number=1, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kFloat', index=2, number=2, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kDouble', index=3, number=3, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kInt8', index=4, number=4, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kInt32', index=5, number=5, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kInt64', index=6, number=6, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kUInt8', index=7, number=7, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kOFRecord', index=8, number=8, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kFloat16', index=9, number=9, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + _descriptor.EnumValueDescriptor( + name='kTensorBuffer', index=10, number=10, + serialized_options=None, + type=None, + create_key=_descriptor._internal_create_key), + ], + containing_type=None, + serialized_options=None, + serialized_start=108, + serialized_end=271, +) +_sym_db.RegisterEnumDescriptor(_DATATYPE) + +DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) +kInvalidDataType = 0 +kChar = 1 +kFloat = 2 +kDouble = 3 +kInt8 = 4 +kInt32 = 5 +kInt64 = 6 +kUInt8 = 7 +kOFRecord = 8 +kFloat16 = 9 +kTensorBuffer = 10 + + + +_SHAPE = _descriptor.Descriptor( + name='Shape', + full_name='Shape', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='dim', full_name='Shape.dim', index=0, + number=1, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=14, + serialized_end=34, +) + + +_META = _descriptor.Descriptor( + name='Meta', + full_name='Meta', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='shape', full_name='Meta.shape', index=0, + number=1, type=11, cpp_type=10, label=2, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + _descriptor.FieldDescriptor( + name='data_type', full_name='Meta.data_type', index=1, + number=2, type=14, cpp_type=8, label=2, + has_default_value=True, default_value=9, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=36, + serialized_end=105, +) + +_META.fields_by_name['shape'].message_type = _SHAPE +_META.fields_by_name['data_type'].enum_type = _DATATYPE +DESCRIPTOR.message_types_by_name['Shape'] = _SHAPE +DESCRIPTOR.message_types_by_name['Meta'] = _META +DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +Shape = _reflection.GeneratedProtocolMessageType('Shape', (_message.Message,), { + 'DESCRIPTOR' : _SHAPE, + '__module__' : 'meta_pb2' + # @@protoc_insertion_point(class_scope:Shape) + }) +_sym_db.RegisterMessage(Shape) + +Meta = _reflection.GeneratedProtocolMessageType('Meta', (_message.Message,), { + 'DESCRIPTOR' : _META, + '__module__' : 'meta_pb2' + # @@protoc_insertion_point(class_scope:Meta) + }) +_sym_db.RegisterMessage(Meta) + + +# @@protoc_insertion_point(module_scope)