Skip to content

Commit

Permalink
Make docstring optional (#259)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #259

* Refactor docstring functions: combines two functions that retrieve docstring into one
* Make docstring optional
* Remove docstring validator

Git issue: #253

Reviewed By: kiukchung

Differential Revision: D31671125

fbshipit-source-id: 2da71fcecf0d05f03c04dcc29b44ec43ab919eaa
  • Loading branch information
aivanou authored and facebook-github-bot committed Oct 16, 2021
1 parent bd1c36d commit 74e07d8
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 224 deletions.
8 changes: 8 additions & 0 deletions docs/source/component_best_practices.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ others to understand how to use it.
return AppDef(roles=[Role(..., num_replicas=num_replicas)])
Documentation
^^^^^^^^^^^^^^^^^^^^^
The documentation is optional, but it is the best practice to keep component functions documented,
especially if you want to share your components. See :ref:Component Authoring<components/overview:Authoring>
for more details.
Named Resources
-----------------
Expand Down
30 changes: 11 additions & 19 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Generic,
Iterator,
List,
Mapping,
Optional,
Tuple,
Type,
Expand All @@ -28,8 +27,7 @@
)

import yaml
from pyre_extensions import none_throws
from torchx.specs.file_linter import parse_fn_docstring
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
from torchx.util.types import decode_from_string, decode_optional, is_bool, is_primitive


Expand Down Expand Up @@ -748,22 +746,21 @@ def get_argparse_param_type(parameter: inspect.Parameter) -> Callable[[str], obj
return str


def _create_args_parser(
fn_name: str,
parameters: Mapping[str, inspect.Parameter],
function_desc: str,
args_desc: Dict[str, str],
) -> argparse.ArgumentParser:
def _create_args_parser(app_fn: Callable[..., AppDef]) -> argparse.ArgumentParser:
parameters = inspect.signature(app_fn).parameters
function_desc, args_desc = get_fn_docstring(app_fn)
script_parser = argparse.ArgumentParser(
prog=f"torchx run ...torchx_params... {fn_name} ",
description=f"App spec: {function_desc}",
prog=f"torchx run <<torchx_params>> {app_fn.__name__} ",
description=f"AppDef for {function_desc}",
formatter_class=TorchXArgumentHelpFormatter,
)

remainder_arg = []

for param_name, parameter in parameters.items():
param_desc = args_desc[parameter.name]
args: Dict[str, Any] = {
"help": args_desc[param_name],
"help": param_desc,
"type": get_argparse_param_type(parameter),
}
if parameter.default != inspect.Parameter.empty:
Expand All @@ -788,20 +785,15 @@ def _create_args_parser(
def _get_function_args(
app_fn: Callable[..., AppDef], app_args: List[str]
) -> Tuple[List[object], List[str], Dict[str, object]]:
docstring = none_throws(inspect.getdoc(app_fn))
function_desc, args_desc = parse_fn_docstring(docstring)

parameters = inspect.signature(app_fn).parameters
script_parser = _create_args_parser(
app_fn.__name__, parameters, function_desc, args_desc
)
script_parser = _create_args_parser(app_fn)

parsed_args = script_parser.parse_args(app_args)

function_args = []
var_arg = []
kwargs = {}

parameters = inspect.signature(app_fn).parameters
for param_name, parameter in parameters.items():
arg_value = getattr(parsed_args, param_name)
parameter_type = parameter.annotation
Expand Down
137 changes: 56 additions & 81 deletions torchx/specs/file_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
# LICENSE file in the root directory of this source tree.

import abc
import argparse
import ast
import inspect
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, cast
from typing import Dict, List, Optional, Tuple, cast, Callable

from docstring_parser import parse
from pyre_extensions import none_throws
Expand All @@ -18,53 +20,66 @@
# pyre-ignore-all-errors[16]


def get_arg_names(app_specs_func_def: ast.FunctionDef) -> List[str]:
arg_names = []
fn_args = app_specs_func_def.args
for arg_def in fn_args.args:
arg_names.append(arg_def.arg)
if fn_args.vararg:
arg_names.append(fn_args.vararg.arg)
for arg in fn_args.kwonlyargs:
arg_names.append(arg.arg)
return arg_names
def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]:
parameters = inspect.signature(fn).parameters
args_decs = {}
for parameter_name in parameters.keys():
# The None or Empty string values getting ignored during help command by argparse
args_decs[parameter_name] = " "
return args_decs


def parse_fn_docstring(func_description: str) -> Tuple[str, Dict[str, str]]:
class TorchXArgumentHelpFormatter(argparse.HelpFormatter):
"""Help message formatter which adds default values and required to argument help.
If the argument is required, the class appends `(required)` at the end of the help message.
If the argument has default value, the class appends `(default: $DEFAULT)` at the end.
The formatter is designed to be used only for the torchx components functions.
These functions do not have both required and default arguments.
"""
Given a docstring in a google-style format, returns the function description and
description of all arguments.
See: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html

def _get_help_string(self, action: argparse.Action) -> str:
help = action.help or ""
# Only `--help` will have be SUPPRESS, so we ignore it
if action.default is argparse.SUPPRESS:
return help
if action.required:
help += " (required)"
else:
help += f" (default: {action.default})"
return help


def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
"""
args_description = {}
Parses the function and arguments description from the provided function. Docstring should be in
`google-style format <https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
If function has no docstring, the function description will be the name of the function, TIP
on how to improve the help message and arguments descriptions will be names of the arguments.
The arguments that are not present in the docstring will contain default/required information
Args:
fn: Function with or without docstring
Returns:
function description, arguments description where key is the name of the argument and value
if the description
"""
default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
args_description = _get_default_arguments_descriptions(fn)
func_description = inspect.getdoc(fn)
if not func_description:
return default_fn_desc, args_description
docstring = parse(func_description)
for param in docstring.params:
args_description[param.arg_name] = param.description
short_func_description = docstring.short_description
return (short_func_description or "", args_description)


def _get_fn_docstring(
source: str, function_name: str
) -> Optional[Tuple[str, Dict[str, str]]]:
module = ast.parse(source)
for expr in module.body:
if type(expr) == ast.FunctionDef:
func_def = cast(ast.FunctionDef, expr)
if func_def.name == function_name:
docstring = ast.get_docstring(func_def)
if not docstring:
return None
return parse_fn_docstring(docstring)
return None


def get_short_fn_description(path: str, function_name: str) -> Optional[str]:
source = read_conf_file(path)
docstring = _get_fn_docstring(source, function_name)
if not docstring:
return None
return docstring[0]
short_func_description = docstring.short_description or default_fn_desc
if docstring.long_description:
short_func_description += " ..."
return (short_func_description or default_fn_desc, args_description)


@dataclass
Expand All @@ -91,38 +106,6 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
)


