Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Action masking (feature request) #77

Open
zengmao opened this issue Jan 13, 2025 · 7 comments
Open

Action masking (feature request) #77

zengmao opened this issue Jan 13, 2025 · 7 comments

Comments

@zengmao
Copy link

zengmao commented Jan 13, 2025

For example, consider a 2D maze environment. The state is the grid coordinates, and the actions are moving left, right, up or down, but the only valid actions are those that do not cross the wall in the maze. It would be nice if the user could specify which actions are valid. The invalid actions are "masked" and ignored by the NNPolicy action selection.

@zsunberg
Copy link
Member

This would be nice to have. One question is whether the mask should be part of the environment/MDP model or specified separately (or both).

Currently, POMDPs.jl has actions(m, state) and actions(m, belief) (https://juliapomdp.github.io/POMDPs.jl/stable/api/#POMDPs.actions) and CommonRLInterface has valid_action_mask.

@zengmao
Copy link
Author

zengmao commented Jan 18, 2025

For efficiently filtering the NN output, the action mask should probably be a boolean array, so POMDPs.actions(m, state) is not sufficient even though the return value implicitly contains all the information about the mask.

@zsunberg
Copy link
Member

zsunberg commented Jan 18, 2025

Right.

Having both efficiency and a simiple interface is a difficult challenge. CommonRLInterface.valid_action_mask seeks to address this challenge, but using it would require some refactoring of the current code.

The question should always be "what is the best development path?" We should probably not allow efficiency concerns to prevent us from implementing the feature. That would be premature optimization, and it is not clear what fraction of the time would be taken up by masking anyways. A good maxim here is "make it work, make it right, make it fast" (in that order).

It may be that actions(m, state) can be reasonably efficient, for instance, if we get a vector actionvalues from the neural network, we can do

best_action = first(actions(m))
best_action_value = -Inf
for a in actions(m, s)
    ai = actionindex(m, a)
    if actionvalues[ai] >= best_action_value
        best_action = a
        best_action_value = actionvalues[ai]
    end
end

This is not as good as the bit array, but I would say also not egregious, and it can be optimized later if we find that it is a performance bottleneck.

We also may be able to cache masks if the same states are encountered often.

Bottom line, don't let efficiency concerns derail development 🙂

@zengmao
Copy link
Author

zengmao commented Jan 18, 2025

To construct actionindex in the first place, I suppose there needs to be one more function in the POMDPs interface to return all possible actions of a model, not just the valid actions of the initial state (which may already be partially masked and becomes a subset).

@zsunberg
Copy link
Member

zsunberg commented Jan 18, 2025

POMDPs.actionindex is already part of the POMDPs interface. (also POMDPs.actions(m) should return all of the actions. )

@zengmao
Copy link
Author

zengmao commented Jan 18, 2025

The replay buffer may need to be modified to store the mask array for each state. Looking at line 265 of solver.jl,

q_sp_max = dropdims(maximum(target_q(sp_batch[i]), dims=1), dims=1)

This computes the maximum of $Q(s^\prime, a^\prime)$ for fixed $s^\prime$ over all possible $a^\prime$, as in the Bellman equation. It's necessary to modify this line to only compute the maximum over the valid actions, and this is most straightforward if the mask (as a boolean array) is precomputed and stored in the replay buffer. (Another motivation is keeping the vectorized code structure for processing the batch, which would be broken if you manually write a loop to iterate over valid actions.)

Therefore, some data structures in the current code would need to be modified to support masking. This may cause an overhead for solving models that do not use masking, though the overhead is likely small.

@zsunberg
Copy link
Member

zsunberg commented Jan 18, 2025

I like the way you are thinking!

Another option (with pros and cons) would be to maintain a separate data structure parallel to the main (s, a, sp, r) buffer that just holds masks. Pro: could be more easily disabled if the masks are not needed; Con: perhaps harder to maintain if we change the structure of the main buffer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants