@@ -2162,7 +2162,6 @@ def test_aten_index_select_0(self):
2162
2162
kwargs = dict ()
2163
2163
run_export_and_compare (self , torch .ops .aten .index_select , args , kwargs )
2164
2164
2165
- @unittest .skip
2166
2165
def test_aten_index_select_1 (self ):
2167
2166
args = (
2168
2167
torch .randn ((2 , 10 )).to (torch .float16 ),
@@ -2181,6 +2180,33 @@ def test_aten_index_select_2(self):
2181
2180
kwargs = dict ()
2182
2181
run_export_and_compare (self , torch .ops .aten .index_select , args , kwargs )
2183
2182
2183
+ def test_aten_index_select_3 (self ):
2184
+ args = (
2185
+ torch .randn ((2 , 10 )).to (torch .float32 ),
2186
+ 1 ,
2187
+ torch .randint (0 , 10 , (2 ,)).to (torch .int32 ),
2188
+ )
2189
+ kwargs = dict ()
2190
+ run_export_and_compare (self , torch .ops .aten .index_select , args , kwargs )
2191
+
2192
+ def test_aten_index_select_4 (self ):
2193
+ args = (
2194
+ torch .randn ((2 , 10 )).to (torch .float16 ),
2195
+ 1 ,
2196
+ torch .randint (0 , 10 , (2 ,)).to (torch .int32 ),
2197
+ )
2198
+ kwargs = dict ()
2199
+ run_export_and_compare (self , torch .ops .aten .index_select , args , kwargs )
2200
+
2201
+ def test_aten_index_select_5 (self ):
2202
+ args = (
2203
+ torch .randint (0 , 10 , (2 , 10 )).to (torch .int32 ),
2204
+ 1 ,
2205
+ torch .randint (0 , 10 , (2 ,)).to (torch .int32 ),
2206
+ )
2207
+ kwargs = dict ()
2208
+ run_export_and_compare (self , torch .ops .aten .index_select , args , kwargs )
2209
+
2184
2210
@unittest .skip
2185
2211
def test_aten_index_Tensor_0 (self ):
2186
2212
args = (
@@ -2437,7 +2463,6 @@ def test_aten_logical_and_1(self):
2437
2463
kwargs = dict ()
2438
2464
run_export_and_compare (self , torch .ops .aten .logical_and , args , kwargs )
2439
2465
2440
- @unittest .skip
2441
2466
def test_aten_logical_and_2 (self ):
2442
2467
args = (
2443
2468
torch .randint (0 , 10 , (10 , 10 )).to (torch .int32 ),
@@ -2446,6 +2471,14 @@ def test_aten_logical_and_2(self):
2446
2471
kwargs = dict ()
2447
2472
run_export_and_compare (self , torch .ops .aten .logical_and , args , kwargs )
2448
2473
2474
+ def test_aten_logical_and_3 (self ):
2475
+ args = (
2476
+ torch .randint (0 , 2 , (10 , 10 )).to (torch .bool ),
2477
+ torch .randint (0 , 2 , (10 , 10 )).to (torch .bool ),
2478
+ )
2479
+ kwargs = dict ()
2480
+ run_export_and_compare (self , torch .ops .aten .logical_and , args , kwargs )
2481
+
2449
2482
def test_aten_logical_not_0 (self ):
2450
2483
args = (torch .randn ((10 , 10 )).to (torch .float32 ),)
2451
2484
kwargs = dict ()
0 commit comments