I don´t quite get the logic behind this. wouldn't this allocate a lot of memory on each device? what happens on the devices afterwards that would need the result on all devices? could this be a source of memory issues?
I would propose to refactor the run_sharded function to support a more flexible approach here.
Originally posted by @MaHaWo in #147 (comment)