Skip to content

Commit

Permalink
JSON Schema updates and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Clemens Vasters <clemens@vasters.com>
  • Loading branch information
clemensv committed Mar 4, 2024
1 parent a912c1f commit 6b47170
Show file tree
Hide file tree
Showing 12 changed files with 759 additions and 283 deletions.
44 changes: 43 additions & 1 deletion avrotize/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,46 @@ def generic_type() -> list[str | dict]:
"type": "map",
"values": l2
}])
return l1
return l1

def find_schema_node(test, avro_schema, recursion_stack = []):
"""Find the first schema node in the avro_schema matching the test"""
for recursion_item in recursion_stack:
if avro_schema is recursion_item:
raise ValueError('Cyclical reference detected in schema')
if len(recursion_stack) > 30:
raise ValueError('Maximum recursion depth 30 exceeded in schema')
try:
recursion_stack.append(avro_schema)
if isinstance(avro_schema, dict):
test_node = test(avro_schema)
if test_node:
return avro_schema
for k, v in avro_schema.items():
if isinstance(v, (dict,list)):
node = find_schema_node(test, v, recursion_stack)
if node:
return node
elif isinstance(avro_schema, list):
for item in avro_schema:
node = find_schema_node(test, item, recursion_stack)
if node:
return node
return None
finally:
recursion_stack.pop()

def set_schema_node(test, replacement, avro_schema):
"""Set the first schema node in the avro_schema matching the test to the replacement"""
if isinstance(avro_schema, dict):
test_node = test(avro_schema)
if test_node:
avro_schema.clear()
avro_schema.update(replacement)
return
for k, v in avro_schema.items():
if isinstance(v, (dict,list)):
set_schema_node(test, replacement, v)
elif isinstance(avro_schema, list):
for item in avro_schema:
set_schema_node(test, replacement, item)
135 changes: 124 additions & 11 deletions avrotize/dependency_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,112 @@
from typing import List



def adjust_resolved_dependencies(avro_schema: List[dict] | dict):
"""
After resolving dependencies, it may still be necessary to adjust them. The
first pass of the algorithms below does inline all dependent types, but
the resulting document may still have fields defined before the types they
depend on because of the order in which the resolution happened, which necessarily
re-sorts the graph. This function will recursively adjust the resolved
dependencies until all record types have their dependency types defined before them.
"""

class TreeWalker:

def __init__(self):
self.found_something = True

def swap_record_dependencies_above(self, current_node, record, avro_schema) -> str | None:
""" swap the first reference to of the record type above the record in avro_schema """
if isinstance(current_node, dict):
if 'name' in current_node and 'namespace' in current_node and 'type' in current_node and \
current_node['name'] == record['name'] and current_node.get('namespace','') == record.get('namespace','') and current_node['type'] == record['type']:
# we reached the record again. we stop here.
return None
for k, v in current_node.items():
if k in ['dependencies', 'unmerged_types']:
continue
if isinstance(v, (dict,list)):
return self.swap_record_dependencies_above(v, record, avro_schema)
elif isinstance(v, str):
if k not in ['type', 'values', 'items']:
continue
qname = record.get('namespace','')+'.'+record['name']
if v == qname:
self.found_something = True
current_node[k] = copy.deepcopy(record)
return qname
elif isinstance(current_node, list):
for item in current_node:
if isinstance(item, (dict,list)):
return self.swap_record_dependencies_above(item, record, avro_schema)
elif isinstance(item, str):
qname = record.get('namespace','')+'.'+record['name']
if item == qname:
self.found_something = True
idx = current_node.index(item)
current_node.remove(item)
current_node.insert(idx, copy.deepcopy(record))
return qname
return None

def walk_schema(self, current_node, avro_schema, record_list) -> str | None:
found_record = None
if isinstance(current_node, dict):
if 'type' in current_node and (current_node['type'] == 'record' or current_node['type'] == 'enum'):
current_qname = current_node.get('namespace','')+'.'+current_node.get('name','')
if current_qname in record_list:
self.found_something = True
return current_qname
record_list.append(current_qname)
found_record = self.swap_record_dependencies_above(avro_schema, current_node, avro_schema)
for k, v in current_node.items():
if isinstance(v, (dict,list)):
qname = self.walk_schema(v, avro_schema, record_list)
if qname:
self.found_something = True
current_node[k] = qname
elif isinstance(current_node, list):
for item in current_node:
qname = self.walk_schema(item, avro_schema, record_list)
if qname:
self.found_something = True
idx = current_node.index(item)
current_node.remove(item)
current_node.insert(idx, qname)
# dedupe the list
new_list = []
for item in current_node:
if not item in new_list:
new_list.append(item)
current_node.clear()
current_node.extend(new_list)
return found_record

