diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index 4b1b5365146..d9f790343d9 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -18,5 +18,5 @@ from .logarithmic_equalization import * from .obcq import * from .pruning import * -from .quantization_legacy import * +from .quantization import * from .smoothquant import * diff --git a/src/sparseml/modifiers/quantization/__init__.py b/src/sparseml/modifiers/quantization/__init__.py index 9cdf715c135..2e1cdc0d24c 100644 --- a/src/sparseml/modifiers/quantization/__init__.py +++ b/src/sparseml/modifiers/quantization/__init__.py @@ -14,4 +14,5 @@ # flake8: noqa -from .base import * +from .gptq import * +from .quantization import * diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/__init__.py b/src/sparseml/modifiers/quantization/gptq/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/__init__.py rename to src/sparseml/modifiers/quantization/gptq/__init__.py diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/base.py b/src/sparseml/modifiers/quantization/gptq/base.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/base.py rename to src/sparseml/modifiers/quantization/gptq/base.py diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py b/src/sparseml/modifiers/quantization/gptq/pytorch.py similarity index 97% rename from src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py rename to src/sparseml/modifiers/quantization/gptq/pytorch.py index c76382db647..e9e3f715625 100644 --- a/src/sparseml/modifiers/quantization_legacy/gptq/pytorch.py +++ b/src/sparseml/modifiers/quantization/gptq/pytorch.py @@ -19,8 +19,8 @@ from sparseml.core.model import ModifiableModel from sparseml.core.state import State -from sparseml.modifiers.quantization_legacy.gptq.base import GPTQModifier -from sparseml.modifiers.quantization_legacy.gptq.utils.gptq_wrapper import GPTQWrapper +from sparseml.modifiers.quantization.gptq.base import GPTQModifier +from sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper from sparseml.modifiers.utils.layer_compressor import LayerCompressor from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward from sparseml.utils.fsdp.context import fix_fsdp_module_name diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/utils/__init__.py b/src/sparseml/modifiers/quantization/gptq/utils/__init__.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/utils/__init__.py rename to src/sparseml/modifiers/quantization/gptq/utils/__init__.py diff --git a/src/sparseml/modifiers/quantization_legacy/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py similarity index 100% rename from src/sparseml/modifiers/quantization_legacy/gptq/utils/gptq_wrapper.py rename to src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py diff --git a/src/sparseml/modifiers/quantization/quantization/__init__.py b/src/sparseml/modifiers/quantization/quantization/__init__.py new file mode 100644 index 00000000000..9cdf715c135 --- /dev/null +++ b/src/sparseml/modifiers/quantization/quantization/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +from .base import * diff --git a/src/sparseml/modifiers/quantization/base.py b/src/sparseml/modifiers/quantization/quantization/base.py similarity index 100% rename from src/sparseml/modifiers/quantization/base.py rename to src/sparseml/modifiers/quantization/quantization/base.py diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/quantization/pytorch.py similarity index 98% rename from src/sparseml/modifiers/quantization/pytorch.py rename to src/sparseml/modifiers/quantization/quantization/pytorch.py index 8761b16007a..246fd3ce52a 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/quantization/pytorch.py @@ -23,7 +23,7 @@ set_module_for_calibration, ) from sparseml.core import Event, EventType, State -from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.quantization.base import QuantizationModifier from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward diff --git a/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py b/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py index b2dc72cfb83..97a1f1022da 100644 --- a/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py +++ b/src/sparseml/modifiers/quantization_legacy/modification/modify_model.py @@ -15,7 +15,9 @@ import logging import os -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) _LOGGER = logging.getLogger(__name__) diff --git a/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py b/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py index e97771c437d..f235cbfdf8c 100644 --- a/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/quantization_scheme.py @@ -30,7 +30,9 @@ except Exception: torch_quantization = None -from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import FakeQuantizeWrapper +from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import ( + FakeQuantizeWrapper, +) __all__ = [ diff --git a/src/sparseml/modifiers/quantization_legacy/utils/quantize.py b/src/sparseml/modifiers/quantization_legacy/utils/quantize.py index 89e9d2faaa9..038ae5cab92 100644 --- a/src/sparseml/modifiers/quantization_legacy/utils/quantize.py +++ b/src/sparseml/modifiers/quantization_legacy/utils/quantize.py @@ -26,13 +26,17 @@ FUSED_MODULE_NAMES, NON_QUANTIZABLE_MODULE_NAMES, ) -from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import FakeQuantizeWrapper +from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import ( + FakeQuantizeWrapper, +) from sparseml.modifiers.quantization_legacy.utils.helpers import ( QATWrapper, configure_module_default_qconfigs, prepare_embeddings_qat, ) -from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import QuantizationScheme +from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import ( + QuantizationScheme, +) from sparseml.pytorch.utils import get_layer from sparseml.utils.fsdp.context import fix_fsdp_module_name diff --git a/src/sparseml/transformers/sparsification/modification/modifying_bert.py b/src/sparseml/transformers/sparsification/modification/modifying_bert.py index 2632600ed8c..fccb65ea885 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_bert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_bert.py @@ -25,8 +25,12 @@ from torch import nn from transformers.models.bert.modeling_bert import BertSelfAttention -from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATMatMul -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( + QATMatMul, +) +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py index 9aa0389590c..d2bf92dd637 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_distilbert.py @@ -27,8 +27,12 @@ MultiHeadSelfAttention, ) -from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATMatMul -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( + QATMatMul, +) +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_llama.py b/src/sparseml/transformers/sparsification/modification/modifying_llama.py index 5e480376a74..d7aea9ac1c6 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_llama.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_llama.py @@ -36,7 +36,9 @@ QuantizableIdentity, QuantizableMatMul, ) -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py index a50f31eb588..a27a75d5992 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mistral.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mistral.py @@ -36,7 +36,9 @@ QuantizableIdentity, QuantizableMatMul, ) -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py index 62fd6bef7f8..2ab9d819fb5 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_mobilebert.py @@ -20,8 +20,12 @@ from torch import nn from transformers.models.mobilebert.modeling_mobilebert import MobileBertEmbeddings -from sparseml.modifiers.quantization_legacy.modification.modification_objects import QATLinear -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.modification_objects import ( + QATLinear, +) +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/src/sparseml/transformers/sparsification/modification/modifying_opt.py b/src/sparseml/transformers/sparsification/modification/modifying_opt.py index fb448316cab..eb42dd6d686 100644 --- a/src/sparseml/transformers/sparsification/modification/modifying_opt.py +++ b/src/sparseml/transformers/sparsification/modification/modifying_opt.py @@ -27,7 +27,9 @@ QuantizableBatchMatmul, QuantizableIdentity, ) -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparseml.pytorch.utils.helpers import swap_modules from sparseml.transformers.sparsification.modification.base import ( check_transformers_version, diff --git a/tests/sparseml/modifiers/quantization/modification/test_modify_model.py b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py index 16c13af7207..4ad1cb6580b 100644 --- a/tests/sparseml/modifiers/quantization/modification/test_modify_model.py +++ b/tests/sparseml/modifiers/quantization/modification/test_modify_model.py @@ -18,7 +18,9 @@ import pytest from sparseml.modifiers.quantization_legacy.modification import modify_model -from sparseml.modifiers.quantization_legacy.modification.registry import ModificationRegistry +from sparseml.modifiers.quantization_legacy.modification.registry import ( + ModificationRegistry, +) from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, standardize_lookup_name diff --git a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 01c5fb0cbf9..0fcb66eee9c 100644 --- a/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -20,9 +20,11 @@ from sparseml.core.framework import Framework from sparseml.core.model import ModifiableModel from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.modifiers.quantization_legacy.gptq.pytorch import GPTQModifierPyTorch -from sparseml.modifiers.quantization_legacy.pytorch import LegacyQuantizationModifierPyTorch -from sparseml.modifiers.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization.gptq.pytorch import GPTQModifierPyTorch +from sparseml.modifiers.quantization.quantization.base import QuantizationModifier +from sparseml.modifiers.quantization_legacy.pytorch import ( + LegacyQuantizationModifierPyTorch, +) from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory from tests.sparseml.pytorch.helpers import LinearNet from tests.testing_utils import requires_torch diff --git a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py index b1327f4ce3d..2e9750c60c7 100644 --- a/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py +++ b/tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py @@ -21,7 +21,9 @@ from sparseml.core.event import Event, EventType from sparseml.core.factory import ModifierFactory from sparseml.core.framework import Framework -from sparseml.modifiers.quantization_legacy.pytorch import LegacyQuantizationModifierPyTorch +from sparseml.modifiers.quantization_legacy.pytorch import ( + LegacyQuantizationModifierPyTorch, +) from sparseml.pytorch.sparsification.quantization.quantize import ( is_qat_helper_module, is_quantizable_module,