Skip to content

Commit

Permalink
Rename pg.symbolic.function_schema to pg.symbolic.callable_schema
Browse files Browse the repository at this point in the history
… and add `remove_self` flag.

PiperOrigin-RevId: 618037220
  • Loading branch information
daiyip authored and pyglove authors committed Mar 22, 2024
1 parent aad25de commit e15b9a2
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyglove/core/symbolic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
# TODO(daiyip): internal dependencies, remove later.
from pyglove.core.symbolic.schema_utils import formalize_schema
from pyglove.core.symbolic.schema_utils import augment_schema
from pyglove.core.symbolic.schema_utils import function_schema
from pyglove.core.symbolic.schema_utils import callable_schema
from pyglove.core.symbolic.schema_utils import update_schema


Expand Down
2 changes: 1 addition & 1 deletion pyglove/core/symbolic/compounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def compound_class(
if not inspect.isfunction(factory_fn):
raise TypeError('Decorator `compound` is only applicable to functions.')

schema = schema_utils.function_schema(
schema = schema_utils.callable_schema(
factory_fn,
args=args,
returns=pg_typing.Object(base_class) if base_class else None,
Expand Down
5 changes: 3 additions & 2 deletions pyglove/core/symbolic/functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ def _call(self, *args, **kwargs):
cls.auto_register = True

# Apply function schema.
schema = schema_utils.function_schema(
func, args, returns, auto_doc=auto_doc, auto_typing=auto_typing)
schema = schema_utils.callable_schema(
func, args, returns, auto_doc=auto_doc, auto_typing=auto_typing
)
cls.apply_schema(schema)

# Register functor class for deserialization if needed.
Expand Down
15 changes: 12 additions & 3 deletions pyglove/core/symbolic/schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class B(A):
cls.register_for_deserialization(serialization_key, additional_keys)


def function_schema(
def callable_schema(
func: types.FunctionType,
args: Optional[
List[
Expand All @@ -193,8 +193,9 @@ def function_schema(
*,
auto_typing: bool = True,
auto_doc: bool = True,
remove_self: bool = False,
) -> pg_typing.Schema:
"""Returns the schema from the signature of a function."""
"""Returns the schema from the signature of a callable."""
args_docstr = None
description = None
if auto_doc:
Expand All @@ -210,15 +211,23 @@ def function_schema(
raise ValueError('return value spec should not have default value.')
returns = returns or signature.return_value

if remove_self and arg_fields and arg_fields[0].key == 'self':
arg_fields.pop(0)

# Generate init_arg_list from signature.
init_arg_list = [arg.name for arg in signature.args]
if signature.varargs:
init_arg_list.append(f'*{signature.varargs.name}')

# Decide schema name.
module_name = getattr(func, '__module__', None)
func_name = func.__qualname__
schema_name = f'{module_name}.{func_name}' if module_name else func_name

return formalize_schema(
pg_typing.create_schema(
fields=arg_fields,
name=f'{func.__module__}.{func.__name__}',
name=schema_name,
metadata={
'init_arg_list': init_arg_list,
'varargs_name': getattr(signature.varargs, 'name', None),
Expand Down
82 changes: 82 additions & 0 deletions pyglove/core/symbolic/schema_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,91 @@
import unittest

from pyglove.core import object_utils
from pyglove.core import typing as pg_typing
from pyglove.core.symbolic import list as pg_list # pylint: disable=unused-import
from pyglove.core.symbolic import schema_utils


class CallableSchemaTest(unittest.TestCase):
"""Tests for `callable_schema`."""

def test_function_schema(self):
def foo(x: int, *args, y: str, **kwargs) -> float:
"""A function.
Args:
x: Input 1.
*args: Variable positional args.
y: Input 2.
**kwargs: Variable keyword args.
Returns:
The result.
"""
del x, y, args, kwargs

schema = schema_utils.callable_schema(foo, auto_typing=True, auto_doc=True)
self.assertEqual(schema.name, f'{foo.__module__}.{foo.__qualname__}')
self.assertEqual(schema.description, 'A function.')
self.assertEqual(
list(schema.fields.values()),
[
pg_typing.Field('x', pg_typing.Int(), description='Input 1.'),
pg_typing.Field(
'args',
pg_typing.List(pg_typing.Any(), default=[]),
description='Variable positional args.',
),
pg_typing.Field('y', pg_typing.Str(), description='Input 2.'),
pg_typing.Field(
pg_typing.StrKey(),
pg_typing.Any(),
description='Variable keyword args.',
),
],
)

def test_class_init_schema(self):
class A:

def __init__(self, x: int, *args, y: str, **kwargs) -> float:
"""Constructor.
Args:
x: Input 1.
*args: Variable positional args.
y: Input 2.
**kwargs: Variable keyword args.
Returns:
The result.
"""
del x, y, args, kwargs

schema = schema_utils.callable_schema(
A.__init__, auto_typing=True, auto_doc=True, remove_self=True
)
self.assertEqual(schema.name, f'{A.__module__}.{A.__init__.__qualname__}')
self.assertEqual(schema.description, 'Constructor.')
self.assertEqual(
list(schema.fields.values()),
[
pg_typing.Field('x', pg_typing.Int(), description='Input 1.'),
pg_typing.Field(
'args',
pg_typing.List(pg_typing.Any(), default=[]),
description='Variable positional args.',
),
pg_typing.Field('y', pg_typing.Str(), description='Input 2.'),
pg_typing.Field(
pg_typing.StrKey(),
pg_typing.Any(),
description='Variable keyword args.',
),
],
)


class SchemaDescriptionFromDocStrTest(unittest.TestCase):
"""Tests for `schema_description_from_docstr`."""

Expand Down

0 comments on commit e15b9a2

Please sign in to comment.