diff --git a/tests/kernel/indexing_test.py b/tests/kernel/indexing_test.py index 07b93d7ea..107e17676 100644 --- a/tests/kernel/indexing_test.py +++ b/tests/kernel/indexing_test.py @@ -155,6 +155,15 @@ def testDimExpressionBackedDynamicDimInferenceMismatch(self): ): c.finalize() + def testDependentDynamicDims(self): + c = IndexingContext() + inst = object() + kb1 = KernelBuffer[M, M * 4] + c.bind_shaped(inst, kb1, (c.next_dyn_dim(), c.next_dyn_dim())) + c.finalize() + self.assertEqual(c.dyn_dims[0], c.eval_dim(inst, kb1, 0)) + self.assertEqual(c.dyn_dims[0] * 4, c.eval_dim(inst, kb1, 1)) + if __name__ == "__main__": unittest.main()