Skip to content

Commit

Permalink
Refactor TPC with supported input bit-width to each operator. (#1169)
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c authored Aug 15, 2024
1 parent 67ec854 commit 916db58
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.constants import REUSE, REUSE_GROUP
from model_compression_toolkit.core.keras.constants import KERNEL, BIAS, USE_BIAS, FILTERS, PADDING, \
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE
KERNEL_SIZE, DEPTH_MULTIPLIER, STRIDES, DILATIONS, DILATION_RATE, DEPTHWISE_KERNEL, RATE, \
ACTIVATION, LINEAR


def extract_bias_node_data(_node: FunctionalNode, _graph: Graph) -> np.ndarray:
Expand Down Expand Up @@ -132,7 +133,7 @@ def substitute(self,

weights = {KERNEL: k}
# Create Conv2D layer attributes.
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2]}
conv_fw_attr = {FILTERS: k.shape[3], KERNEL_SIZE: k.shape[:2], ACTIVATION: LINEAR}
if len(conv_func_node.op_call_args) > 0:
Logger.critical(f"node {conv_func_node.name} expected to have only kwargs but got args={conv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in conv_func_node.op_call_kwargs:
Expand Down Expand Up @@ -209,7 +210,7 @@ def substitute(self,

weights = {DEPTHWISE_KERNEL: k}
k_shape = k.shape
conv_fw_attr = {DEPTH_MULTIPLIER: k_shape[3], KERNEL_SIZE: k_shape[:2]}
conv_fw_attr = {DEPTH_MULTIPLIER: k_shape[3], KERNEL_SIZE: k_shape[:2], ACTIVATION: LINEAR}
if len(dwconv_func_node.op_call_args) > 0:
Logger.critical(f"node {dwconv_func_node.name} expected to have only kwargs but got args={dwconv_func_node.op_call_args}.") # pragma: no cover
if STRIDES in dwconv_func_node.op_call_kwargs:
Expand Down

0 comments on commit 916db58

Please sign in to comment.