Skip to content

Commit

Permalink
[PT2] MinMax (#3166)
Browse files Browse the repository at this point in the history
### Changes

Introduce TORCH2 backend
MinMax algorithms for torch2 backend
Add handle_torch_function for quantization function to trace it by
torch_function

### Related tickets

152996

### Tests
[test
install](https://github.com/openvinotoolkit/nncf/actions/runs/13014595451)
  • Loading branch information
AlexanderDokuchaev authored Jan 29, 2025
1 parent 6565033 commit 5ad9bc4
Show file tree
Hide file tree
Showing 46 changed files with 14,748 additions and 125 deletions.
38 changes: 30 additions & 8 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar, cast
from typing import Any, TypeVar, cast

import nncf
from nncf.common.engine import Engine
Expand All @@ -20,6 +20,7 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.data.dataset import Dataset
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -53,17 +54,22 @@ def create(model: TModel) -> NNCFGraph:

return FXGraphConverter.create_nncf_graph(cast(GraphModule, model))
if model_backend == BackendType.TORCH:
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.nncf_network import NNCFNetwork

return cast(NNCFNetwork, model).nncf.get_graph()
if isinstance(model, GraphModelWrapper):
return model.build_graph()
if isinstance(model, NNCFNetwork):
return model.nncf.get_graph()
raise nncf.InternalError(f"Unexpected type of model {type(model)} for TORCH backend")
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
f"Cannot create backend-specific graph because {model_backend.value} is not supported!"
)


class ModelTransformerFactory:
@staticmethod
def create(model: TModel, inplace: bool = False) -> ModelTransformer:
def create(model: TModel, inplace: bool = False) -> ModelTransformer[Any]:
"""
Factory method to create backend-specific ModelTransformer instance based on the input model.
Expand All @@ -84,11 +90,18 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.openvino.graph.model_transformer import OVModelTransformer

return OVModelTransformer(cast(Model, model), inplace=inplace)
if model_backend == BackendType.TORCH:
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer

return PT2ModelTransformer(cast(GraphModelWrapper, model))

if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork

return PTModelTransformer(cast(NNCFNetwork, model))

if model_backend == BackendType.TORCH_FX:
from torch.fx import GraphModule

Expand Down Expand Up @@ -125,11 +138,16 @@ def create(model: TModel) -> Engine:
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from torch.nn import Module

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.torch.engine import PTEngine

return PTEngine(cast(Module, model))
if isinstance(model, GraphModelWrapper):
pt_model = model.model
else:
pt_model = cast(Module, model)
return PTEngine(pt_model)
raise nncf.UnsupportedBackendError(
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
f"Cannot create backend-specific engine because {model_backend.value} is not supported!"
)


Expand Down Expand Up @@ -176,10 +194,14 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.openvino.statistics.aggregator import OVStatisticsAggregator

return OVStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH:
if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator

return PT2StatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

Expand Down
6 changes: 3 additions & 3 deletions nncf/common/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
from typing import Generic, TypeVar

from nncf.common.graph.transformations.layout import TransformationLayout

TModel = TypeVar("TModel")


class ModelTransformer:
class ModelTransformer(Generic[TModel]):
"""
Applies transformations to the model.
"""
Expand All @@ -29,7 +29,7 @@ def __init__(self, model: TModel):
"""
self._model = model

def transform(self, transformation_layout: TransformationLayout) -> TModel: # type:ignore
def transform(self, transformation_layout: TransformationLayout) -> TModel:
"""
Applies transformations to the model.
Expand Down
8 changes: 4 additions & 4 deletions nncf/common/graph/transformations/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import List

from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.commands import Command


class TransformationLayout:
Expand All @@ -27,13 +27,13 @@ def __init__(self) -> None:
"""
Initialize Transformation Layout.
"""
self._transformations: List[TransformationCommand] = []
self._transformations: List[Command] = []

@property
def transformations(self) -> List[TransformationCommand]:
def transformations(self) -> List[Command]:
return self._transformations

def register(self, transformation: TransformationCommand) -> None:
def register(self, transformation: Command) -> None:
"""
Registers the transformation command in the transformation layout.
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizer_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,6 @@ def revert_operations_to_floating_point_precision(
)

model_transformer = ModelTransformerFactory.create(quantized_model)
transformed_model = model_transformer.transform(transformation_layout) # type: ignore[var-annotated]
transformed_model = model_transformer.transform(transformation_layout)

return cast(TModel, transformed_model)
6 changes: 3 additions & 3 deletions nncf/common/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

import re
from typing import Iterable, List, Optional, Sequence, Union
from typing import Iterable, List, Optional, Union

import nncf
from nncf.common.graph import NNCFGraph
Expand Down Expand Up @@ -52,8 +52,8 @@ def matches_any(tested_str: str, strs_to_match_to: Union[Iterable[str], str, Non

def should_consider_scope(
serializable_id: Union[QuantizerId, NNCFNodeName],
ignored_scopes: Optional[Sequence[str]],
target_scopes: Optional[Sequence[str]] = None,
ignored_scopes: Optional[Iterable[str]],
target_scopes: Optional[Iterable[str]] = None,
) -> bool:
"""
Used when an entity arising during compression has to be compared to an allowlist or a denylist of strings.
Expand Down
6 changes: 3 additions & 3 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.tensor_statistics.statistics_serializer import dump_statistics
from nncf.common.tensor_statistics.statistics_serializer import load_statistics
from nncf.common.utils.backend import BackendType
from nncf.data.dataset import Dataset
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic
from nncf.tensor import Tensor

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand Down Expand Up @@ -165,7 +165,7 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer)
self.stat_subset_size = max(self.stat_subset_size, tensor_collector.num_samples)

@abstractmethod
def _register_statistics(self, outputs: Dict[str, NNCFTensor], statistic_points: StatisticPointsContainer) -> None:
def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None:
"""
Process prepared raw model outputs and statistic points for the further usage.
Expand Down Expand Up @@ -203,7 +203,7 @@ def _get_merged_statistic_points(

@staticmethod
@abstractmethod
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
def _process_outputs(outputs: Any) -> Dict[str, Tensor]:
"""
Post-process model outputs for the further statistics collection.
Expand Down
6 changes: 6 additions & 0 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, TypeVar, cast

import nncf
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled

try:
import openvino # type: ignore # noqa: F401
Expand Down Expand Up @@ -53,6 +54,11 @@ def is_torch_model(model: Any) -> bool:
import torch
import torch.fx

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

if is_experimental_torch_tracing_enabled():
return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule)

return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


Expand Down
21 changes: 21 additions & 0 deletions nncf/experimental/common/check_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os


def is_experimental_torch_tracing_enabled() -> bool:
"""
Checks if experimental torch tracing is enabled by environment variable NNCF_EXPERIMENTAL_TORCH_TRACING.
:return: True if experimental torch tracing is enabled, False otherwise.
"""
return os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is not None
43 changes: 43 additions & 0 deletions nncf/experimental/torch2/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from torch import nn

from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationType
from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle
from nncf.torch.graph.transformations.commands import PTTargetPoint


class PT2InsertionCommand(Command):
"""
Insertion operation to the models.
"""

def __init__(
self,
target_points: List[PTTargetPoint],
hook_module: nn.Module,
*,
handle_storage: Optional[List[RemovableHookHandle]] = None,
):
"""
:param target_points: The list of target points for the command.
:param hook_module: The hook module for the command that will be inserted into the model
to execute at the target points.
:param handle_storage: The handle storage for the command to collect RemovableHookHandle. Defaults to None.
"""
super().__init__(TransformationType.INSERT)
self.target_points = target_points
self.hook_module = hook_module
self.handle_storage = handle_storage
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

"""
This module implements selected functions from the `torch` module, excluding the `hand_function` mechanism.
This module implements selected functions from the `torch` module, excluding the `handle_torch_function` function.
It processes inner functions to handle exception hooks and graph analysis. The implementation is designed
to support custom handling of inner function exceptions for specific functions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import nncf
import nncf.torch.graph.operator_metatypes as om
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.layer_attributes import Dtype
Expand All @@ -30,6 +29,7 @@
from nncf.experimental.torch2.function_hook.graph.graph_utils import InOutMeta
from nncf.experimental.torch2.function_hook.graph.graph_utils import NodeType
from nncf.experimental.torch2.function_hook.nncf_graph.layer_attributes import PT2OpLayerAttributes
from nncf.torch.graph.graph import PTNNCFGraph


def get_node_type(type: NodeType, meta: Union[ConstMeta, FunctionMeta, InOutMeta]) -> str:
Expand Down Expand Up @@ -159,14 +159,14 @@ def get_layer_attributes(
return None


def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> PTNNCFGraph:
"""
Converts a graph to an NNCFGraph.
Converts a graph to an PTNNCFGraph.
:param nx_graph: The graph to convert.
:return: The converted NNCFGraph.
"""
nncf_graph = NNCFGraph()
nncf_graph = PTNNCFGraph()

map_nx_node_to_nncf_node: Dict[int, NNCFNode] = {}
for node, data in nx_graph.nodes(data=True):
Expand All @@ -178,10 +178,11 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
meta_type = get_meta_type(node_type, meta)
layer_attributes = get_layer_attributes(nx_graph, node, meta)
nncf_node = nncf_graph.add_nncf_node(
layer_attributes=layer_attributes,
layer_name=node_name,
node_metatype=meta_type,
node_name=node_name,
node_type=node_type,
node_metatype=meta_type,
layer_attributes=layer_attributes,
)
map_nx_node_to_nncf_node[node] = nncf_node

Expand All @@ -207,7 +208,7 @@ def convert_to_nncf_graph(nx_graph: nx.MultiDiGraph) -> NNCFGraph:
return nncf_graph


def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> NNCFGraph:
def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> PTNNCFGraph:
"""
Builds an NNCF graph from the given PyTorch model.
Expand All @@ -218,3 +219,35 @@ def build_nncf_graph(model: nn.Module, *args: Any, **kwargs: Any) -> NNCFGraph:
"""
graph = build_graph(model, *args, **kwargs)
return convert_to_nncf_graph(graph)


class GraphModelWrapper:
"""
A class that wraps a PyTorch model with examples inputs and provides an interface
to build a computational graph of the model.
:param model: The PyTorch model to be wrapped.
:param example_input: A tuple of example input for the model.
"""

def __init__(self, model: nn.Module, example_input: Any) -> None:
"""
Initialize the GraphModelWrapper.
"""
self.model = model
self.example_input = example_input

def build_graph(self) -> PTNNCFGraph:
"""
Constructs a computational graph of the given model.
This function builds a directed graph `PTNNCFGraph` representing the operations
and data flow within the model by leveraging hooks by using GraphBuilderMode.
:return: A PTNNCFGraph where nodes represent operations of model.
"""
if isinstance(self.example_input, dict):
return build_nncf_graph(self.model, **self.example_input)
if isinstance(self.example_input, tuple):
return build_nncf_graph(self.model, *self.example_input)
return build_nncf_graph(self.model, self.example_input)
Loading

0 comments on commit 5ad9bc4

Please sign in to comment.