Skip to content

Commit

Permalink
Merge pull request #76 from pedohorse/attribute-serialization-error-c…
Browse files Browse the repository at this point in the history
…atching

attribute serialization error catching
  • Loading branch information
pedohorse authored Apr 15, 2024
2 parents 88135d3 + 8f061ef commit 7b175b3
Show file tree
Hide file tree
Showing 27 changed files with 422 additions and 132 deletions.
20 changes: 20 additions & 0 deletions src/lifeblood/attribute_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio
import json

from .common_serialization import AttribSerializer, AttribDeserializer


async def serialize_attributes(attributes: dict) -> str:
return await asyncio.get_event_loop().run_in_executor(None, serialize_attributes_core, attributes)


async def deserialize_attributes(attributes_serialized: str) -> dict:
return await asyncio.get_event_loop().run_in_executor(None, deserialize_attributes_core, attributes_serialized)


def serialize_attributes_core(attributes: dict) -> str:
return json.dumps(attributes, cls=AttribSerializer)


def deserialize_attributes_core(attributes_serialized: str) -> dict:
return json.loads(attributes_serialized, cls=AttribDeserializer)
4 changes: 0 additions & 4 deletions src/lifeblood/basenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,6 @@ def apply_settings(self, settings: Dict[str, Dict[str, Any]]) -> None:
self.logger().warning(f'applying settings: skipping parameter "{param_name}": bad value type: {str(e)}')
continue

# # some helpers
# def _get_task_attributes(self, task_row):
# return json.loads(task_row.get('attributes', '{}'))

