Skip to content

Commit 6e490f7

Browse files
Googlercopybara-github
authored andcommitted
Implement flag_group in the new rule-based toolchain.
BEGIN_PUBLIC Implement flag_group in the new rule-based toolchain. END_PUBLIC PiperOrigin-RevId: 622107179 Change-Id: I9e1971e279f313ce85537c899bcf80860616f8b7
1 parent 5467790 commit 6e490f7

File tree

6 files changed

+458
-35
lines changed

6 files changed

+458
-35
lines changed

cc/toolchains/args.bzl

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,43 +13,50 @@
1313
# limitations under the License.
1414
"""All providers for rule-based bazel toolchain config."""
1515

16-
load("//cc:cc_toolchain_config_lib.bzl", "flag_group")
16+
load("//cc/toolchains/impl:args_utils.bzl", "validate_nested_args")
1717
load(
1818
"//cc/toolchains/impl:collect.bzl",
1919
"collect_action_types",
2020
"collect_files",
2121
"collect_provider",
2222
)
23+
load(
24+
"//cc/toolchains/impl:nested_args.bzl",
25+
"NESTED_ARGS_ATTRS",
26+
"args_wrapper_macro",
27+
"nested_args_provider_from_ctx",
28+
)
2329
load(
2430
":cc_toolchain_info.bzl",
2531
"ActionTypeSetInfo",
2632
"ArgsInfo",
2733
"ArgsListInfo",
34+
"BuiltinVariablesInfo",
2835
"FeatureConstraintInfo",
29-
"NestedArgsInfo",
3036
)
3137

3238
visibility("public")
3339

3440
def _cc_args_impl(ctx):
35-
if not ctx.attr.args and not ctx.attr.env:
36-
fail("cc_args requires at least one of args and env")
37-
3841
actions = collect_action_types(ctx.attr.actions)
39-
files = collect_files(ctx.attr.data)
40-
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
42+
43+
if not ctx.attr.args and not ctx.attr.nested and not ctx.attr.env:
44+
fail("cc_args requires at least one of args, nested, and env")
4145

4246
nested = None
43-
if ctx.attr.args:
44-
# TODO: This is temporary until cc_nested_args is implemented.
45-
nested = NestedArgsInfo(
47+
if ctx.attr.args or ctx.attr.nested:
48+
nested = nested_args_provider_from_ctx(ctx)
49+
validate_nested_args(
50+
variables = ctx.attr._variables[BuiltinVariablesInfo].variables,
51+
nested_args = nested,
52+
actions = actions.to_list(),
4653
label = ctx.label,
47-
nested = tuple(),
48-
iterate_over = None,
49-
files = files,
50-
requires_types = {},
51-
legacy_flag_group = flag_group(flags = ctx.attr.args),
5254
)
55+
files = nested.files
56+
else:
57+
files = collect_files(ctx.attr.data)
58+
59+
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)
5360

5461
args = ArgsInfo(
5562
label = ctx.label,
@@ -72,7 +79,7 @@ def _cc_args_impl(ctx):
7279
),
7380
]
7481

