Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PT2] MinMax #3166

Merged
merged 28 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1b2bf15
init
AlexanderDokuchaev Dec 23, 2024
a4366ab
Merge branch 'develop' into ad/pt2_minmax
AlexanderDokuchaev Dec 23, 2024
516c716
efficientnet_pytorch==0.7.1
AlexanderDokuchaev Dec 24, 2024
ab0ceec
addict
AlexanderDokuchaev Dec 24, 2024
af939a4
dot
AlexanderDokuchaev Dec 24, 2024
873a1df
mypy
AlexanderDokuchaev Dec 25, 2024
6c62051
c
AlexanderDokuchaev Dec 25, 2024
3301f7b
Merge branch 'develop' into ad/pt2_minmax
AlexanderDokuchaev Dec 25, 2024
c3c809c
mypy
AlexanderDokuchaev Dec 25, 2024
cc9e0fc
f
AlexanderDokuchaev Dec 25, 2024
b18c9a2
none
AlexanderDokuchaev Dec 26, 2024
40466e2
no cache for gpu
AlexanderDokuchaev Dec 26, 2024
656ca73
p
AlexanderDokuchaev Dec 26, 2024
5d038c2
revert
AlexanderDokuchaev Dec 26, 2024
1f8ce90
Merge branch 'develop' into ad/pt2_minmax
AlexanderDokuchaev Dec 27, 2024
8693917
rename
AlexanderDokuchaev Dec 29, 2024
e2f892e
Merge branch 'develop' into ad/pt2_minmax
AlexanderDokuchaev Jan 24, 2025
b91cb9e
rm torch2 backend
AlexanderDokuchaev Jan 26, 2025
2cf3eff
c
AlexanderDokuchaev Jan 27, 2025
4e39113
dub
AlexanderDokuchaev Jan 27, 2025
f9e8200
com
AlexanderDokuchaev Jan 27, 2025
9a8fc76
c
AlexanderDokuchaev Jan 28, 2025
c1151b8
f
AlexanderDokuchaev Jan 28, 2025
287e624
iter
AlexanderDokuchaev Jan 28, 2025
2a33082
Merge branch 'develop' into ad/pt2_minmax
AlexanderDokuchaev Jan 28, 2025
33713a1
comments
AlexanderDokuchaev Jan 29, 2025
98030f0
Merge branch 'develop' into ad/pt2_minmax
AlexanderDokuchaev Jan 29, 2025
2f36c9f
mypy
AlexanderDokuchaev Jan 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions nncf/common/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar, cast
from typing import Any, TypeVar, cast

import nncf
from nncf.common.engine import Engine
Expand All @@ -20,6 +20,7 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.data.dataset import Dataset
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -52,18 +53,23 @@ def create(model: TModel) -> NNCFGraph:
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter as FXGraphConverter

return FXGraphConverter.create_nncf_graph(cast(GraphModule, model))
if model_backend == BackendType.TORCH:
if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
from nncf.torch.nncf_network import NNCFNetwork

return cast(NNCFNetwork, model).nncf.get_graph()
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

return cast(GraphModelWrapper, model).build_nncf_graph()

raise nncf.UnsupportedBackendError(
"Cannot create backend-specific graph because {} is not supported!".format(model_backend.value)
)


class ModelTransformerFactory:
@staticmethod
def create(model: TModel, inplace: bool = False) -> ModelTransformer:
def create(model: TModel, inplace: bool = False) -> ModelTransformer[Any]:
"""
Factory method to create backend-specific ModelTransformer instance based on the input model.

Expand All @@ -84,11 +90,18 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer:
from nncf.openvino.graph.model_transformer import OVModelTransformer

return OVModelTransformer(cast(Model, model), inplace=inplace)
if model_backend == BackendType.TORCH:
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper
from nncf.experimental.torch2.model_transformer import PT2ModelTransformer

return PT2ModelTransformer(cast(GraphModelWrapper, model))

if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
from nncf.torch.model_transformer import PTModelTransformer
from nncf.torch.nncf_network import NNCFNetwork

return PTModelTransformer(cast(NNCFNetwork, model))

if model_backend == BackendType.TORCH_FX:
from torch.fx import GraphModule

Expand Down Expand Up @@ -122,14 +135,19 @@ def create(model: TModel) -> Engine:
from nncf.openvino.engine import OVNativeEngine

return OVNativeEngine(cast(Model, model))
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
from nncf.experimental.torch2.engine import PT2Engine
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

return PT2Engine(cast(GraphModelWrapper, model))
if model_backend in (BackendType.TORCH, BackendType.TORCH_FX):
from torch.nn import Module

from nncf.torch.engine import PTEngine

return PTEngine(cast(Module, model))
raise nncf.UnsupportedBackendError(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
"Cannot create backend-specific engine because {} is not supported!".format(model_backend.value)
f"Cannot create backend-specific engine because {model_backend.value} is not supported!"
)


Expand Down Expand Up @@ -176,10 +194,14 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator:
from nncf.openvino.statistics.aggregator import OVStatisticsAggregator

return OVStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH:
if model_backend == BackendType.TORCH and not is_experimental_torch_tracing_enabled():
from nncf.torch.statistics.aggregator import PTStatisticsAggregator

return PTStatisticsAggregator(dataset)
if model_backend == BackendType.TORCH and is_experimental_torch_tracing_enabled():
from nncf.experimental.torch2.statistics.aggregator import PT2StatisticsAggregator

return PT2StatisticsAggregator(dataset)
if model_backend == BackendType.TORCH_FX:
from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator

Expand Down
6 changes: 3 additions & 3 deletions nncf/common/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TypeVar
from typing import Generic, TypeVar

from nncf.common.graph.transformations.layout import TransformationLayout

TModel = TypeVar("TModel")


class ModelTransformer:
class ModelTransformer(Generic[TModel]):
"""
Applies transformations to the model.
"""
Expand All @@ -29,7 +29,7 @@ def __init__(self, model: TModel):
"""
self._model = model

