diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c index c0f27c2..1050859 100644 --- a/tree_sitter/binding/query.c +++ b/tree_sitter/binding/query.c @@ -351,6 +351,7 @@ PyObject *query_new(PyTypeObject *cls, PyObject *args, PyObject *Py_UNUSED(kwarg QueryPredicateAnyOf *predicate = PyObject_New(QueryPredicateAnyOf, state->query_predicate_anyof_type); + predicate->capture_id = (predicate_step + 1)->value_id; predicate->is_positive = is_positive; predicate->values = values; PyObject *predicate_obj = diff --git a/tree_sitter/binding/query_predicates.c b/tree_sitter/binding/query_predicates.c index 7226591..8347187 100644 --- a/tree_sitter/binding/query_predicates.c +++ b/tree_sitter/binding/query_predicates.c @@ -48,17 +48,25 @@ static inline bool satisfies_anyof(ModuleState *state, QueryPredicateAnyOf *pred PyObject *nodes = nodes_for_capture_index(state, predicate->capture_id, match, tree); for (size_t i = 0, l = (size_t)PyList_Size(nodes); i < l; ++i) { Node *node = (Node *)PyList_GetItem(nodes, i); - PyObject *text1 = node_get_text(node, NULL), *text2; + PyObject *text1 = node_get_text(node, NULL); + bool found_match = false; + for (size_t j = 0, k = (size_t)PyList_Size(predicate->values); j < k; ++j) { - text2 = PyList_GetItem(predicate->values, j); - if (PREDICATE_CMP(text1, text2, predicate) != 1) { - Py_DECREF(text1); - Py_DECREF(nodes); - return false; + PyObject *text2 = PyList_GetItem(predicate->values, j); + if (PREDICATE_CMP(text1, text2, predicate) == 1) { + found_match = true; + break; } } + Py_DECREF(text1); + + if (!found_match) { + Py_DECREF(nodes); + return false; + } } + Py_DECREF(nodes); return true; }