diff --git a/README.md b/README.md index 7498b0db..ee06a998 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,9 @@ i.doThis() **Backticks:** `backticks(string: String, strings: String*): String)`: Encloses the given column name with backticks (`` ` ``) when needed. This is a handy way to ensure column names with special characters like dots (`.`) work with `col()` or `select()`. +**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`. +This is equivalent to calling `count(when(e.isNull, lit(1)))`. + **.Net DateTime.Ticks:** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.
diff --git a/python/gresearch/spark/__init__.py b/python/gresearch/spark/__init__.py index 5fe87e00..550f85d9 100644 --- a/python/gresearch/spark/__init__.py +++ b/python/gresearch/spark/__init__.py @@ -13,17 +13,20 @@ # limitations under the License. from contextlib import contextmanager -from typing import Any, Union, List, Optional, Mapping +from typing import Any, Union, List, Optional, Mapping, TYPE_CHECKING from py4j.java_gateway import JVMView, JavaObject -from pyspark import SparkContext +from pyspark.context import SparkContext from pyspark.sql import DataFrame from pyspark.sql.column import Column, _to_java_column from pyspark.sql.context import SQLContext -from pyspark.context import SparkContext +from pyspark.sql.functions import col, count, lit, when from pyspark.sql.session import SparkSession from pyspark.storagelevel import StorageLevel +if TYPE_CHECKING: + from pyspark.sql._typing import ColumnOrName + def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject: array = jvm.java.util.ArrayList(list) @@ -234,6 +237,24 @@ def unix_epoch_nanos_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column: return Column(func(_to_java_column(unix_column))) +def count_null(e: "ColumnOrName") -> Column: + """ + Aggregate function: returns the number of items in a group that are not null. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + column for computed results. + """ + if isinstance(e, str): + e = col(e) + return count(when(e.isNull(), lit(1))) + + def histogram(self: DataFrame, thresholds: List[Union[int, float]], value_column: str, diff --git a/python/test/test_package.py b/python/test/test_package.py index dee3e0ca..0bb1f086 100644 --- a/python/test/test_package.py +++ b/python/test/test_package.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime -import unittest +from pyspark.sql import Row +from pyspark.sql.functions import col, count from gresearch.spark import dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \ - timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks + timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, count_null from spark_common import SparkTest from decimal import Decimal @@ -141,6 +142,15 @@ def test_unix_epoch_nanos_to_dotnet_ticks(self): expected = self.unix_nanos.join(self.ticks_from_unix_nanos, "id").orderBy('id') self.compare_dfs(expected, timestamps) + def test_count_null(self): + actual = self.unix_nanos.select( + count("id").alias("ids"), + count(col("unix_nanos")).alias("nanos"), + count_null("id").alias("null_ids"), + count_null(col("unix_nanos")).alias("null_nanos"), + ).collect() + self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual) + if __name__ == '__main__': SparkTest.main() diff --git a/src/main/scala/uk/co/gresearch/spark/package.scala b/src/main/scala/uk/co/gresearch/spark/package.scala index 34f17ce8..2e228b54 100644 --- a/src/main/scala/uk/co/gresearch/spark/package.scala +++ b/src/main/scala/uk/co/gresearch/spark/package.scala @@ -20,7 +20,7 @@ import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.functions.{col, when} +import org.apache.spark.sql.functions.{col, count, lit, when} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, LongType, TimestampType} import org.apache.spark.storage.StorageLevel @@ -75,6 +75,11 @@ package object spark extends Logging with SparkVersion with BuildVersion { def backticks(string: String, strings: String*): String = Backticks.column_name(string, strings: _*) + /** + * Aggregate function: returns the number of items in a group that are not null. + */ + def count_null(e: Column): Column = count(when(e.isNull, lit(1))) + private val nanoSecondsPerDotNetTick: Long = 100 private val dotNetTicksPerSecond: Long = 10000000 private val unixEpochDotNetTicks: Long = 621355968000000000L diff --git a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala index 7dc7765b..0ec08820 100644 --- a/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala +++ b/src/test/scala/uk/co/gresearch/spark/SparkSuite.scala @@ -142,6 +142,20 @@ class SparkSuite extends AnyFunSuite with SparkTestSession { assert(backticks("the.alias", "a.column", "a.field") === "`the.alias`.`a.column`.`a.field`") } + test("count_null") { + val df = Seq( + (1, "some"), (2, "text"), (3, "and"), (4, "some"), (5, "null"), (6, "values"), (7, null), (8, null) + ).toDF("id", "str") + val actual = + df.select( + count($"id").as("ids"), + count($"str").as("strs"), + count_null($"id").as("null ids"), + count_null($"str").as("null strs") + ).collect().head + assert(actual === Row(8, 6, 0, 2)) + } + def assertJobDescription(expected: Option[String]): Unit = { val descriptions = collectJobDescription(spark) assert(descriptions === 0.to(2).map(id => (id, id, expected.orNull)))