# while we've got work to do
tree_walker = TreeWalker()
while True:
tree_walker.found_something = False
tree_walker.walk_schema(avro_schema, avro_schema, [])
if not tree_walker.found_something:
break



def inline_dependencies_of(avro_schema, record):
for dependency in record.get('dependencies', []):
""" to break circular dependencies, we will inline all dependent record """
for dependency in copy.deepcopy(record.get('dependencies', [])):
dependency_type = next((x for x in avro_schema if x['name'] == dependency or x.get('namespace','')+'.'+x['name'] == dependency), None)
if not dependency_type:
continue
deps = record.get('dependencies', [])
for field in record['fields']:
swap_dependency_type(avro_schema, field, dependency, dependency_type, deps, [record['namespace']+'.'+record['name']])
del record['dependencies']
if 'dependencies' in record:
del record['dependencies']

adjust_resolved_dependencies(record)



Expand Down Expand Up @@ -96,6 +193,8 @@ def sort_messages_by_dependencies(avro_schema):

if not found:
print('WARNING: There are circular dependencies in the schema, unable to resolve them: {}'.format([x['name'] for x in avro_schema if isinstance(x, dict) and 'dependencies' in x]))

adjust_resolved_dependencies(sorted_messages)
return sorted_messages

def swap_record_dependencies(avro_schema, record, record_stack: List[str], recursion_depth: int = 0):
Expand Down Expand Up @@ -169,11 +268,15 @@ def swap_dependency_type(avro_schema, field, dependency, dependency_type, depend
for field_type in field['type']:
if field_type == dependency:
if dependency_type in avro_schema:
index = field['type'].index(field_type)
field['type'].remove(field_type)
field['type'].append(dependency_type)
field['type'].insert(index, dependency_type)
avro_schema.remove(dependency_type)
dependencies.remove(dependency)
if dependency in dependencies:
dependencies.remove(dependency)
dependencies.extend(dependency_type.get('dependencies', []))
if 'dependencies' in dependency_type:
swap_record_dependencies(avro_schema, dependency_type, record_stack, recursion_depth + 1)
for field_type in field['type']:
if isinstance(field_type, dict):
swap_dependency_type(avro_schema, field_type, dependency, dependency_type, dependencies, record_stack, recursion_depth + 1)
Expand All @@ -186,19 +289,24 @@ def swap_dependency_type(avro_schema, field, dependency, dependency_type, depend
for item in field['items']:
if item == dependency:
if dependency_type in avro_schema:
index = field['items'].index(item)
field['items'].remove(item)
field['items'].append(dependency_type)
field['items'].insert(index, dependency_type)
avro_schema.remove(dependency_type)
dependencies.remove(dependency)
if dependency in dependencies:
dependencies.remove(dependency)
dependencies.extend(dependency_type.get('dependencies', []))
if 'dependencies' in dependency_type:
swap_record_dependencies(avro_schema, dependency_type, record_stack)
for item in field['items']:
if isinstance(item, dict):
swap_dependency_type(avro_schema, item, dependency, dependency_type, dependencies, record_stack, recursion_depth + 1)
elif field['items'] == dependency:
if dependency_type in avro_schema:
field['items'] = dependency_type
avro_schema.remove(dependency_type)
dependencies.remove(dependency)
if dependency in dependencies:
dependencies.remove(dependency)
dependencies.extend(dependency_type.get('dependencies', []))
if 'dependencies' in dependency_type:
swap_record_dependencies(avro_schema, dependency_type, record_stack)
Expand All @@ -209,19 +317,24 @@ def swap_dependency_type(avro_schema, field, dependency, dependency_type, depend
for item in field['values']:
if item == dependency:
if dependency_type in avro_schema:
index = field['values'].index(item)
field['values'].remove(item)
field['values'].append(dependency_type)
field['values'].insert(index, dependency_type)
avro_schema.remove(dependency_type)
dependencies.remove(dependency)
if dependency in dependencies:
dependencies.remove(dependency)
dependencies.extend(dependency_type.get('dependencies', []))
if 'dependencies' in dependency_type:
swap_record_dependencies(avro_schema, dependency_type, record_stack)
for item in field['values']:
if isinstance(item, dict):
swap_dependency_type(avro_schema, item, dependency, dependency_type, dependencies, record_stack, recursion_depth + 1)
if field['values'] == dependency:
if dependency_type in avro_schema:
field['values'] = dependency_type
avro_schema.remove(dependency_type)
dependencies.remove(dependency)
avro_schema.remove(dependency_type)
if dependency in dependencies:
dependencies.remove(dependency)
dependencies.extend(dependency_type.get('dependencies', []))
if 'dependencies' in dependency_type:
swap_record_dependencies(avro_schema, dependency_type, record_stack)
Expand Down
Loading

0 comments on commit 6b47170

Please sign in to comment.