-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfp16fixes.py
29 lines (22 loc) · 979 Bytes
/
fp16fixes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
def fp16_fixes():
if torch.backends.mps.is_available():
torch.empty = torch.zeros
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
if weight is not None:
weight = weight.float()
if bias is not None:
bias = bias.float()
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
def new_torch_tensor_permute(input, *dims):
result = torch.permute(input, tuple(dims))
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute