Skip to content

Commit

Permalink
Formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Nov 15, 2024
1 parent 413d3bc commit 1c8fea8
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 4 deletions.
1 change: 1 addition & 0 deletions gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import torch
import torch.nn.functional as F

import vllm._gradlib_C # noqa: F401

rtol = 1e-5
Expand Down
6 changes: 4 additions & 2 deletions gradlib/gemm_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys

import pandas as pd
import torch
import torch.nn.functional as F

import vllm._gradlib_C # noqa: F401
import pandas as pd

torch.ops._gradlib_C.rocb_create_extension()
torch.ops._gradlib_C.hipb_create_extension()
Expand Down Expand Up @@ -37,7 +38,8 @@ def mm(self, inp, weights):
n=inp.shape[0],
k=inp.shape[1])
if soltype == 1:
out = torch.ops._gradlib_C.hipb_mm(inp, weights.t(), solidx, None, None, None, None, None)
out = torch.ops._gradlib_C.hipb_mm(inp, weights.t(), solidx, None,
None, None, None, None)
elif soltype == 2:
out = torch.ops._gradlib_C.rocb_mm(inp, weights.t(), solidx)
else:
Expand Down
2 changes: 1 addition & 1 deletion gradlib/gemm_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from pathlib import Path

import torch # isort: split
import vllm._gradlib_C
import pandas as pd

import vllm._gradlib_C # noqa: F401
from gradlib.GemmTuner import GemmTuner

torch.ops._gradlib_C.rocb_create_extension()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ exclude = [

[tool.codespell]
ignore-words-list = "dout, te, indicies, subtile"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build,./gradlib,./csrc/rocm"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build,./csrc/gradlib,./csrc/rocm"

[tool.isort]
use_parentheses = true
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def create_ds(self):
soltype = 2
solds[key] = (soltype, int(ds['solidx']))
self.solids = solds

def query_sol(self, m, n, k, bias, dtype):
return self.solids.get((m, n, k, bias, str(dtype)), (0, 0))

Expand Down

0 comments on commit 1c8fea8

Please sign in to comment.