75-
cc_args = rule(
82+
_cc_args = rule(
7683
implementation = _cc_args_impl,
7784
attrs = {
7885
"actions": attr.label_list(
@@ -82,21 +89,6 @@ cc_args = rule(
8289
8390
See @rules_cc//cc/toolchains/actions:all for valid options.
8491
""",
85-
),
86-
"args": attr.string_list(
87-
doc = """Arguments that should be added to the command-line.
88-
89-
These are evaluated in order, with earlier args appearing earlier in the
90-
invocation of the underlying tool.
91-
""",
92-
),
93-
"data": attr.label_list(
94-
allow_files = True,
95-
doc = """Files required to add this argument to the command-line.
96-
97-
For example, a flag that sets the header directory might add the headers in that
98-
directory as additional files.
99-
""",
10092
),
10193
"env": attr.string_dict(
10294
doc = "Environment variables to be added to the command-line.",
@@ -108,7 +100,10 @@ directory as additional files.
108100
If omitted, this flag set will be enabled unconditionally.
109101
""",
110102
),
111-
},
103+
"_variables": attr.label(
104+
default = "//cc/toolchains/variables:variables",
105+
),
106+
} | NESTED_ARGS_ATTRS,
112107
provides = [ArgsInfo],
113108
doc = """Declares a list of arguments bound to a set of actions.
114109
@@ -121,3 +116,5 @@ Examples:
121116
)
122117
""",
123118
)
119+
120+
cc_args = lambda **kwargs: args_wrapper_macro(rule = _cc_args, **kwargs)

cc/toolchains/impl/args_utils.bzl

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""."""
14+
"""Helper functions for working with args."""
15+
16+
load(":variables.bzl", "get_type")
17+
18+
visibility([
19+
"//cc/toolchains",
20+
"//tests/rule_based_toolchain/...",
21+
])
1522

1623
def get_action_type(args_list, action_type):
1724
"""Returns the corresponding entry in ArgsListInfo.by_action.
@@ -28,3 +35,87 @@ def get_action_type(args_list, action_type):
2835
return args
2936

3037
return struct(action = action_type, args = tuple(), files = depset([]))
38+
39+
def validate_nested_args(*, nested_args, variables, actions, label, fail = fail):
40+
"""Validates the typing for an nested_args invocation.
41+
42+
Args:
43+
nested_args: (NestedArgsInfo) The nested_args to validate
44+
variables: (Dict[str, VariableInfo]) A mapping from variable name to
45+
the metadata (variable type and valid actions).
46+
actions: (List[ActionTypeInfo]) The actions we require these variables
47+
to be valid for.
48+
label: (Label) The label of the rule we're currently validating.
49+
Used for error messages.
50+
fail: The fail function. Use for testing only.
51+
"""
52+
stack = [(nested_args, {})]
53+
54+
for _ in range(9999999):
55+
if not stack:
56+
break
57+
nested_args, overrides = stack.pop()
58+
if nested_args.iterate_over != None or nested_args.unwrap_options:
59+
# Make sure we don't keep using the same object.
60+
overrides = dict(**overrides)
61+
62+
if nested_args.iterate_over != None:
63+
type = get_type(
64+
name = nested_args.iterate_over,
65+
variables = variables,
66+
overrides = overrides,
67+
actions = actions,
68+
args_label = label,
69+
nested_label = nested_args.label,
70+
fail = fail,
71+
)
72+
if type["name"] == "list":
73+
# Rewrite the type of the thing we iterate over from a List[T]
74+
# to a T.
75+
overrides[nested_args.iterate_over] = type["elements"]
76+
elif type["name"] == "option" and type["elements"]["name"] == "list":
77+
# Rewrite Option[List[T]] to T.
78+
overrides[nested_args.iterate_over] = type["elements"]["elements"]
79+
else:
80+
fail("Attempting to iterate over %s, but it was not a list - it was a %s" % (nested_args.iterate_over, type["repr"]))
81+
82+
# 1) Validate variables marked with after_option_unwrap = False.
83+
# 2) Unwrap Option[T] to T as required.
84+
# 3) Validate variables marked with after_option_unwrap = True.
85+
for after_option_unwrap in [False, True]:
86+
for var_name, requirements in nested_args.requires_types.items():
87+
for requirement in requirements:
88+
if requirement.after_option_unwrap == after_option_unwrap:
89+
type = get_type(
90+
name = var_name,
91+
variables = variables,
92+
overrides = overrides,
93+
actions = actions,
94+
args_label = label,
95+
nested_label = nested_args.label,
96+
fail = fail,
97+
)
98+
if type["name"] not in requirement.valid_types:
99+
fail("{msg}, but {var_name} has type {type}".format(
100+
var_name = var_name,
101+
msg = requirement.msg,
102+
type = type["repr"],
103+
))
104+
105+
# Only unwrap the options after the first iteration of this loop.
106+
if not after_option_unwrap:
107+
for var in nested_args.unwrap_options:
108+
type = get_type(
109+
name = var,
110+
variables = variables,
111+
overrides = overrides,
112+
actions = actions,
113+
args_label = label,
114+
nested_label = nested_args.label,
115+
fail = fail,
116+
)
117+
if type["name"] == "option":
118+
overrides[var] = type["elements"]
119+
120+
for child in nested_args.nested:
121+
stack.append((child, overrides))

cc/toolchains/impl/nested_args.bzl

Lines changed: 123 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
"""Helper functions for working with args."""
1515

16+
load("@bazel_skylib//lib:structs.bzl", "structs")
1617
load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
17-
load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo")
18+
load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo")
19+
load(":collect.bzl", "collect_files", "collect_provider")
1820

1921
visibility([
2022
"//cc/toolchains",
@@ -48,6 +50,126 @@ cc_args(
4850
iterate_over = "//toolchains/variables:foo_list",
4951
"""
5052

53+
# @unsorted-dict-items.
54+
NESTED_ARGS_ATTRS = {
55+
"args": attr.string_list(
56+
doc = """json-encoded arguments to be added to the command-line.
57+
58+
Usage:
59+
cc_args(
60+
...,
61+
args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")]
62+
)
63+
64+
This is equivalent to flag_group(flags = ["--foo", "%{foo}"])
65+
66+
Mutually exclusive with nested.
67+
""",
68+
),
69+
"nested": attr.label_list(
70+
providers = [NestedArgsInfo],
71+
doc = """nested_args that should be added on the command-line.
72+
73+
Mutually exclusive with args.""",
74+
),
75+
"data": attr.label_list(
76+
allow_files = True,
77+
doc = """Files required to add this argument to the command-line.
78+
79+
For example, a flag that sets the header directory might add the headers in that
80+
directory as additional files.
81+
""",
82+
),
83+
"variables": attr.label_list(
84+
providers = [VariableInfo],
85+
doc = "Variables to be used in substitutions",
86+
),
87+
"iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"),
88+
"requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"),
89+
"requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"),
90+
"requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"),
91+
"requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"),
92+
"requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"),
93+
"requires_equal_value": attr.string(),
94+
}
95+
96+
def args_wrapper_macro(*, name, rule, args = [], **kwargs):
97+
"""Invokes a rule by converting args to attributes.
98+
99+
Args:
100+
name: (str) The name of the target.
101+
rule: (rule) The rule to invoke. Either cc_args or cc_nested_args.
102+
args: (List[str|Formatted]) A list of either strings, or function calls
103+
from format.bzl. For example:
104+
["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")]
105+
**kwargs: kwargs to pass through into the rule invocation.
106+
"""
107+
out_args = []
108+
vars = []
109+
if type(args) != "list":
110+
fail("Args must be a list in %s" % native.package_relative_label(name))
111+
for arg in args:
112+
if type(arg) == "string":
113+
out_args.append(raw_string(arg))
114+
elif getattr(arg, "format_type") == "format_arg":
115+
arg = structs.to_dict(arg)
116+
if arg["value"] == None:
117+
out_args.append(arg)
118+
else:
119+
var = arg.pop("value")
120+
121+
# Swap the variable from a label to an index. This allows us to
122+
# actually get the providers in a rule.
123+
out_args.append(struct(value = len(vars), **arg))
124+
vars.append(var)
125+
else:
126+
fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg))
127+
128+
rule(
129+
name = name,
130+
args = [json.encode(arg) for arg in out_args],
131+
variables = vars,
132+
**kwargs
133+
)
134+
135+
def _var(target):
136+
if target == None:
137+
return None
138+
return target[VariableInfo].name
139+
140+
# TODO: Consider replacing this with a subrule in the future. However, maybe not
141+
# for a long time, since it'll break compatibility with all bazel versions < 7.
142+
def nested_args_provider_from_ctx(ctx):
143+
"""Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS.
144+
145+
Args:
146+
ctx: The rule context
147+
Returns:
148+
NestedArgsInfo
149+
"""
150+
variables = collect_provider(ctx.attr.variables, VariableInfo)
151+
args = []
152+
for arg in ctx.attr.args:
153+
arg = json.decode(arg)
154+
if "value" in arg:
155+
if arg["value"] != None:
156+
arg["value"] = variables[arg["value"]]
157+
args.append(struct(**arg))
158+
159+
return nested_args_provider(
160+
label = ctx.label,
161+
args = args,
162+
nested = collect_provider(ctx.attr.nested, NestedArgsInfo),
163+
files = collect_files(ctx.attr.data),
164+
iterate_over = _var(ctx.attr.iterate_over),
165+
requires_not_none = _var(ctx.attr.requires_not_none),
166+
requires_none = _var(ctx.attr.requires_none),
167+
requires_true = _var(ctx.attr.requires_true),
168+
requires_false = _var(ctx.attr.requires_false),
169+
requires_equal = _var(ctx.attr.requires_equal),
170+
requires_equal_value = ctx.attr.requires_equal_value,
171+
)
172+
51173
def raw_string(s):
52174
"""Constructs metadata for creating a raw string.
53175

0 commit comments

Comments
 (0)