Skip to content

Commit

Permalink
Minor changes addressing comments
Browse files Browse the repository at this point in the history
 - Adds todo clarifying skipping of _params. in mm_group_quant for matching purposes
 - removes arg use in pipeline to make external use easier
  • Loading branch information
IanNod committed Nov 17, 2023
1 parent 3460784 commit d98e705
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def rewrite(self, mr: TransposedMMResult):
# TODO (ian): make generalizable and not specific for brevitas
if "lm_head.weight" not in mr.param_name:
inline_module_asm = GROUP_MATMUL_TEMPLATE.format(
# TODO (ian): Fix skipping the "_params." portion of the name to match safetensor format with RenameParametersPass
param_name=mr.param_name[8:],
lowp_type="i4",
m=none_to_q(mr.m),
Expand Down
2 changes: 1 addition & 1 deletion python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def forward(token0: torch.Tensor, *state0_flat):

# TODO: Integrate with external parameters to actually be able to run
# TODO: Make more generalizable to be able to quantize with all compile_to options
if args.quantization == "int4" and not compile_to == "linalg":
if quantization == "int4" and not compile_to == "linalg":
from shark_turbine.transforms.quantization import mm_group_quant

mm_group_quant.MMGroupQuantRewriterPass(
Expand Down

0 comments on commit d98e705

Please sign in to comment.