diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1 @@ + diff --git a/test/benchmarks/run_tests.sh b/test/benchmarks/run_tests.sh index 408e8022b245..075c71deea28 100755 --- a/test/benchmarks/run_tests.sh +++ b/test/benchmarks/run_tests.sh @@ -4,6 +4,9 @@ CDIR="$(cd "$(dirname "$0")" ; pwd -P)" LOGFILE=/tmp/pytorch_benchmarks_test.log VERBOSITY=0 +# Make benchmark module available as it is not part of torch_xla. +export PYTHONPATH=$PYTHONPATH:$CDIR/../../benchmarks/ + # Note [Keep Going] # # Set the `CONTINUE_ON_ERROR` flag to `true` to make the CircleCI tests continue on error. @@ -19,10 +22,10 @@ do case $OPTION in L) LOGFILE= - ;; + ;; V) VERBOSITY=$OPTARG - ;; + ;; esac done shift $(($OPTIND - 1)) @@ -35,8 +38,13 @@ function run_make_tests { make -C $CDIR $MAKE_V all } +function run_python_tests { + python3 "$CDIR/test_benchmark_experiment.py" +} + function run_tests { run_make_tests + run_python_tests } if [ "$LOGFILE" != "" ]; then diff --git a/test/benchmarks/test_benchmark_experiment.py b/test/benchmarks/test_benchmark_experiment.py new file mode 100644 index 000000000000..578ea79eed37 --- /dev/null +++ b/test/benchmarks/test_benchmark_experiment.py @@ -0,0 +1,24 @@ +import unittest + +from benchmark_experiment import BenchmarkExperiment + + +class BenchmarkExperimentTest(unittest.TestCase): + + def test_to_dict(self): + be = BenchmarkExperiment("some name", "cpu", "PJRT", "some xla_flags", + "openxla", "train", "123") + actual = be.to_dict() + self.assertEqual(8, len(actual)) + self.assertEqual("some name", actual["experiment_name"]) + self.assertEqual("cpu", actual["accelerator"]) + self.assertTrue("accelerator_model" in actual) + self.assertEqual("PJRT", actual["xla"]) + self.assertEqual("some xla_flags", actual["xla_flags"]) + self.assertEqual("openxla", actual["dynamo"]) + self.assertEqual("train", actual["test"]) + self.assertEqual("123", actual["batch_size"]) + + +if __name__ == '__main__': + unittest.main()