Skip to content

Commit

Permalink
add openelm (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
ToluClassics authored May 3, 2024
1 parent e63b5e9 commit 70e6293
Show file tree
Hide file tree
Showing 11 changed files with 954 additions and 11 deletions.
124 changes: 124 additions & 0 deletions examples/text_generation/openelm_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import argparse
import os
import time
from typing import Tuple

import mlx.core as mx
from transformers import AutoTokenizer, AutoConfig

from mlx_transformers.models import OpenELMForCausalLM as MlxOpenELMForCausalLM


def tic():
"Return generation time in seconds"
return time.time()


def toc(msg, start):
"Return generation time in seconds and a message"
end = time.time()
return f"[INFO] {msg}: {end - start:.3f} s"


def load_model(
model_name: str, mlx_model_class
) -> Tuple[MlxOpenELMForCausalLM, AutoTokenizer]:
"""
Load a llama model and tokenizer from the given model name and weights.
Args:
model_name (str): Name of the llama model to load
model_weights (str): Path to the model weights
hgf_model_class: Huggingface model class
mlx_model_class: Mlx model class
Returns:
_type_: _description_
"""
config = AutoConfig.from_pretrained(model_name)
os.path.dirname(os.path.realpath(__file__))

model = mlx_model_class(config)
model.from_pretrained(
model_name,
huggingface_model_architecture="AutoModelForCausalLM",
trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

return model, tokenizer


def generate(model: MlxOpenELMForCausalLM, tokenizer: AutoTokenizer, args):
print(args.prompt)
inputs = tokenizer(args.prompt, return_tensors="np", truncation=True)

inputs = {key: mx.array(v) for key, v in inputs.items()}
print(inputs["input_ids"][0])

skip = 0
prompt_processing = None
tokens = []
start = tic()
for token in model.generate(inputs, max_length=args.max_tokens, temp=args.temp):
tokens.append(token)

if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = toc("Prompt processing", start)

if len(tokens) >= args.max_tokens:
break

elif (len(tokens) % args.write_every) == 0:
# It is perfectly ok to eval things we have already eval-ed.
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)

mx.eval(tokens)
full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], flush=True)
print("------")
print(prompt_processing)
print(full_gen)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument(
"--model-name",
help="The model name to load",
default="apple/OpenELM-1_1B-Instruct",
)
parser.add_argument(
"--prompt",
help="The message to be processed by the model.",
default="Who is your daddy and what does he do?",
)
parser.add_argument(
"--max-tokens", "-m", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=5, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")

args = parser.parse_args()

mx.random.seed(args.seed)
mx.set_default_device(mx.gpu)

model, tokenizer = load_model(
args.model_name,
MlxOpenELMForCausalLM,
)

generate(model, tokenizer, args)
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def toc(msg, start):


def load_model(
model_name: str, mlx_model_class
model_name: str, mlx_model_class, fp16: bool = False
) -> Tuple[MlxPhi3ForCausalLM, AutoTokenizer]:
"""
Load a llama model and tokenizer from the given model name and weights.
Expand All @@ -43,6 +43,7 @@ def load_model(
model_name,
huggingface_model_architecture="AutoModelForCausalLM",
trust_remote_code=True,
fp16=fp16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -87,7 +88,7 @@ def generate(model: MlxPhi3ForCausalLM, tokenizer: AutoTokenizer, args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser = argparse.ArgumentParser(description="Phi3 inference script")
parser.add_argument(
"--model-name",
help="The model name to load",
Expand All @@ -108,6 +109,9 @@ def generate(model: MlxPhi3ForCausalLM, tokenizer: AutoTokenizer, args):
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--fp16", action="store_true", help="Use mixed precision for inference"
)

args = parser.parse_args()

Expand Down
9 changes: 6 additions & 3 deletions examples/text_generation/phi_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def toc(msg, start):


def load_model(
model_name: str, mlx_model_class
model_name: str, mlx_model_class, fp16: bool = False
) -> Tuple[MlxPhiForCausalLM, AutoTokenizer]:
"""
Load a llama model and tokenizer from the given model name and weights.
Expand All @@ -39,7 +39,7 @@ def load_model(
os.path.dirname(os.path.realpath(__file__))

model = mlx_model_class(config)
model.from_pretrained(model_name)
model.from_pretrained(model_name, fp16=fp16)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand Down Expand Up @@ -83,7 +83,7 @@ def generate(model: MlxPhiForCausalLM, tokenizer: AutoTokenizer, args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser = argparse.ArgumentParser(description="Phi inference script")
parser.add_argument(
"--model-name",
help="The model name to load",
Expand All @@ -104,6 +104,9 @@ def generate(model: MlxPhiForCausalLM, tokenizer: AutoTokenizer, args):
"--temp", type=float, default=0.0, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--fp16", action="store_true", help="Use mixed precision for inference"
)

args = parser.parse_args()

Expand Down
1 change: 1 addition & 0 deletions src/mlx_transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from .llama import LlamaForCausalLM, LlamaModel
from .m2m_100 import M2M100ForConditionalGeneration
from .openelm import OpenELMForCausalLM, OpenELMModel
from .phi import PhiForCausalLM, PhiModel
from .phi3 import Phi3ForCausalLM, Phi3Model
from .persimmon import PersimmonForCausalLM
Expand Down
2 changes: 2 additions & 0 deletions src/mlx_transformers/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def sample(logits):
next_token_logits = output.logits[:, -1, :]
next_token = sample(next_token_logits)

yield next_token

while True:
# Update the prompt
next_token = mx.expand_dims(next_token, axis=0)
Expand Down
2 changes: 2 additions & 0 deletions src/mlx_transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,8 @@ def sample(logits):
next_token_logits = output.logits[:, -1, :]
next_token = sample(next_token_logits)

yield next_token

while True:
# Update the prompt
next_token = mx.expand_dims(next_token, axis=0)
Expand Down
Loading

0 comments on commit 70e6293

Please sign in to comment.