consider value matrix shapes for Jax conversion#10
consider value matrix shapes for Jax conversion#10William-Baker wants to merge 2 commits intogoogle-deepmind:mainfrom
Conversation
For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compliation error. By considering both ov and qk matrices when padding the Jax model we can resolve this
|
Thanks for this! This fix makes sense. Could you add a test case that catches the previous bug? Easiest would be to add a minimal version of your example to the test cases in test_cases.py. This will add it to all our integration tests. In particular, rasp_to_craft_integration_test.py should fail without your change and pass after making it. The proper thing to do would be to also add a unit test to assemble_test.py. I don't think we need to do this here, but feel free to give it a shot if you want. |
|
Is this still blocked on the tests? Would be great if it could be merged soon. |
|
I haven't added any test cases yet as there are still discrepancies with RASP and CRAFT that should be resolved first, then we can check for consistency. I will prepare some test cases in the mean time that demonstrate the issue under the current compiler but things may change depending on issue #14 |
|
I have added the test cases to test cases.py and checked that the validator doesn't flag issues, since these are categorical aggregates, I do think this only happens with categorical aggregation programs... |
david-lindner
left a comment
There was a problem hiding this comment.
ov_test_case_1 still fails -- please fix
For some programs, the value matrix could have a dimension larger than the largest key-query matrix resulting in a compilation error. By considering both ov and qk matrices when padding the Jax model we can resolve this.
Examples of programs that trigger this error are:
I have verified that this change does not break any of the test cases proposed in issue #9