def transform(self, transformation_layout: TransformationLayout) -> TModel: # type:ignore
def transform(self, transformation_layout: TransformationLayout) -> TModel:
"""
Applies transformations to the model.

Expand Down
8 changes: 4 additions & 4 deletions nncf/common/graph/transformations/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from typing import List

from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.commands import Command


class TransformationLayout:
Expand All @@ -27,13 +27,13 @@ def __init__(self) -> None:
"""
Initialize Transformation Layout.
"""
self._transformations: List[TransformationCommand] = []
self._transformations: List[Command] = []

@property
def transformations(self) -> List[TransformationCommand]:
def transformations(self) -> List[Command]:
return self._transformations

def register(self, transformation: TransformationCommand) -> None:
def register(self, transformation: Command) -> None:
"""
Registers the transformation command in the transformation layout.

Expand Down
6 changes: 3 additions & 3 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.tensor_statistics.statistics_serializer import dump_statistics
from nncf.common.tensor_statistics.statistics_serializer import load_statistics
from nncf.common.utils.backend import BackendType
from nncf.data.dataset import Dataset
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic
from nncf.tensor import Tensor

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand Down Expand Up @@ -165,7 +165,7 @@ def register_statistic_points(self, statistic_points: StatisticPointsContainer)
self.stat_subset_size = max(self.stat_subset_size, tensor_collector.num_samples)

@abstractmethod
def _register_statistics(self, outputs: Dict[str, NNCFTensor], statistic_points: StatisticPointsContainer) -> None:
def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None:
"""
Process prepared raw model outputs and statistic points for the further usage.

Expand Down Expand Up @@ -203,7 +203,7 @@ def _get_merged_statistic_points(

@staticmethod
@abstractmethod
def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
def _process_outputs(outputs: Any) -> Dict[str, Tensor]:
"""
Post-process model outputs for the further statistics collection.

Expand Down
6 changes: 6 additions & 0 deletions nncf/common/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Callable, TypeVar

import nncf
from nncf.experimental.common.check_feature import is_experimental_torch_tracing_enabled

TModel = TypeVar("TModel")

Expand Down Expand Up @@ -53,6 +54,11 @@ def is_torch_model(model: TModel) -> bool:
import torch
import torch.fx

from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper

if is_experimental_torch_tracing_enabled():
return isinstance(model, (GraphModelWrapper, torch.nn.Module)) and not isinstance(model, torch.fx.GraphModule)

return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module)


Expand Down
21 changes: 21 additions & 0 deletions nncf/experimental/common/check_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2025 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os


def is_experimental_torch_tracing_enabled() -> bool:
"""
Checks if experimental torch tracing is enabled by environment variable NNCF_EXPERIMENTAL_TORCH_TRACING.

:return: True if experimental torch tracing is enabled, False otherwise.
"""
return os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is not None
37 changes: 37 additions & 0 deletions nncf/experimental/torch2/commands.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from torch import nn

from nncf.common.graph.transformations.commands import Command
from nncf.common.graph.transformations.commands import TransformationType
from nncf.experimental.torch2.function_hook.hook_storage import RemovableHookHandle
from nncf.torch.graph.transformations.commands import PTTargetPoint


class PT2InsertionCommand(Command):
"""
Insertion operation to the models.
"""

def __init__(
self,
target_points: List[PTTargetPoint],
hook_module: nn.Module,
*,
handle_storage: Optional[List[RemovableHookHandle]] = None,
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(TransformationType.INSERT)
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
self.target_points = target_points
self.hook_module = hook_module
self.handle_storage = handle_storage
47 changes: 47 additions & 0 deletions nncf/experimental/torch2/engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, Tuple, Union

import torch

from nncf.common.engine import Engine
from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import GraphModelWrapper


class PT2Engine(Engine):
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
"""
Engine for the Pytorch backend.
"""

def __init__(self, model: GraphModelWrapper):
"""
Constructor.

:param model: Pytorch module to infer.
"""

self._model = model.model
self._model.eval()

def infer(self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]]) -> Any:
"""
Runs Torch model on the provided input.

:param input_data: Inputs for the model.
:return: Model outputs.
"""

if isinstance(input_data, dict):
return self._model(**input_data)
if isinstance(input_data, tuple):
return self._model(*input_data)
return self._model(input_data)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

"""
This module implements selected functions from the `torch` module, excluding the `hand_function` mechanism.
This module implements selected functions from the `torch` module, excluding the `handle_torch_function` function.

It processes inner functions to handle exception hooks and graph analysis. The implementation is designed
to support custom handling of inner function exceptions for specific functions.
Expand Down
Loading