Skip to content

Commit

Permalink
fix more errors
Browse files Browse the repository at this point in the history
  • Loading branch information
dame-cell committed Oct 26, 2024
1 parent 6767588 commit 163614c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'triformer',
packages = find_packages(),
version = '1.3.1',
version = '1.3.2',
license='MIT',
description = 'Transformer components in Triton',
long_description=open('README.md').read(),
Expand Down
8 changes: 6 additions & 2 deletions triformer/forward_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def grid(meta):
class TritonLinearFunction(Function):
@staticmethod
def forward(ctx, x, weight, bias, use_relu=True):
# Convert input to float16 if necessary
x = x.to(dtype=torch.float16)

# Add batch dimension if necessary
if x.ndim == 2:
x = x.unsqueeze(0)
Expand Down Expand Up @@ -149,11 +152,12 @@ def __init__(self, in_features, out_features, use_relu=True):
self.in_features = in_features
self.out_features = out_features
self.use_relu = use_relu
# Change weight and bias to float32
self.weight = nn.Parameter(
torch.empty(out_features, in_features, device='cuda', dtype=torch.float16)
torch.empty(out_features, in_features, device='cuda', dtype=torch.float32)
)
self.bias = nn.Parameter(
torch.zeros(out_features, device='cuda', dtype=torch.float16)
torch.zeros(out_features, device='cuda', dtype=torch.float32)
)
self.reset_parameters()

Expand Down
14 changes: 14 additions & 0 deletions triformer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def calc_num_warps(block_size):
num_warps = 4
if block_size >= 2048:
num_warps = 8
if block_size >= 4096:
num_warps = 16
return num_warps

0 comments on commit 163614c

Please sign in to comment.