#
# Plugin info
#
Expand Down
55 changes: 10 additions & 45 deletions src/lifeblood/basenode_serializer_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, is_dataclass
import json
from .common_serialization import AttribSerializer, AttribDeserializer
from .basenode_serialization import NodeSerializerBase, IncompatibleDeserializationMethod, FailedToApplyNodeState, FailedToApplyParameters
from .basenode import BaseNode, NodeParameterType
from .uidata import ParameterFullValue
Expand Down Expand Up @@ -31,66 +32,30 @@ class NodeSerializerV2(NodeSerializerBase):
the final string though is json-compliant
"""

class Serializer(json.JSONEncoder):
def __reform(self, obj):
if type(obj) is set:
return {
'__special_object_type__': 'set',
'items': self.__reform(list(obj))
}
elif type(obj) is tuple:
return {
'__special_object_type__': 'tuple',
'items': self.__reform(list(obj))
}
elif type(obj) is dict: # int keys case
if any(isinstance(x, (int, float, tuple)) for x in obj.keys()):
return {
'__special_object_type__': 'kvp',
'items': self.__reform([[k, v] for k, v in obj.items()])
}
return {k: self.__reform(v) for k, v in obj.items()}
elif is_dataclass(obj):
dcs = self.__reform(obj.__dict__) # dataclasses.asdict is recursive, kills inner dataclasses
class Serializer(AttribSerializer):
def _reform(self, obj):
if is_dataclass(obj):
dcs = self._reform(obj.__dict__) # dataclasses.asdict is recursive, kills inner dataclasses
dcs['__dataclass__'] = obj.__class__.__name__
dcs['__special_object_type__'] = 'dataclass'
return dcs
elif isinstance(obj, NodeParameterType):
return {'value': obj.value,
'__special_object_type__': 'NodeParameterType'
}
elif isinstance(obj, list):
return [self.__reform(x) for x in obj]
elif isinstance(obj, (int, float, str, bool)) or obj is None:
return obj
raise NotImplementedError(f'serialization not implemented for type "{type(obj)}"')
return super()._reform(obj)

def encode(self, o):
return super().encode(self.__reform(o))

def default(self, obj):
return super(NodeSerializerV2.Serializer, self).default(obj)

class Deserializer(json.JSONDecoder):
def dedata(self, obj):
class Deserializer(AttribDeserializer):
def _dedata(self, obj):
special_type = obj.get('__special_object_type__')
if special_type == 'set':
return set(obj.get('items'))
elif special_type == 'tuple':
return tuple(obj.get('items'))
elif special_type == 'kvp':
return {k: v for k, v in obj.get('items')}
elif special_type == 'dataclass':
if special_type == 'dataclass':
data = globals()[obj['__dataclass__']](**{k: v for k, v in obj.items() if k not in ('__dataclass__', '__special_object_type__')})
if obj['__dataclass__'] == 'NodeData':
data.pos = tuple(data.pos)
return data
elif special_type == 'NodeParameterType':
return NodeParameterType(obj['value'])
return obj

def __init__(self):
super(NodeSerializerV2.Deserializer, self).__init__(object_hook=self.dedata)
return super()._dedata(obj)

def serialize(self, node: BaseNode) -> Tuple[bytes, Optional[bytes]]:
param_values = {}
Expand Down
48 changes: 48 additions & 0 deletions src/lifeblood/common_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import json


class AttribSerializer(json.JSONEncoder):
def _reform(self, obj):
if type(obj) is set:
return {
'__special_object_type__': 'set',
'items': self._reform(list(obj))
}
elif type(obj) is tuple:
return {
'__special_object_type__': 'tuple',
'items': self._reform(list(obj))
}
elif type(obj) is dict: # int keys case
if any(isinstance(x, (int, float, tuple)) for x in obj.keys()):
return {
'__special_object_type__': 'kvp',
'items': self._reform([[k, v] for k, v in obj.items()])
}
return {k: self._reform(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self._reform(x) for x in obj]
elif isinstance(obj, (int, float, str, bool)) or obj is None:
return obj
raise NotImplementedError(f'serialization not implemented for type "{type(obj)}"')

def encode(self, o):
return super().encode(self._reform(o))

def default(self, obj):
return super().default(obj)


class AttribDeserializer(json.JSONDecoder):
def _dedata(self, obj):
special_type = obj.get('__special_object_type__')
if special_type == 'set':
return set(obj.get('items'))
elif special_type == 'tuple':
return tuple(obj.get('items'))
elif special_type == 'kvp':
return {k: v for k, v in obj.get('items')}
return obj

def __init__(self):
super().__init__(object_hook=self._dedata)
4 changes: 2 additions & 2 deletions src/lifeblood/core_nodes/parent_children_waiter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
from dataclasses import dataclass
import json
from lifeblood.attribute_serialization import deserialize_attributes_core
from lifeblood.basenode import BaseNode, ProcessingError
from lifeblood.nodethings import ProcessingResult
from lifeblood.taskspawn import TaskSpawn
Expand Down Expand Up @@ -177,7 +177,7 @@ def process_task(self, context) -> ProcessingResult:
self.__cache_children[parent_id] = ParentChildrenWaiterNode.Entry()
if task_id not in self.__cache_children[parent_id].children:
self.__cache_children[parent_id].children.add(task_id)
self.__cache_children[parent_id].all_children_dicts[task_id] = json.loads(context.task_field('attributes'))
self.__cache_children[parent_id].all_children_dicts[task_id] = dict(context.task_attributes())
self.__cache_children[parent_id].all_children_dicts[task_id]['_builtin_id'] = task_id
# promote children attribs up
if recursive and children_count > 0:
Expand Down
4 changes: 1 addition & 3 deletions src/lifeblood/core_nodes/split_waiter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from dataclasses import dataclass
import time
import json
from lifeblood.basenode import BaseNode
from lifeblood.nodethings import ProcessingResult
from lifeblood.taskspawn import TaskSpawn
Expand Down Expand Up @@ -134,7 +132,7 @@ def process_task(self, context) -> ProcessingResult: #TODO: not finished, attrib
if self.__cache[split_id].first_to_arrive is None and len(self.__cache[split_id].arrived) == 0:
self.__cache[split_id].first_to_arrive = task_id
if context.task_field('split_element') not in self.__cache[split_id].arrived:
self.__cache[split_id].arrived[context.task_field('split_element')] = json.loads(context.task_field('attributes'))
self.__cache[split_id].arrived[context.task_field('split_element')] = dict(context.task_attributes())
self.__cache[split_id].arrived[context.task_field('split_element')]['_builtin_id'] = task_id

# we will not wait in loop or we risk deadlocking threadpool
Expand Down
11 changes: 8 additions & 3 deletions src/lifeblood/environment_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import asyncio
import os
import sys
import json
import inspect
import pathlib
import re
Expand All @@ -24,6 +23,7 @@
from types import MappingProxyType
from . import invocationjob, paths, logging
from .config import get_config
from .attribute_serialization import serialize_attributes_core, deserialize_attributes_core
from .toml_coders import TomlFlatConfigEncoder
from .process_utils import create_process, oh_no_its_windows
from .exceptions import ProcessInitializationError
Expand Down Expand Up @@ -70,6 +70,9 @@ def __init__(self, resolver_name=None, arguments: Optional[Mapping] = None):
if resolver_name is None and len(arguments) > 0:
raise ValueError('if name is None - no arguments are allowed')
self.__resolver_name = resolver_name
if arguments is not None:
# validate args
serialize_attributes_core(arguments)
self.__args = arguments

def name(self):
Expand All @@ -82,6 +85,8 @@ def arguments(self):
return MappingProxyType(self.__args)

def add_argument(self, name: str, value):
# validate value
serialize_attributes_core(value)
self.__args[name] = value

def remove_argument(self, name: str):
Expand All @@ -94,7 +99,7 @@ async def get_environment(self) -> "invocationjob.Environment":
return await get_resolver(self.name()).get_environment(self.arguments())

def serialize(self) -> bytes:
return json.dumps({
return serialize_attributes_core({
'_EnvironmentResolverArguments__resolver_name': self.__resolver_name,
'_EnvironmentResolverArguments__args': self.__args,
}).encode('utf-8')
Expand All @@ -105,7 +110,7 @@ async def serialize_async(self):
@classmethod
def deserialize(cls, data: bytes):
wrp = EnvironmentResolverArguments(None)
data_dict = json.loads(data.decode('utf-8'))
data_dict = deserialize_attributes_core(data.decode('utf-8'))
wrp.__resolver_name = data_dict['_EnvironmentResolverArguments__resolver_name']
wrp.__args = data_dict['_EnvironmentResolverArguments__args']
return wrp
Expand Down
15 changes: 10 additions & 5 deletions src/lifeblood/nodethings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json

from .attribute_serialization import serialize_attributes_core
from .invocationjob import InvocationJob
from .taskspawn import TaskSpawn
from .environment_resolver import EnvironmentResolverArguments
Expand Down Expand Up @@ -39,9 +38,13 @@ def remove_split(self, attributes_to_set=None):
"""
self.do_split_remove = True
if attributes_to_set is not None:
# validate attributes_to_set
serialize_attributes_core(attributes_to_set) # will raise in case of errors
self.split_attributes_to_set.update(attributes_to_set)

