Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion about "append" mode in sow #78

Open
pwolle opened this issue Jun 17, 2024 · 0 comments
Open

Confusion about "append" mode in sow #78

pwolle opened this issue Jun 17, 2024 · 0 comments

Comments

@pwolle
Copy link

pwolle commented Jun 17, 2024

I think I am confused on how to use the "append" mode in sow, I would expect the following code

from oryx.core import sow, reap

def f(x):
    x = sow(x + 1.0, tag="tag", name="x", mode="append")
    x = sow(x + 1.0, tag="tag", name="x", mode="append")
    return x

print(reap(f, tag="tag")(1))

To output something similar to {'x': [2, 3]}, as the documentation says: "Another option is 'append', in which all sows of the same name will be appended into a growing array.". However, I get the same error as in strict mode, which is

Traceback (most recent call last):
  File "/home/anton/flarenet/oryx_example_append.py", line 11, in <module>
    print(reap(f, tag="tag")(1))
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 868, in wrapped
    return call_and_reap(
           ^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 801, in wrapped
    out, reaps, preds = _call_and_reap(
                        ^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 838, in wrapped
    out_flat, reaps, preds = flat_fun.call_wrapped(flat_args)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/oryx_example_append.py", line 7, in f
    x = sow(x + 1.0, tag="tag", name="x", mode="append")
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 253, in sow
    return _sow(value, tag=tag, name=name, mode=mode, key=key)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 301, in _sow
    out_flat = sow_p.bind(*flat_args, name=name, tag=tag, mode=mode, tree=in_tree)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 416, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 396, in process_primitive
    return self.default_process_primitive(primitive, tracers, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 404, in default_process_primitive
    outvals = context.process_sow(*vals, **params)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 508, in process_sow
    return self.handle_sow(*values, name=name, tag=tag, tree=tree, mode=mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/anton/flarenet/.venv/lib/python3.12/site-packages/oryx/core/interpreters/harvest.py", line 567, in handle_sow
    raise ValueError(f'Variable has already been reaped: {name}')
ValueError: Variable has already been reaped: x

I am using python 3.12.3 with jax==0.4.29 and jaxlib==0.4.29. The same code with mode set to "clobber" works as described in the documentation.
I would like to have a way to use the same name and tag multiple times, but still be able to reap or plant in/from (specific) parts of the sowed places, since I do not want to force users to call functions with sow in them with nest.
What is the recommended way to do this and could you perhaps provide a minimal working example using the append mode?

Please let me know if you would like any further information.
Thank you very much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant