Skip to content

Commit

Permalink
[llama] transpose global state (#104)
Browse files Browse the repository at this point in the history
Transposed the global state to allow continuous memory access while
updating. Previously only the 0th idx of the 2nd dim was getting updated
on initialization. Also update needs to be stored back to the global
variable. Couple other minor bugfixes.
  • Loading branch information
dan-garvey authored Oct 31, 2023
1 parent 5ab12bf commit 0d69932
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 54 deletions.
1 change: 1 addition & 0 deletions examples/aot_mlp/mlp_export_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def forward(self, x: torch.Tensor):
model = MLP()
example_x = torch.empty(97, 8, dtype=torch.float32)
exported = aot.export(model, example_x)
aot.CompiledModule.run_import(exported.compiled_module)
exported.print_readable()
compiled_binary = exported.compile(save_to=None)

Expand Down
106 changes: 52 additions & 54 deletions examples/llama2_inference/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <</SYS>> hi what are you? [/INST]
"""


class InferenceModel(torch.nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -212,27 +211,15 @@ def slice_up_to_step(global_pkv, seq_step, heads, hidden_dim):
all_pkv_tensors = []
for i in range(heads * 2):
sliced = IREE.tensor_slice(
global_pkv, i, 0, (0, heads), (0, seq_step), (0, hidden_dim)
global_pkv, i, 0, (0, seq_step), (0, heads), (0, hidden_dim)
) # sequence context dim
all_pkv_tensors.append(
IREE.tensor_reshape(sliced, 1, heads, seq_step, hidden_dim)
IREE.tensor_reshape(sliced, 1, seq_step, heads, hidden_dim)
)

return all_pkv_tensors


def update_state(state, state_updates, seq_step, heads, hidden_dim):
all_updates = []
for i in range(heads * 2):
update = IREE.tensor_reshape(
state_updates[i], 1, 1, heads, 1, hidden_dim
)
all_updates.append(
IREE.tensor_update(state, update, i, 0, 0, seq_step, 0)
)
return all_updates


def export_transformer_model(
state_schema_path, hf_model_name, hf_auth_token, compile_to
):
Expand Down Expand Up @@ -262,7 +249,7 @@ def export_transformer_model(
HIDDEN_DIM = 128
BATCH_SIZE = 1
global_pkv = torch.zeros(
size=(HEADS * 2, BATCH_SIZE, HEADS, MAX_STEP_SEQ, HIDDEN_DIM),
size=(HEADS * 2, BATCH_SIZE, MAX_STEP_SEQ, HEADS, HIDDEN_DIM),
dtype=torch.float32,
)
seq_step = AbstractIndex
Expand All @@ -281,39 +268,37 @@ def run_initialize(
token, *state = self.initialize(x, constraints=init_const)
updates = []
self.global_seq_step = IREE.tensor_dim(
state[0], 2
) # 2nd dimension of arbitrarily 0th kv tensor
state[0], 1
) # ? dimension of arbitrarily 0th kv tensor
for i in range(HEADS * 2):
slice_of_state = IREE.tensor_reshape(
state[i], 1, 1, HEADS, self.global_seq_step, HIDDEN_DIM
state[i], 1, 1, self.global_seq_step, HEADS, HIDDEN_DIM
)
updates.append(
IREE.tensor_update(
self.global_state, slice_of_state, i, 0, 0, 0, 0
)
self.global_state = IREE.tensor_update(
self.global_state, slice_of_state, i, 0, 0, 0, 0
)
return token

def run_forward(self, x=AbstractTensor(1, None, dtype=torch.int64)):
state_arg = slice_up_to_step(
self.global_state, self.global_seq_step, HEADS, HIDDEN_DIM
)
forw_const = [state_arg[0].dynamic_dim(2) < MAX_STEP_SEQ] + [
x.dynamic_dim(2) == (state_arg[0].dynamic_dim(2))
forw_const = [state_arg[0].dynamic_dim(1) < MAX_STEP_SEQ] + [
x.dynamic_dim(1) == (state_arg[0].dynamic_dim(1))
for x in state_arg[1:]
]
token, *state_update = self.forward(
x, *state_arg, constraints=forw_const
)
self.global_seq_step = self.global_seq_step + 1
res = update_state(
self.global_state,
state_update,
self.global_seq_step,
HEADS,
HIDDEN_DIM,
)
for i in range(HEADS * 2):
update = IREE.tensor_reshape(
state_update[i], 1, 1, 1, HEADS, HIDDEN_DIM
)
self.global_state = IREE.tensor_update(
self.global_state, update, i, 0, self.global_seq_step, 0, 0
)

self.global_seq_step = self.global_seq_step + 1
return token

def get_global_state(self):
Expand All @@ -328,15 +313,19 @@ def initialize(input_ids):
state1_flat, _ = pytree.tree_flatten(result.past_key_values)
token1 = torch.argmax(result.logits[:, -1, :], dim=1)
token1 = token1[None, :]
state1_flat = [torch.transpose(x, 1, 2) for x in state1_flat]
return token1, *state1_flat

@jittable
def forward(token0: torch.Tensor, *state0_flat):
# Unpad the states.
state0_flat = [torch.transpose(x, 1, 2) for x in state0_flat]
state0 = pytree.tree_unflatten(state0_flat, state_schema)
result = mod.forward(token0, past_key_values=state0)
state1_flat, _ = pytree.tree_flatten(result.past_key_values)
state1_flat = [x[:, :, -1:, :] for x in state1_flat]
state1_flat = [
torch.transpose(x[:, :, -1:, :], 1, 2) for x in state1_flat
]
token1 = torch.argmax(result.logits[:, -1, :], dim=1)
token1 = token1[None, :]
return token1, *state1_flat
Expand Down Expand Up @@ -399,39 +388,48 @@ def run_vmfb_comparison(args):
example_input_id = initial_input.input_ids
device_inputs = [ireert.asdevicearray(config.device, example_input_id)]

step0 = ModuleCompiled["get_seq_step"]()
print("step0 :"+str(step0))
results = ModuleCompiled["run_initialize"](*device_inputs)

pkv = ModuleCompiled["get_global_state"]().to_host()
step = ModuleCompiled["get_seq_step"]()
print(f"step: {step}")
sliced = pkv[0,:,:,:step,:]

def format_out(results):
return torch.tensor(results.to_host()[0][0])

# print(tokenizer.decode(format_out(results)))
for i in range(10):
results = ModuleCompiled["run_forward"](results)
step = ModuleCompiled["get_seq_step"]()
print(f"step: {step}")
print(tokenizer.decode(format_out(results)))
model = InferenceModel(args)


def get_token_from_logits(logits):
return torch.argmax(logits[:, -1, :], dim=1)

base_model_results = model.base_model.forward(example_input_id)
base_model_token = get_token_from_logits(base_model_results.logits)
# print(tokenizer.decode(base_model_token))
bm_pkv = base_model_results.past_key_values
turbine_results = []
torch_results = []
turbine_results.append(format_out(results))
torch_results.append(int(base_model_token))
while base_model_token != 2:
results = ModuleCompiled["run_forward"](results)
step = ModuleCompiled["get_seq_step"]()
pkv = ModuleCompiled["get_global_state"]().to_host()
# print(f"turbine: {tokenizer.decode(format_out(results))}")
base_model_results = model.base_model.forward(
torch.unsqueeze(base_model_token, 0), past_key_values=bm_pkv
)
base_model_token = int(
get_token_from_logits(base_model_results.logits)[0]
)
bm_pkv = base_model_results.past_key_values
# print(f"pytorch: {tokenizer.decode(base_model_token)}")
turbine_results.append(format_out(results))
torch_results.append(base_model_token)

print("\n\n")
print("what is the best hardware company?")
print("\n\n")

print("turbine output: ")
print(tokenizer.decode(turbine_results))
print("torch output: ")
print(tokenizer.decode(torch_results))

matcher = base_model_results.past_key_values[0][0]
print(sliced)
print(matcher)
print(sliced.shape)
print(matcher.shape)

if __name__ == "__main__":
args = parser.parse_args()
Expand Down

0 comments on commit 0d69932

Please sign in to comment.