diff --git a/README.md b/README.md
index a8e8062..b658aff 100644
--- a/README.md
+++ b/README.md
@@ -58,7 +58,7 @@ def foo():
if bar:
baz()
""",
- "utf8",
+ "utf8"
)
)
```
@@ -68,9 +68,9 @@ you can pass a "read" callable to the parse function.
The read callable can use either the byte offset or point tuple to read from
buffer and return source code as bytes object. An empty bytes object or None
-terminates parsing for that line. The bytes must encode the source as UTF-8.
+terminates parsing for that line. The bytes must be encoded as UTF-8 or UTF-16.
-For example, to use the byte offset:
+For example, to use the byte offset with UTF-8 encoding:
```python
src = bytes(
@@ -87,7 +87,7 @@ def read_callable_byte_offset(byte_offset, point):
return src[byte_offset : byte_offset + 1]
-tree = parser.parse(read_callable_byte_offset)
+tree = parser.parse(read_callable_byte_offset, encoding="utf8")
```
And to use the point:
@@ -103,7 +103,7 @@ def read_callable_point(byte_offset, point):
return src_lines[row][column:].encode("utf8")
-tree = parser.parse(read_callable_point)
+tree = parser.parse(read_callable_point, encoding="utf8")
```
Inspect the resulting `Tree`:
@@ -153,6 +153,27 @@ assert root_node.sexp() == (
)
```
+Or, to use the byte offset with UTF-16 encoding:
+
+```python
+parser.set_language(JAVASCRIPT)
+source_code = bytes("'😎' && '🐍'", "utf16")
+
+def read(byte_position, _):
+ return source_code[byte_position: byte_position + 2]
+
+tree = parser.parse(read, encoding="utf16")
+root_node = tree.root_node
+statement_node = root_node.children[0]
+binary_node = statement_node.children[0]
+snake_node = binary_node.children[2]
+snake = source_code[snake_node.start_byte:snake_node.end_byte]
+
+assert binary_node.type == "binary_expression"
+assert snake_node.type == "string"
+assert snake.decode("utf16") == "'🐍'"
+```
+
### Walking syntax trees
If you need to traverse a large number of nodes efficiently, you can use
diff --git a/tests/test_parser.py b/tests/test_parser.py
index 2c921ce..00316d4 100644
--- a/tests/test_parser.py
+++ b/tests/test_parser.py
@@ -131,6 +131,28 @@ def read_callback(_, point):
+ " arguments: (argument_list))))))",
)
+ def test_parse_utf16_encoding(self):
+ source_code = bytes("'😎' && '🐍'", "utf16")
+ parser = Parser(self.javascript)
+
+ def read(byte_position, _):
+ return source_code[byte_position: byte_position + 2]
+
+ tree = parser.parse(read, encoding="utf-16")
+ root_node = tree.root_node
+ snake_node = root_node.children[0].children[0].children[2]
+ snake = source_code[snake_node.start_byte + 2:snake_node.end_byte - 2]
+
+ self.assertEqual(snake_node.type, "string")
+ self.assertEqual(snake.decode("utf16"), "🐍")
+
+
+ def test_parse_invalid_encoding(self):
+ parser = Parser(self.python)
+ with self.assertRaises(ValueError):
+ parser.parse(b"foo", encoding="ascii")
+
+
def test_parse_with_one_included_range(self):
source_code = b"hi"
parser = Parser(self.html)
diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi
index dd1f576..d0b9c21 100644
--- a/tree_sitter/__init__.pyi
+++ b/tree_sitter/__init__.pyi
@@ -1,11 +1,13 @@
from collections.abc import ByteString, Callable, Iterator, Sequence
-from typing import Annotated, Any, Final, NamedTuple, final, overload
+from typing import Annotated, Any, Final, Literal, NamedTuple, final, overload
from typing_extensions import deprecated
_Ptr = Annotated[int, "TSLanguage *"]
_ParseCB = Callable[[int, Point | tuple[int, int]], bytes]
+_Encoding = Literal["utf8", "utf16"]
+
_UINT32_MAX = 0xFFFFFFFF
class Point(NamedTuple):
@@ -247,6 +249,7 @@ class Parser:
source: ByteString | _ParseCB | None,
/,
old_tree: Tree | None = None,
+ encoding: _Encoding = "utf8",
) -> Tree: ...
@overload
@deprecated("`keep_text` will be removed")
@@ -255,6 +258,7 @@ class Parser:
source: ByteString | _ParseCB | None,
/,
old_tree: Tree | None = None,
+ encoding: _Encoding = "utf8",
keep_text: bool = True,
) -> Tree: ...
def reset(self) -> None: ...
diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c
index ab0c4e1..e6e4b09 100644
--- a/tree_sitter/binding/parser.c
+++ b/tree_sitter/binding/parser.c
@@ -1,5 +1,7 @@
#include "parser.h"
+#include
+
#define SET_ATTRIBUTE_ERROR(name) \
(name != NULL && name != Py_None && parser_set_##name(self, name, NULL) < 0)
@@ -75,7 +77,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo
Py_XDECREF(args);
// If error or None returned, we're done parsing.
- if (!rv || (rv == Py_None)) {
+ if (rv == NULL || rv == Py_None) {
Py_XDECREF(rv);
*bytes_read = 0;
return NULL;
@@ -84,7 +86,7 @@ static const char *parser_read_wrapper(void *payload, uint32_t byte_offset, TSPo
// If something other than None is returned, it must be a bytes object.
if (!PyBytes_Check(rv)) {
Py_XDECREF(rv);
- PyErr_SetString(PyExc_TypeError, "Read callable must return byte buffer");
+ PyErr_SetString(PyExc_TypeError, "read callable must return a bytestring");
*bytes_read = 0;
return NULL;
}
@@ -101,21 +103,62 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
PyObject *source_or_callback;
PyObject *old_tree_obj = NULL;
int keep_text = 1;
- char *keywords[] = {"", "old_tree", "keep_text", NULL};
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!p:parse", keywords, &source_or_callback,
- state->tree_type, &old_tree_obj, &keep_text)) {
+ const char *encoding = "utf8";
+ char *keywords[] = {"", "old_tree", "encoding", "keep_text", NULL};
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|O!sp:parse", keywords, &source_or_callback,
+ state->tree_type, &old_tree_obj, &encoding, &keep_text)) {
return NULL;
}
const TSTree *old_tree = old_tree_obj ? ((Tree *)old_tree_obj)->tree : NULL;
+ TSInputEncoding input_encoding;
+ if (strcmp(encoding, "utf8") == 0) {
+ input_encoding = TSInputEncodingUTF8;
+ } else if (strcmp(encoding, "utf16") == 0) {
+ input_encoding = TSInputEncodingUTF16;
+ } else {
+ // try to normalize the encoding and check again
+ PyObject *encodings = PyImport_ImportModule("encodings");
+ if (encodings == NULL) {
+ goto encoding_error;
+ }
+ PyObject *normalize_encoding = PyObject_GetAttrString(encodings, "normalize_encoding");
+ Py_DECREF(encodings);
+ if (normalize_encoding == NULL) {
+ goto encoding_error;
+ }
+ PyObject *arg = PyUnicode_DecodeASCII(encoding, strlen(encoding), NULL);
+ if (arg == NULL) {
+ goto encoding_error;
+ }
+ PyObject *normalized_obj = PyObject_CallOneArg(normalize_encoding, arg);
+ Py_DECREF(arg);
+ Py_DECREF(normalize_encoding);
+ if (normalized_obj == NULL) {
+ goto encoding_error;
+ }
+ const char *normalized_str = PyUnicode_AsUTF8(normalized_obj);
+ if (strcmp(normalized_str, "utf8") == 0 || strcmp(normalized_str, "utf_8") == 0) {
+ Py_DECREF(normalized_obj);
+ input_encoding = TSInputEncodingUTF8;
+ } else if (strcmp(normalized_str, "utf16") == 0 || strcmp(normalized_str, "utf_16") == 0) {
+ Py_DECREF(normalized_obj);
+ input_encoding = TSInputEncodingUTF16;
+ } else {
+ Py_DECREF(normalized_obj);
+ goto encoding_error;
+ }
+ }
+
TSTree *new_tree = NULL;
Py_buffer source_view;
if (PyObject_GetBuffer(source_or_callback, &source_view, PyBUF_SIMPLE) > -1) {
// parse a buffer
const char *source_bytes = (const char *)source_view.buf;
uint32_t length = (uint32_t)source_view.len;
- new_tree = ts_parser_parse_string(self->parser, old_tree, source_bytes, length);
+ new_tree = ts_parser_parse_string_encoding(self->parser, old_tree, source_bytes, length,
+ input_encoding);
PyBuffer_Release(&source_view);
} else if (PyCallable_Check(source_or_callback)) {
// clear the GetBuffer error
@@ -129,7 +172,7 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
TSInput input = {
.payload = &payload,
.read = parser_read_wrapper,
- .encoding = TSInputEncodingUTF8,
+ .encoding = input_encoding,
};
new_tree = ts_parser_parse(self->parser, old_tree, input);
Py_XDECREF(payload.previous_return_value);
@@ -156,6 +199,10 @@ PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
tree->source = keep_text ? source_or_callback : Py_None;
Py_INCREF(tree->source);
return PyObject_Init((PyObject *)tree, state->tree_type);
+
+encoding_error:
+ PyErr_Format(PyExc_ValueError, "encoding must be 'utf8' or 'utf16', not '%s'", encoding);
+ return NULL;
}
PyObject *parser_reset(Parser *self, void *Py_UNUSED(payload)) {
@@ -330,7 +377,7 @@ PyObject *parser_set_language_old(Parser *self, PyObject *arg) {
PyDoc_STRVAR(
parser_parse_doc,
- "parse(self, source, /, old_tree=None, keep_text=True)\n--\n\n"
+ "parse(self, source, /, old_tree=None, encoding=\"utf8\", keep_text=True)\n--\n\n"
"Parse a slice of a bytestring or bytes provided in chunks by a callback.\n\n"
"The callback function takes a byte offset and position and returns a bytestring starting "
"at that offset and position. The slices can be of any length. If the given position "