def set_attribute(self, key: str, value):
# validate value first
serialize_attributes_core({key: value}) # will raise in case of errors
self.attributes_to_set[key] = value

def remove_attribute(self, key: str):
Expand All @@ -65,16 +68,18 @@ def split_task(self, into: int):
self._split_attribs = [{} for _ in range(into)]

def set_split_task_attrib(self, split: int, attr_name: str, attr_value):
# validate attrs
try:
json.dumps(attr_value)
serialize_attributes_core({attr_name: attr_value})
except:
raise ValueError('attribs must be json-serializable dict')
raise ValueError('attr_value must be json-serializable')
self._split_attribs[split][attr_name] = attr_value

def set_split_task_attribs(self, split: int, attribs: dict):
# validate attrs
try:
assert isinstance(attribs, dict)
json.dumps(attribs)
serialize_attributes_core(attribs)
except:
raise ValueError('attribs must be json-serializable dict')
self._split_attribs[split] = attribs
Expand Down
6 changes: 3 additions & 3 deletions src/lifeblood/processingcontext.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from types import MappingProxyType
import re

from .attribute_serialization import deserialize_attributes_core
from .config import get_config
from .environment_resolver import EnvironmentResolverArguments

Expand All @@ -15,7 +15,7 @@
class ProcessingContext:
class TaskWrapper:
def __init__(self, task_dict: dict):
self.__attributes = json.loads(task_dict.get('attributes', '{}'))
self.__attributes = deserialize_attributes_core(task_dict.get('attributes', '{}'))
self.__stuff = task_dict

def __getitem__(self, item):
Expand Down Expand Up @@ -59,7 +59,7 @@ def __getitem__(self, item):

