Skip to content

Commit

Permalink
Structure checking (#615)
Browse files Browse the repository at this point in the history
* Test on Python 3.13

Remove ruff flakiness
Better showing of deltas when codegen fails

* Allow refactoring in _api.py. apipatcher know understands helper functions.

* A few quick typos, while I'm at it.

* Forgot to make corresponding changes to _classes.py

* Forgot to make corresponding changes to _classes.py

* Removed 3.8
Replaced ruff temporary file with stdin/stdout

* Change requirements to >= 3.9

* Undo 3.8 => 3.9 change

* Typo

* Typo
  • Loading branch information
fyellin authored Oct 16, 2024
1 parent 52f6baa commit 6bb8bf6
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 165 deletions.
67 changes: 58 additions & 9 deletions codegen/apipatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
spec (IDL), and the backend implementations from the base API.
"""

from codegen.utils import print, format_code, to_snake_case, to_camel_case, Patcher
from codegen.idlparser import get_idl_parser, Attribute
from codegen.files import file_cache
import ast
from collections import defaultdict
from functools import cache

from codegen.files import file_cache
from codegen.idlparser import Attribute, get_idl_parser
from codegen.utils import Patcher, format_code, print, to_camel_case, to_snake_case

# In wgpu-py, we make some args optional, that are not optional in the
# IDL. Reasons may be because it makes sense to be able to omit them,
Expand Down Expand Up @@ -620,6 +623,8 @@ def apply(self, code):
all_structs = set()
ignore_structs = {"Extent3D", "Origin3D"}

structure_checks = self._get_structure_checks()

for classname, i1, i2 in self.iter_classes():
if classname not in idl.classes:
continue
Expand All @@ -639,12 +644,8 @@ def apply(self, code):
method_structs.update(self._get_sub_structs(idl, structname))
all_structs.update(method_structs)
# Collect structs being checked
checked = set()
for line in code.splitlines():
line = line.lstrip()
if line.startswith("check_struct("):
name = line.split("(")[1].split(",")[0].strip('"')
checked.add(name)
checked = structure_checks[classname, methodname]

# Test that a matching check is done
unchecked = method_structs.difference(checked)
unchecked = list(sorted(unchecked.difference(ignore_structs)))
Expand Down Expand Up @@ -674,3 +675,51 @@ def _get_sub_structs(self, idl, structname):
if structname2 in idl.structs:
structnames.update(self._get_sub_structs(idl, structname2))
return structnames

@staticmethod
def _get_structure_checks():
"""
Returns a map
(class_name, method_name) -> <list of structure names>
mapping each top-level method in _api.py to the calls to check_struct made by
that method or by any helper methods called by that method.
For now, the helper function must be methods within the same class. This code
does not yet deal with global functions or with methods in superclasses.
"""
module = ast.parse(file_cache.read("backends/wgpu_native/_api.py"))
# We only care about top-level classes and their top-level methods.
top_level_methods = {
# (class_name, method_name) -> method_ast
(class_ast.name, method_ast.name): method_ast
for class_ast in module.body
if isinstance(class_ast, ast.ClassDef)
for method_ast in class_ast.body
if isinstance(method_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
}

# (class_name, method_name) -> list of helper methods
method_helper_calls = defaultdict(list)
# (class_name, method_name) -> list of structures checked
structure_checks = defaultdict(list)

for key, method_ast in top_level_methods.items():
for node in ast.walk(method_ast):
if isinstance(node, ast.Call):
name = ast.unparse(node.func)
if name.startswith("self._"):
method_helper_calls[key].append(name[5:])
if name == "check_struct":
assert isinstance(node.args[0], ast.Constant)
struct_name = node.args[0].value
assert isinstance(struct_name, str)
structure_checks[key].append(struct_name)

@cache
def get_function_checks(class_name, method_name):
result = set(structure_checks[class_name, method_name])
for helper_method_name in method_helper_calls[class_name, method_name]:
result.update(get_function_checks(class_name, helper_method_name))
return sorted(result)

return {key: get_function_checks(*key) for key in top_level_methods.keys()}
6 changes: 3 additions & 3 deletions wgpu/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class GPU:
def request_adapter_sync(
self,
*,
power_preference: enums.GPUPowerPreference = None,
power_preference: enums.PowerPreference = None,
force_fallback_adapter: bool = False,
canvas=None,
):
Expand All @@ -114,7 +114,7 @@ def request_adapter_sync(
async def request_adapter_async(
self,
*,
power_preference: enums.GPUPowerPreference = None,
power_preference: enums.PowerPreference = None,
force_fallback_adapter: bool = False,
canvas=None,
):
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def create_render_pipeline(
layout (GPUPipelineLayout): The layout for the new pipeline.
vertex (structs.VertexState): Describes the vertex shader entry point of the
pipeline and its input buffer layouts.
primitive (structs.PrimitiveState): Describes the the primitive-related properties
primitive (structs.PrimitiveState): Describes the primitive-related properties
of the pipeline. If `strip_index_format` is present (which means the
primitive topology is a strip), and the drawCall is indexed, the
vertex index list is split into sub-lists using the maximum value of this
Expand Down
Loading

0 comments on commit 6bb8bf6

Please sign in to comment.