class TorchxDocstringValidator(TorchxFunctionValidator):
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
"""
Validates the docstring of the `get_app_spec` function. Criteria:
* There mast be google-style docstring
* If there are more than zero arguments, there mast be a `Args:` section defined
with all arguments included.
"""
docsting = ast.get_docstring(app_specs_func_def)
lineno = app_specs_func_def.lineno
if not docsting:
desc = (
f"`{app_specs_func_def.name}` is missing a Google Style docstring, please add one. "
"For more information on the docstring format see: "
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
)
return [self._gen_linter_message(desc, lineno)]

arg_names = get_arg_names(app_specs_func_def)
_, docstring_arg_defs = parse_fn_docstring(docsting)
missing_args = [
arg_name for arg_name in arg_names if arg_name not in docstring_arg_defs
]
if len(missing_args) > 0:
desc = (
f"`{app_specs_func_def.name}` not all function arguments are present"
f" in the docstring. Missing args: {missing_args}"
)
return [self._gen_linter_message(desc, lineno)]
return []


class TorchxFunctionArgsValidator(TorchxFunctionValidator):
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
linter_errors = []
Expand All @@ -149,7 +132,6 @@ def _validate_arg_def(
)
]
if isinstance(arg_def.annotation, ast.Name):
# TODO(aivanou): add support for primitive type check
return []
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
if complex_type_def.value.id == "Optional":
Expand Down Expand Up @@ -239,12 +221,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
Visitor that finds the component_function and runs registered validators on it.
Current registered validators:
* TorchxDocstringValidator - validates the docstring of the function.
Criteria:
* There format should be google-python
* If there are more than zero arguments defined, there
should be obligatory `Args:` section that describes each argument on a new line.
* TorchxFunctionArgsValidator - validates arguments of the function.
Criteria:
* Each argument should be annotated with the type
Expand All @@ -260,7 +236,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):

def __init__(self, component_function_name: str) -> None:
self.validators = [
TorchxDocstringValidator(),
TorchxFunctionArgsValidator(),
TorchxReturnValidator(),
]
Expand Down
11 changes: 6 additions & 5 deletions torchx/specs/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pyre_extensions import none_throws
from torchx.specs import AppDef
from torchx.specs.file_linter import get_short_fn_description, validate
from torchx.specs.file_linter import get_fn_docstring, validate
from torchx.util import entrypoints
from torchx.util.io import read_conf_file

Expand All @@ -40,14 +40,15 @@ class _Component:
Args:
name: The name of the component, which usually MODULE_PATH.FN_NAME
description: The description of the component, taken from the desrciption
of the function that creates component
of the function that creates component. In case of no docstring, description
will be the same as name
fn_name: Function name that creates component
fn: Function that creates component
validation_errors: Validation errors
"""

name: str
description: Optional[str]
description: str
fn_name: str
fn: Callable[..., AppDef]
validation_errors: List[str]
Expand Down Expand Up @@ -150,7 +151,7 @@ def _get_components_from_module(
module_path = os.path.abspath(module.__file__)
for function_name, function in functions:
linter_errors = validate(module_path, function_name)
component_desc = get_short_fn_description(module_path, function_name)
component_desc, _ = get_fn_docstring(function)
component_def = _Component(
name=self._get_component_name(
base_module, module.__name__, function_name
Expand Down Expand Up @@ -197,7 +198,6 @@ def find(self) -> List[_Component]:
validation_errors = self._get_validation_errors(
self._filepath, self._function_name
)
fn_desc = get_short_fn_description(self._filepath, self._function_name)

file_source = read_conf_file(self._filepath)
namespace = globals()
Expand All @@ -207,6 +207,7 @@ def find(self) -> List[_Component]:
f"Function {self._function_name} does not exist in file {self._filepath}"
)
app_fn = namespace[self._function_name]
fn_desc, _ = get_fn_docstring(app_fn)
return [
_Component(
name=f"{self._filepath}:{self._function_name}",
Expand Down
Loading

0 comments on commit 74e07d8

Please sign in to comment.