diff --git a/pyglove/core/symbolic/__init__.py b/pyglove/core/symbolic/__init__.py index 55fdf03..2ed8a8d 100644 --- a/pyglove/core/symbolic/__init__.py +++ b/pyglove/core/symbolic/__init__.py @@ -143,7 +143,7 @@ # TODO(daiyip): internal dependencies, remove later. from pyglove.core.symbolic.schema_utils import formalize_schema from pyglove.core.symbolic.schema_utils import augment_schema -from pyglove.core.symbolic.schema_utils import function_schema +from pyglove.core.symbolic.schema_utils import callable_schema from pyglove.core.symbolic.schema_utils import update_schema diff --git a/pyglove/core/symbolic/compounding.py b/pyglove/core/symbolic/compounding.py index e0a3aaa..32b0f25 100644 --- a/pyglove/core/symbolic/compounding.py +++ b/pyglove/core/symbolic/compounding.py @@ -127,7 +127,7 @@ def compound_class( if not inspect.isfunction(factory_fn): raise TypeError('Decorator `compound` is only applicable to functions.') - schema = schema_utils.function_schema( + schema = schema_utils.callable_schema( factory_fn, args=args, returns=pg_typing.Object(base_class) if base_class else None, diff --git a/pyglove/core/symbolic/functor.py b/pyglove/core/symbolic/functor.py index 0cdfb84..26893a7 100644 --- a/pyglove/core/symbolic/functor.py +++ b/pyglove/core/symbolic/functor.py @@ -684,8 +684,9 @@ def _call(self, *args, **kwargs): cls.auto_register = True # Apply function schema. - schema = schema_utils.function_schema( - func, args, returns, auto_doc=auto_doc, auto_typing=auto_typing) + schema = schema_utils.callable_schema( + func, args, returns, auto_doc=auto_doc, auto_typing=auto_typing + ) cls.apply_schema(schema) # Register functor class for deserialization if needed. diff --git a/pyglove/core/symbolic/schema_utils.py b/pyglove/core/symbolic/schema_utils.py index 6922dc2..252a20e 100644 --- a/pyglove/core/symbolic/schema_utils.py +++ b/pyglove/core/symbolic/schema_utils.py @@ -177,7 +177,7 @@ class B(A): cls.register_for_deserialization(serialization_key, additional_keys) -def function_schema( +def callable_schema( func: types.FunctionType, args: Optional[ List[ @@ -193,8 +193,9 @@ def function_schema( *, auto_typing: bool = True, auto_doc: bool = True, + remove_self: bool = False, ) -> pg_typing.Schema: - """Returns the schema from the signature of a function.""" + """Returns the schema from the signature of a callable.""" args_docstr = None description = None if auto_doc: @@ -210,15 +211,23 @@ def function_schema( raise ValueError('return value spec should not have default value.') returns = returns or signature.return_value + if remove_self and arg_fields and arg_fields[0].key == 'self': + arg_fields.pop(0) + # Generate init_arg_list from signature. init_arg_list = [arg.name for arg in signature.args] if signature.varargs: init_arg_list.append(f'*{signature.varargs.name}') + # Decide schema name. + module_name = getattr(func, '__module__', None) + func_name = func.__qualname__ + schema_name = f'{module_name}.{func_name}' if module_name else func_name + return formalize_schema( pg_typing.create_schema( fields=arg_fields, - name=f'{func.__module__}.{func.__name__}', + name=schema_name, metadata={ 'init_arg_list': init_arg_list, 'varargs_name': getattr(signature.varargs, 'name', None), diff --git a/pyglove/core/symbolic/schema_utils_test.py b/pyglove/core/symbolic/schema_utils_test.py index 7a91fe0..bca46b9 100644 --- a/pyglove/core/symbolic/schema_utils_test.py +++ b/pyglove/core/symbolic/schema_utils_test.py @@ -16,9 +16,91 @@ import unittest from pyglove.core import object_utils +from pyglove.core import typing as pg_typing +from pyglove.core.symbolic import list as pg_list # pylint: disable=unused-import from pyglove.core.symbolic import schema_utils +class CallableSchemaTest(unittest.TestCase): + """Tests for `callable_schema`.""" + + def test_function_schema(self): + def foo(x: int, *args, y: str, **kwargs) -> float: + """A function. + + Args: + x: Input 1. + *args: Variable positional args. + y: Input 2. + **kwargs: Variable keyword args. + + Returns: + The result. + """ + del x, y, args, kwargs + + schema = schema_utils.callable_schema(foo, auto_typing=True, auto_doc=True) + self.assertEqual(schema.name, f'{foo.__module__}.{foo.__qualname__}') + self.assertEqual(schema.description, 'A function.') + self.assertEqual( + list(schema.fields.values()), + [ + pg_typing.Field('x', pg_typing.Int(), description='Input 1.'), + pg_typing.Field( + 'args', + pg_typing.List(pg_typing.Any(), default=[]), + description='Variable positional args.', + ), + pg_typing.Field('y', pg_typing.Str(), description='Input 2.'), + pg_typing.Field( + pg_typing.StrKey(), + pg_typing.Any(), + description='Variable keyword args.', + ), + ], + ) + + def test_class_init_schema(self): + class A: + + def __init__(self, x: int, *args, y: str, **kwargs) -> float: + """Constructor. + + Args: + x: Input 1. + *args: Variable positional args. + y: Input 2. + **kwargs: Variable keyword args. + + Returns: + The result. + """ + del x, y, args, kwargs + + schema = schema_utils.callable_schema( + A.__init__, auto_typing=True, auto_doc=True, remove_self=True + ) + self.assertEqual(schema.name, f'{A.__module__}.{A.__init__.__qualname__}') + self.assertEqual(schema.description, 'Constructor.') + self.assertEqual( + list(schema.fields.values()), + [ + pg_typing.Field('x', pg_typing.Int(), description='Input 1.'), + pg_typing.Field( + 'args', + pg_typing.List(pg_typing.Any(), default=[]), + description='Variable positional args.', + ), + pg_typing.Field('y', pg_typing.Str(), description='Input 2.'), + pg_typing.Field( + pg_typing.StrKey(), + pg_typing.Any(), + description='Variable keyword args.', + ), + ], + ) + + class SchemaDescriptionFromDocStrTest(unittest.TestCase): """Tests for `schema_description_from_docstr`."""