Skip to content

Commit

Permalink
Merge branch 'main' into andrew/faster_chrome_CI
Browse files Browse the repository at this point in the history
  • Loading branch information
ayjayt committed Dec 3, 2024
2 parents 30c2477 + 9071bc7 commit e8e94a6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
16 changes: 9 additions & 7 deletions choreographer/pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import sys
import json
import simplejson
import platform
import warnings
from threading import Lock
Expand All @@ -11,7 +11,7 @@ class BlockWarning(UserWarning):

# TODO: don't know about this
# TODO: use has_attr instead of np.integer, you'll be fine
class NumpyEncoder(json.JSONEncoder):
class MultiEncoder(simplejson.JSONEncoder):
"""Special json encoder for numpy types"""

def default(self, obj):
Expand All @@ -29,18 +29,20 @@ def default(self, obj):
return float(obj)
elif hasattr(obj, "dtype") and obj.shape != ():
return obj.tolist()
return json.JSONEncoder.default(self, obj)
elif hasattr(obj, "isoformat"):
return obj.isoformat()
return simplejson.JSONEncoder.default(self, obj)


class PipeClosedError(IOError):
pass

class Pipe:
def __init__(self, debug=False, cls=NumpyEncoder):
def __init__(self, debug=False, json_encoder=MultiEncoder):
self.read_from_chromium, self.write_from_chromium = list(os.pipe())
self.read_to_chromium, self.write_to_chromium = list(os.pipe())
self.debug = debug
self.cls=cls
self.json_encoder = json_encoder

# this is just a convenience to prevent multiple shutdowns
self.shutdown_lock = Lock()
Expand All @@ -51,7 +53,7 @@ def write_json(self, obj, debug=None):
if not debug: debug = self.debug
if debug:
print("write_json:", file=sys.stderr)
message = json.dumps(obj, ensure_ascii=False, cls=self.cls)
message = simplejson.dumps(obj, ensure_ascii=False, ignore_nan=True, cls=self.json_encoder)
encoded_message = message.encode("utf-8") + b"\0"
if debug:
print(f"write_json: {message}", file=sys.stderr)
Expand Down Expand Up @@ -112,7 +114,7 @@ def read_jsons(self, blocking=True, debug=None):
for raw_message in decoded_buffer.split("\0"):
if raw_message:
try:
jsons.append(json.loads(raw_message))
jsons.append(simplejson.loads(raw_message))
except BaseException as e:
if debug:
print(f"Problem with {raw_message} in json: {e}", file=sys.stderr)
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ authors = [
maintainers = [
{name = "Andrew Pikul", email = "ajpikul@gmail.com"},
]
dependencies = [
"simplejson"
]

[project.optional-dependencies]
dev = [
Expand Down

0 comments on commit e8e94a6

Please sign in to comment.