Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic scala_test rule #959

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,32 @@ load("//test/proto_cross_repo_boundary:repo.bzl", "proto_cross_repo_boundary_rep

proto_cross_repo_boundary_repository()

# test sbt testing frameworks
scala_maven_import_external(
name = "org_scalacheck_scalacheck",
artifact = scala_mvn_artifact(
"org.scalacheck:scalacheck:1.14.3",
default_scala_major_version(),
),
artifact_sha256 = "3cbc95bb615f1a384b8c4406dfc42b225499f08adf7639de11566069e47d44cf",
licenses = ["notice"], # Apache 2.0
server_urls = [
"https://repo1.maven.org/maven2/",
"https://mirror.bazel.build/repo1.maven.org/maven2",
],
)

scala_maven_import_external(
name = "com_novocode_junit_interface",
artifact = "com.novocode:junit-interface:0.11",
artifact_sha256 = "29e923226a0d10e9142bbd81073ef52f601277001fcf9014389bf0af3dc33dc3",
licenses = ["notice"], # Apache 2.0
server_urls = [
"https://repo1.maven.org/maven2/",
"https://mirror.bazel.build/repo1.maven.org/maven2",
],
)

# test adding a scala jar:
jvm_maven_import_external(
name = "com_twitter__scalding_date",
Expand Down
40 changes: 40 additions & 0 deletions scala/defs.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Starlark rules for building Scala projects.

These are the core rules (library, binary, test) under active
development. Their APIs are not guaranteed stable and we anticipate
some breaking changes.

We do not yet recommend using these APIs for production codebases. Instead,
use the stable rules exported by scala.bzl:

```
load(
"@io_bazel_rules_scala//scala:scala.bzl",
"scala_library",
"scala_binary",
"scala_test"
)
```

"""

load(
"@io_bazel_rules_scala//scala/private:rules/scala_binary.bzl",
_make_scala_binary = "make_scala_binary",
)
load(
"@io_bazel_rules_scala//scala/private:rules/scala_library.bzl",
_make_scala_library = "make_scala_library",
)
load(
"@io_bazel_rules_scala//scala/private:rules/unstable_scala_test.bzl",
_make_scala_test = "make_scala_test",
)

make_scala_library = _make_scala_library
make_scala_binary = _make_scala_binary
make_scala_test = _make_scala_test

scala_library = _make_scala_library()
scala_binary = _make_scala_binary()
scala_test = _make_scala_test()
17 changes: 17 additions & 0 deletions scala/private/macros/scala_repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,23 @@ def scala_repositories(
fetch_sources = fetch_sources,
)

# used by the experimental scala_test rule
_scala_maven_import_external(
name = "io_bazel_rules_scala_classgraph",
artifact = "io.github.classgraph:classgraph:jar:4.8.60",
artifact_sha256 = "dacf7d7fec4088e674ee98155adbb74f30af2f8b64f8990d37c223d8b9047b72",
licenses = ["notice"],
server_urls = maven_servers,
)

_scala_maven_import_external(
name = "io_bazel_rules_scala_test_interface",
artifact = "org.scala-sbt:test-interface:jar:1.0",
artifact_sha256 = "15f70b38bb95f3002fec9aea54030f19bb4ecfbad64c67424b5e5fea09cd749e",
licenses = ["notice"],
server_urls = maven_servers,
)

if not native.existing_rule("com_google_protobuf"):
http_archive(
name = "com_google_protobuf",
Expand Down
24 changes: 24 additions & 0 deletions scala/private/phases/phase_collect_jars.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ def phase_collect_jars_scalatest(ctx, p):
)
return _phase_collect_jars_default(ctx, p, args)

def phase_collect_jars_unstable_scala_test(ctx, p):
args = struct(
base_classpath = p.scalac_provider.default_classpath,
extra_runtime_deps = [
ctx.attr._discover_tests_runner,
],
)
return _phase_collect_jars_default(ctx, p, args)

def phase_collect_jars_repl(ctx, p):
args = struct(
base_classpath = p.scalac_provider.default_repl_classpath,
Expand Down Expand Up @@ -45,6 +54,21 @@ def phase_collect_jars_common(ctx, p):
return _phase_collect_jars_default(ctx, p)

def _phase_collect_jars_default(ctx, p, _args = struct()):
extra_deps = []
extra_runtime_deps = []

phase_names = dir(p)
phase_names.remove("to_json")
phase_names.remove("to_proto")
for phase_name in phase_names:
phase = getattr(p, phase_name)

if hasattr(phase, "extra_deps"):
extra_deps.extend(phase.extra_deps)

if hasattr(phase, "extra_runtime_deps"):
extra_runtime_deps.extend(phase.extra_runtime_deps)

return _phase_collect_jars(
ctx,
p,
Expand Down
34 changes: 34 additions & 0 deletions scala/private/phases/phase_discover_tests.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
def phase_discover_tests(ctx, p):
worker = ctx.attr._discover_tests_worker
worker_inputs, _, worker_input_manifests = ctx.resolve_command(
tools = [worker],
)

output = ctx.actions.declare_file("{}_discovered_tests.bin".format(ctx.label.name))

args = ctx.actions.args()
args.set_param_file_format("multiline")
args.use_param_file("@%s", use_always = True)

args.add(output)
args.add_all(p.compile.files)
args.add("--")
args.add_all(p.collect_jars.transitive_runtime_jars)

ctx.actions.run(
mnemonic = "DiscoverTests",
inputs = worker_inputs + p.collect_jars.compile_jars.to_list() + p.compile.files.to_list(),
outputs = [output],
executable = worker.files_to_run.executable,
input_manifests = worker_input_manifests,
execution_requirements = {"supports-workers": "1"},
arguments = [args],
)

return struct(
files = depset([output]),
jvm_flags = [
"-DDiscoveredTestsResult={}".format(output.short_path),
],
runfiles = depset([output]),
)
15 changes: 13 additions & 2 deletions scala/private/phases/phase_write_executable.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,24 @@ def phase_write_executable_common(ctx, p):
return _phase_write_executable_default(ctx, p)

def _phase_write_executable_default(ctx, p, _args = struct()):
jvm_flags = []

phase_names = dir(p)
phase_names.remove("to_json")
phase_names.remove("to_proto")
for phase_name in phase_names:
phase = getattr(p, phase_name)

if hasattr(phase, "jvm_flags"):
jvm_flags.extend(phase.jvm_flags)

return _phase_write_executable(
ctx,
p,
_args.rjars if hasattr(_args, "rjars") else p.compile.rjars,
_args.jvm_flags if hasattr(_args, "jvm_flags") else ctx.attr.jvm_flags,
(_args.jvm_flags if hasattr(_args, "jvm_flags") else ctx.attr.jvm_flags) + jvm_flags,
_args.use_jacoco if hasattr(_args, "use_jacoco") else False,
_args.main_class if hasattr(_args, "main_class") else ctx.attr.main_class,
_args.main_class if hasattr(_args, "main_class") else ctx.attr._main_class if hasattr(ctx.attr, "_main_class") else ctx.attr.main_class,
)

def _phase_write_executable(
Expand Down
6 changes: 6 additions & 0 deletions scala/private/phases/phases.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ load(
_phase_collect_jars_macro_library = "phase_collect_jars_macro_library",
_phase_collect_jars_repl = "phase_collect_jars_repl",
_phase_collect_jars_scalatest = "phase_collect_jars_scalatest",
_phase_collect_jars_unstable_scala_test = "phase_collect_jars_unstable_scala_test",
)
load(
"@io_bazel_rules_scala//scala/private:phases/phase_compile.bzl",
Expand Down Expand Up @@ -63,6 +64,7 @@ load("@io_bazel_rules_scala//scala/private:phases/phase_declare_executable.bzl",
load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_merge_jars = "phase_merge_jars")
load("@io_bazel_rules_scala//scala/private:phases/phase_jvm_flags.bzl", _phase_jvm_flags = "phase_jvm_flags")
load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles")
load("@io_bazel_rules_scala//scala/private:phases/phase_discover_tests.bzl", _phase_discover_tests = "phase_discover_tests")
load("@io_bazel_rules_scala//scala/private:phases/phase_scalafmt.bzl", _phase_scalafmt = "phase_scalafmt")

# API
Expand Down Expand Up @@ -112,6 +114,7 @@ phase_java_wrapper_repl = _phase_java_wrapper_repl
phase_java_wrapper_common = _phase_java_wrapper_common

# collect_jars
phase_collect_jars_unstable_scala_test = _phase_collect_jars_unstable_scala_test
phase_collect_jars_scalatest = _phase_collect_jars_scalatest
phase_collect_jars_repl = _phase_collect_jars_repl
phase_collect_jars_macro_library = _phase_collect_jars_macro_library
Expand All @@ -136,5 +139,8 @@ phase_runfiles_common = _phase_runfiles_common
# default_info
phase_default_info = _phase_default_info

# discover_tests
phase_discover_tests = _phase_discover_tests

# scalafmt
phase_scalafmt = _phase_scalafmt
98 changes: 98 additions & 0 deletions scala/private/rules/unstable_scala_test.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Rules for writing tests with ScalaTest"""

load("@bazel_skylib//lib:dicts.bzl", _dicts = "dicts")
load(
"@io_bazel_rules_scala//scala/private:common_attributes.bzl",
"common_attrs",
"implicit_deps",
"launcher_template",
)
load("@io_bazel_rules_scala//scala/private:common.bzl", "sanitize_string_for_usage")
load("@io_bazel_rules_scala//scala/private:common_outputs.bzl", "common_outputs")
load(
"@io_bazel_rules_scala//scala/private:phases/phases.bzl",
"extras_phases",
"phase_collect_jars_unstable_scala_test",
"phase_compile_common",
"phase_coverage_common",
"phase_coverage_runfiles",
"phase_declare_executable",
"phase_default_info",
"phase_dependency_common",
"phase_discover_tests",
"phase_java_wrapper_common",
"phase_merge_jars",
"phase_runfiles_scalatest",
"phase_scalac_provider",
"phase_write_executable_scalatest",
"phase_write_manifest",
"run_phases",
)

def _scala_test_impl(ctx):
return run_phases(
ctx,
# customizable phases
[
("scalac_provider", phase_scalac_provider),
("write_manifest", phase_write_manifest),
("dependency", phase_dependency_common),
("collect_jars", phase_collect_jars_unstable_scala_test),
("java_wrapper", phase_java_wrapper_common),
("declare_executable", phase_declare_executable),
# no need to build an ijar for an executable
("compile", phase_compile_common),
("coverage", phase_coverage_common),
("merge_jars", phase_merge_jars),
("runfiles", phase_runfiles_scalatest),
("coverage_runfiles", phase_coverage_runfiles),
("discover_tests", phase_discover_tests),
("write_executable", phase_write_executable_scalatest),
("default_info", phase_default_info),
],
)

_scala_test_attrs = {
"_main_class": attr.string(
default = "io.bazel.rules_scala.discover_tests_runner.DiscoverTestsRunner",
),
"colors": attr.bool(default = True),
"full_stacktraces": attr.bool(default = True),
"jvm_flags": attr.string_list(),
"_jacocorunner": attr.label(
default = Label("@bazel_tools//tools/jdk:JacocoCoverage"),
),
"_lcov_merger": attr.label(
default = Label("@bazel_tools//tools/test/CoverageOutputGenerator/java/com/google/devtools/coverageoutputgenerator:Main"),
),
"_discover_tests_worker": attr.label(
default = Label("@io_bazel_rules_scala//src/scala/io/bazel/rules_scala/discover_tests_worker"),
),
"_discover_tests_runner": attr.label(
default = Label("@io_bazel_rules_scala//src/scala/io/bazel/rules_scala/discover_tests_runner"),
),
}

_scala_test_attrs.update(launcher_template)

_scala_test_attrs.update(implicit_deps)

_scala_test_attrs.update(common_attrs)

def make_scala_test(*extras):
return rule(
attrs = _dicts.add(
_scala_test_attrs,
extras_phases(extras),
*[extra["attrs"] for extra in extras if "attrs" in extra]
),
executable = True,
fragments = ["java"],
outputs = _dicts.add(
common_outputs,
*[extra["outputs"] for extra in extras if "outputs" in extra]
),
test = True,
toolchains = ["@io_bazel_rules_scala//scala:toolchain_type"],
implementation = _scala_test_impl,
)
12 changes: 12 additions & 0 deletions src/scala/io/bazel/rules_scala/discover_tests_runner/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
load("//scala:defs.bzl", "scala_library")

scala_library(
name = "discover_tests_runner",
srcs = ["DiscoverTestsRunner.scala"],
visibility = ["//visibility:public"],
deps = [
"//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java",
"//src/scala/io/bazel/rules_scala/discover_tests_worker:discovered_tests_java_proto",
"@io_bazel_rules_scala_test_interface//jar",
],
)
Loading