Skip to content

Commit

Permalink
fix code style after rebase (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova authored Dec 20, 2024
1 parent 630d36a commit 7607f45
Showing 1 changed file with 0 additions and 55 deletions.
55 changes: 0 additions & 55 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,58 +295,3 @@ def compare_num_quantized_nodes_per_model(
expected_num_weight_nodes.update({k: 0 for k in set(num_weight_nodes) - set(expected_num_weight_nodes)})
actual_num_weights_per_model.append(num_weight_nodes)
test_case.assertEqual(expected_num_weight_nodes_per_model, actual_num_weights_per_model)


@contextmanager
def mock_torch_cuda_is_available(to_patch):
original_is_available = torch.cuda.is_available
if to_patch:
torch.cuda.is_available = lambda: True
try:
yield
finally:
if to_patch:
torch.cuda.is_available = original_is_available


@contextmanager
def patch_awq_for_inference(to_patch):
orig_gemm_forward = None
if to_patch:
# patch GEMM module to allow inference without CUDA GPU
from awq.modules.linear.gemm import WQLinearMMFunction
from awq.utils.packing_utils import dequantize_gemm

def new_forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0,
):
ctx.out_features = out_features

out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)

out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)

out = out + bias if bias is not None else out
out = out.reshape(out_shape)

if len(out.shape) == 2:
out = out.unsqueeze(0)
return out

orig_gemm_forward = WQLinearMMFunction.forward
WQLinearMMFunction.forward = new_forward
try:
yield
finally:
if orig_gemm_forward is not None:
WQLinearMMFunction.forward = orig_gemm_forward

0 comments on commit 7607f45

Please sign in to comment.