def __init__(self, node: "BaseNode", task_dict: dict):
task_dict = dict(task_dict)
self.__task_attributes = json.loads(task_dict.get('attributes', '{}'))
self.__task_attributes = deserialize_attributes_core(task_dict.get('attributes', '{}'))
self.__task_dict = task_dict
self.__task_wrapper = ProcessingContext.TaskWrapper(task_dict)
self.__node_wrapper = ProcessingContext.NodeWrapper(node, self)
Expand Down
4 changes: 2 additions & 2 deletions src/lifeblood/scheduler/data_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import sqlite3
import random
import struct
import json
from dataclasses import dataclass
from ..attribute_serialization import serialize_attributes
from ..db_misc import sql_init_script
from ..expiring_collections import ExpiringValuesSetMap
from ..config import get_config
Expand Down Expand Up @@ -114,7 +114,7 @@ async def create_task(self, newtask: TaskSpawnData, *, con: Optional[aiosqlite.C
return ret

async with con.execute('INSERT INTO tasks ("name", "attributes", "parent_id", "state", "node_id", "node_output_name", "environment_resolver_data") VALUES (?, ?, ?, ?, ?, ?, ?)',
(newtask.name, json.dumps(newtask.attributes), newtask.parent_id, # TODO: run dumps in executor
(newtask.name, await serialize_attributes(newtask.attributes), newtask.parent_id,
newtask.state.value,
newtask.node_id, newtask.node_output_name,
newtask.environment_resolver_arguments.serialize() if newtask.environment_resolver_arguments is not None else None)) as newcur:
Expand Down
9 changes: 5 additions & 4 deletions src/lifeblood/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .. import logging
from .. import paths
from ..nodegraph_holder_base import NodeGraphHolderBase
from ..attribute_serialization import serialize_attributes, deserialize_attributes
#from ..worker_task_protocol import WorkerTaskClient
from ..worker_messsage_processor import WorkerControlClient
from ..scheduler_task_protocol import SchedulerTaskProtocol, SpawnStatus
Expand Down Expand Up @@ -334,7 +335,7 @@ async def get_task_attributes(self, task_id: int) -> Tuple[Dict[str, Any], Optio
env_res_args = None
if res['environment_resolver_data'] is not None:
env_res_args = await EnvironmentResolverArguments.deserialize_async(res['environment_resolver_data'])
return await asyncio.get_event_loop().run_in_executor(None, json.loads, res['attributes']), env_res_args
return await deserialize_attributes(res['attributes']), env_res_args

async def get_task_fields(self, task_id: int) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -1268,12 +1269,12 @@ async def update_task_attributes(self, task_id: int, attributes_to_update: dict,
self.__logger.warning(f'update task attributes for {task_id} failed. task id not found.')
await con.commit()
return
attributes = await asyncio.get_event_loop().run_in_executor(None, json.loads, row['attributes'])
attributes = await deserialize_attributes(row['attributes'])
attributes.update(attributes_to_update)
for name in attributes_to_delete:
if name in attributes:
del attributes[name]
await con.execute('UPDATE tasks SET "attributes" = ? WHERE "id" = ?', (await asyncio.get_event_loop().run_in_executor(None, json.dumps, attributes),
await con.execute('UPDATE tasks SET "attributes" = ? WHERE "id" = ?', (await serialize_attributes(attributes),
task_id))
await con.commit()

Expand Down Expand Up @@ -1544,7 +1545,7 @@ async def _inner_shit() -> Tuple[Tuple[SpawnStatus, Optional[int]], ...]:
continue

async with con.execute('INSERT INTO tasks ("name", "attributes", "parent_id", "state", "node_id", "node_output_name", "environment_resolver_data") VALUES (?, ?, ?, ?, ?, ?, ?)',
(newtask.name(), json.dumps(newtask._attributes()), parent_task_id, # TODO: run dumps in executor
(newtask.name(), await serialize_attributes(newtask._attributes()), parent_task_id, # TODO: run dumps in executor
TaskState.SPAWNED.value if newtask.create_as_spawned() else TaskState.WAITING.value,
node_id, newtask.node_output_name(),
newtask.environment_arguments().serialize() if newtask.environment_arguments() is not None else None)) as newcur:
Expand Down
Loading

0 comments on commit 7b175b3

Please sign in to comment.