Skip to content

Commit

Permalink
Merge changes in pytorch tests workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
reuvenp committed Sep 18, 2024
2 parents 6dc9f31 + bb35a81 commit ebb2be6
Show file tree
Hide file tree
Showing 54 changed files with 3,058 additions and 348 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions
pip install pytest
- name: Run unittests
run: |
python -m unittest discover tests/pytorch_tests -v
pytest tests_pytest/pytorch
3 changes: 3 additions & 0 deletions .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install coverage
pip install pytest
- name: Prepare TF env
run: pip install tensorflow==2.13.*
- name: Run tensorflow testsuite
Expand All @@ -32,6 +33,8 @@ jobs:
run: pip uninstall tensorflow -y && pip install torch==2.0.* torchvision onnx onnxruntime onnxruntime-extensions
- name: Run torch testsuite
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest tests/test_suite.py -v
- name: Run torch pytest
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/pytorch
- name: Combine Multiple Coverage Files
run: coverage combine
- name: Run Coverage HTML
Expand Down
3 changes: 3 additions & 0 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self,
layer_class: type,
reuse: bool = False,
reuse_group: str = None,
inputs_as_list: bool = False,
quantization_attr: Dict[str, Any] = None,
has_activation: bool = True,
is_custom: bool = False
Expand All @@ -58,6 +59,7 @@ def __init__(self,
layer_class: Class path of the layer this node represents.
reuse: Whether this node was duplicated and represents a reused layer.
reuse_group: Name of group of nodes from the same reused layer.
inputs_as_list: Whether to pass the node its input tensors as a list or not when calling the layer.
quantization_attr: Attributes the node holds regarding how it should be quantized.
has_activation: Whether the node has activations that we might want to quantize.
is_custom: Whether the node is custom layer or not.
Expand All @@ -71,6 +73,7 @@ def __init__(self,
self.layer_class = layer_class
self.reuse = reuse
self.reuse_group = reuse_group
self.inputs_as_list = inputs_as_list
self.final_weights_quantization_cfg = None
self.final_activation_quantization_cfg = None
self.candidates_quantization_cfg = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ def __init__(self,
layer_class,
reuse,
reuse_group,
inputs_as_list,
quantization_attr,
has_activation=has_activation)

self.op_call_kwargs = op_call_kwargs
self.op_call_args = list(op_call_args)
self.functional_op = functional_op
self.inputs_as_list = inputs_as_list
self.tensor_input_allocs = [] if tensor_input_allocs is None else tensor_input_allocs

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _run_operation(self,
else:
# If operator expects a single input tensor, it cannot be a list as it should
# have a dtype field.
if len(input_tensors) == 1:
if len(input_tensors) == 1 and not n.inputs_as_list:
input_tensors = input_tensors[0]
out_tensors_of_n_float = op_func(input_tensors)

Expand Down
24 changes: 23 additions & 1 deletion model_compression_toolkit/core/keras/reader/node_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from keras.src.layers.core import TFOpLambda, SlicingOpLambda
from keras.src.engine.keras_tensor import KerasTensor
from keras.src.engine.node import Node as KerasNode
from keras.src.layers.merging.base_merge import _Merge
else:
from keras.layers.core import TFOpLambda, SlicingOpLambda
from keras.engine.keras_tensor import KerasTensor
from keras.engine.node import Node as KerasNode
from keras.layers.merging.base_merge import _Merge

from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
Expand Down Expand Up @@ -287,6 +289,7 @@ def build_node(node: KerasNode,
for i, arg in enumerate(op_call_args[0]):
if is_const(arg):
weights.update({i: to_numpy(arg, is_single_tensor=True)})
inputs_as_list = __is_node_inputs_a_list(op_call_args, keras_layer)

node = BaseNode(node_name,
layer_config,
Expand All @@ -296,6 +299,7 @@ def build_node(node: KerasNode,
layer_class,
is_reused,
reuse_group,
inputs_as_list,
is_custom=is_keras_custom_layer(layer_class))

node_name_to_node[node_name] = node
Expand All @@ -316,6 +320,24 @@ def __is_functional_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""

return (keras_layer.symbol in
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol,TFOpLambda(tf.add_n).symbol] and
[TFOpLambda(tf.concat).symbol, TFOpLambda(tf.stack).symbol, TFOpLambda(tf.add_n).symbol] and
len(op_call_args) > 0 and
isinstance(op_call_args[0], list))


def __is_node_inputs_a_list(op_call_args: Any, keras_layer: Any) -> bool:
"""
Check whether the input tensors should be passed as a list or not. This is relevant
only for layers that inherit from _Merge such as Concatenate and Add.
Args:
op_call_args: Arguments list to check.
keras_layer: Keras layer.
Returns:
Whether the input tensors should be passed as a list or not.
"""

return (isinstance(keras_layer, _Merge) and
len(op_call_args) > 0 and
isinstance(op_call_args[0], (list, tuple)))
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,11 @@ def _run_operation(n: BaseNode,
_tensor_input_allocs = None

if isinstance(n, FunctionalNode) and n.inputs_as_list:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
if isinstance(op_func, PytorchQuantizationWrapper):
# in wrapped nodes, the op args & kwargs are already in the PytorchQuantizationWrapper.
out_tensors_of_n_float = op_func(*input_tensors)
else:
out_tensors_of_n_float = op_func(input_tensors, *op_call_args, **functional_kwargs)
else:
merged_inputs, functional_kwargs = _merge_inputs(n, input_tensors, op_call_args, functional_kwargs.copy(),
tensor_input_allocs=_tensor_input_allocs)
Expand Down
17 changes: 13 additions & 4 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,19 @@ def nodes_builder(model: GraphModule,

# Add constants to weights dictionary.
if node.op != PLACEHOLDER:
for i, input_node in enumerate(node.all_input_nodes):
if input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})
if len(node.args) and isinstance(node.args[0], (list, tuple)):
# handle weights in nodes with list input. Especially when there's a duplicate of a tensor
# in the input list (e.g. torch.concat([const1, x, const2, x, const3], 1)).
for input_node in node.all_input_nodes:
for i, input_arg in enumerate(node.args[0]):
if input_node is input_arg and input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})
else:
for i, input_node in enumerate(node.all_input_nodes):
if input_node in consts_dict:
used_consts.add(input_node)
weights.update({i: consts_dict[input_node]})

# Extract input and output shapes of the node.
input_shape, output_shape = _extract_input_and_output_shapes(node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Tuple, Callable
from typing import Tuple, Callable, Union
from model_compression_toolkit.core import common
from model_compression_toolkit.core.common import Graph
from model_compression_toolkit.verify_packages import FOUND_TF
Expand All @@ -25,10 +25,12 @@
import tensorflow as tf
from tensorflow.keras.layers import Layer
from model_compression_toolkit.core.keras.back2framework.keras_model_builder import KerasModelBuilder
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode
from mct_quantizers import KerasQuantizationWrapper
from mct_quantizers import KerasActivationQuantizationHolder
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS

def _get_wrapper(node: common.BaseNode,
def _get_wrapper(node: Union[common.BaseNode, FunctionalNode],
layer: Layer,
fw_impl=None) -> Layer:
"""
Expand All @@ -45,9 +47,16 @@ def _get_wrapper(node: common.BaseNode,
# for positional weights we need to extract the weight's value.
weights_values = {attr: node.get_weights_by_keys(attr)
for attr in weights_quantizers if isinstance(attr, int)}
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
return KerasQuantizationWrapper(layer,
weights_quantizers,
weights_values)
weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
return layer


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
if FOUND_TORCH:
import torch
from mct_quantizers import PytorchQuantizationWrapper, PytorchActivationQuantizationHolder
from mct_quantizers.common.constants import OP_CALL_ARGS, OP_CALL_KWARGS
from model_compression_toolkit.core.pytorch.back2framework.pytorch_model_builder import PyTorchModelBuilder
from model_compression_toolkit.core.common.graph.functional_node import FunctionalNode


def fully_quantized_wrapper(node: common.BaseNode,
Expand All @@ -46,7 +48,14 @@ def fully_quantized_wrapper(node: common.BaseNode,
# for positional weights we need to extract the weight's value.
weights_values = {attr: fw_impl.to_tensor(node.get_weights_by_keys(attr))
for attr in weight_quantizers if isinstance(attr, int)}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values)
# When wrapping functional nodes, need to set call args\kwargs in wrapper, because they
# are used during wrapper call method.
func_node_kwargs = {OP_CALL_ARGS: node.op_call_args,
OP_CALL_KWARGS: node.op_call_kwargs
} if isinstance(node, FunctionalNode) else {}
return PytorchQuantizationWrapper(module, weight_quantizers, weights_values,
is_inputs_as_list=node.inputs_as_list,
**func_node_kwargs)
return module


Expand Down
22 changes: 17 additions & 5 deletions model_compression_toolkit/gptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,20 @@
# limitations under the License.
# ==============================================================================

from model_compression_toolkit.gptq.common.gptq_config import GradientPTQConfig, RoundingType, GPTQHessianScoresConfig
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
from model_compression_toolkit.gptq.common.gptq_config import (
GradientPTQConfig,
RoundingType,
GPTQHessianScoresConfig,
GradualActivationQuantizationConfig,
QFractionLinearAnnealingConfig
)

from model_compression_toolkit.verify_packages import FOUND_TF, FOUND_TORCH

if FOUND_TF:
from model_compression_toolkit.gptq.keras.quantization_facade import keras_gradient_post_training_quantization
from model_compression_toolkit.gptq.keras.quantization_facade import get_keras_gptq_config

if FOUND_TORCH:
from model_compression_toolkit.gptq.pytorch.quantization_facade import pytorch_gradient_post_training_quantization
from model_compression_toolkit.gptq.pytorch.quantization_facade import get_pytorch_gptq_config
Loading

0 comments on commit ebb2be6

Please sign in to comment.