Skip to content

Commit 04e1238

Browse files
authored
[Core Aten] Add and enable tests for aten index_select and logical_and (pytorch#6293)
1 parent 8141078 commit 04e1238

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

test/test_core_aten_ops.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,7 +2162,6 @@ def test_aten_index_select_0(self):
21622162
kwargs = dict()
21632163
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)
21642164

2165-
@unittest.skip
21662165
def test_aten_index_select_1(self):
21672166
args = (
21682167
torch.randn((2, 10)).to(torch.float16),
@@ -2181,6 +2180,33 @@ def test_aten_index_select_2(self):
21812180
kwargs = dict()
21822181
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)
21832182

2183+
def test_aten_index_select_3(self):
2184+
args = (
2185+
torch.randn((2, 10)).to(torch.float32),
2186+
1,
2187+
torch.randint(0, 10, (2,)).to(torch.int32),
2188+
)
2189+
kwargs = dict()
2190+
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)
2191+
2192+
def test_aten_index_select_4(self):
2193+
args = (
2194+
torch.randn((2, 10)).to(torch.float16),
2195+
1,
2196+
torch.randint(0, 10, (2,)).to(torch.int32),
2197+
)
2198+
kwargs = dict()
2199+
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)
2200+
2201+
def test_aten_index_select_5(self):
2202+
args = (
2203+
torch.randint(0, 10, (2, 10)).to(torch.int32),
2204+
1,
2205+
torch.randint(0, 10, (2,)).to(torch.int32),
2206+
)
2207+
kwargs = dict()
2208+
run_export_and_compare(self, torch.ops.aten.index_select, args, kwargs)
2209+
21842210
@unittest.skip
21852211
def test_aten_index_Tensor_0(self):
21862212
args = (
@@ -2437,7 +2463,6 @@ def test_aten_logical_and_1(self):
24372463
kwargs = dict()
24382464
run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs)
24392465

2440-
@unittest.skip
24412466
def test_aten_logical_and_2(self):
24422467
args = (
24432468
torch.randint(0, 10, (10, 10)).to(torch.int32),
@@ -2446,6 +2471,14 @@ def test_aten_logical_and_2(self):
24462471
kwargs = dict()
24472472
run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs)
24482473

2474+
def test_aten_logical_and_3(self):
2475+
args = (
2476+
torch.randint(0, 2, (10, 10)).to(torch.bool),
2477+
torch.randint(0, 2, (10, 10)).to(torch.bool),
2478+
)
2479+
kwargs = dict()
2480+
run_export_and_compare(self, torch.ops.aten.logical_and, args, kwargs)
2481+
24492482
def test_aten_logical_not_0(self):
24502483
args = (torch.randn((10, 10)).to(torch.float32),)
24512484
kwargs = dict()

0 commit comments

Comments
 (0)