diff --git a/awq/models/base.py b/awq/models/base.py index 06410c44..16a63b67 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -309,6 +309,7 @@ def forward(self, x): max_shard_size=shard_size, safe_serialization=safetensors, force_contiguous=True, + shared_tensors_to_discard=self.model._tied_weights_keys, ) @classmethod diff --git a/setup.py b/setup.py index bc6c1d8f..77d74ed4 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ "accelerate", "datasets>=2.20", "zstandard", + "huggingface_hub @ git+https://github.com/huggingface/huggingface_hub@fix-discard-shared-tensors", ] setup(