Skip to content

Commit

Permalink
Add count_null aggregate function (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi authored Nov 21, 2023
1 parent b2c506a commit 495b8d8
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 6 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ i.doThis()
**Backticks:** `backticks(string: String, strings: String*): String)`: Encloses the given column name with backticks (`` ` ``) when needed.
This is a handy way to ensure column names with special characters like dots (`.`) work with `col()` or `select()`.

**Count null values:** `count_null(e: Column)`: an aggregation function like `count` that counts null values in column `e`.
This is equivalent to calling `count(when(e.isNull, lit(1)))`.

**.Net DateTime.Ticks:** Convert .Net (C#, F#, Visual Basic) `DateTime.Ticks` into Spark timestamps, seconds and nanoseconds.

<details>
Expand Down
27 changes: 24 additions & 3 deletions python/gresearch/spark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,20 @@
# limitations under the License.

from contextlib import contextmanager
from typing import Any, Union, List, Optional, Mapping
from typing import Any, Union, List, Optional, Mapping, TYPE_CHECKING

from py4j.java_gateway import JVMView, JavaObject
from pyspark import SparkContext
from pyspark.context import SparkContext
from pyspark.sql import DataFrame
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.context import SQLContext
from pyspark.context import SparkContext
from pyspark.sql.functions import col, count, lit, when
from pyspark.sql.session import SparkSession
from pyspark.storagelevel import StorageLevel

if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName


def _to_seq(jvm: JVMView, list: List[Any]) -> JavaObject:
array = jvm.java.util.ArrayList(list)
Expand Down Expand Up @@ -234,6 +237,24 @@ def unix_epoch_nanos_to_dotnet_ticks(unix_column: Union[str, Column]) -> Column:
return Column(func(_to_java_column(unix_column)))


def count_null(e: "ColumnOrName") -> Column:
"""
Aggregate function: returns the number of items in a group that are not null.
Parameters
----------
col : :class:`~pyspark.sql.Column` or str target column to compute on.
Returns
-------
:class:`~pyspark.sql.Column`
column for computed results.
"""
if isinstance(e, str):
e = col(e)
return count(when(e.isNull(), lit(1)))


def histogram(self: DataFrame,
thresholds: List[Union[int, float]],
value_column: str,
Expand Down
14 changes: 12 additions & 2 deletions python/test/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import unittest

from pyspark.sql import Row
from pyspark.sql.functions import col, count
from gresearch.spark import dotnet_ticks_to_timestamp, dotnet_ticks_to_unix_epoch, dotnet_ticks_to_unix_epoch_nanos, \
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks
timestamp_to_dotnet_ticks, unix_epoch_to_dotnet_ticks, unix_epoch_nanos_to_dotnet_ticks, count_null
from spark_common import SparkTest
from decimal import Decimal

Expand Down Expand Up @@ -141,6 +142,15 @@ def test_unix_epoch_nanos_to_dotnet_ticks(self):
expected = self.unix_nanos.join(self.ticks_from_unix_nanos, "id").orderBy('id')
self.compare_dfs(expected, timestamps)

def test_count_null(self):
actual = self.unix_nanos.select(
count("id").alias("ids"),
count(col("unix_nanos")).alias("nanos"),
count_null("id").alias("null_ids"),
count_null(col("unix_nanos")).alias("null_nanos"),
).collect()
self.assertEqual([Row(ids=7, nanos=6, null_ids=0, null_nanos=1)], actual)


if __name__ == '__main__':
SparkTest.main()
7 changes: 6 additions & 1 deletion src/main/scala/uk/co/gresearch/spark/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.functions.{col, when}
import org.apache.spark.sql.functions.{col, count, lit, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, LongType, TimestampType}
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -75,6 +75,11 @@ package object spark extends Logging with SparkVersion with BuildVersion {
def backticks(string: String, strings: String*): String =
Backticks.column_name(string, strings: _*)

/**
* Aggregate function: returns the number of items in a group that are not null.
*/
def count_null(e: Column): Column = count(when(e.isNull, lit(1)))

private val nanoSecondsPerDotNetTick: Long = 100
private val dotNetTicksPerSecond: Long = 10000000
private val unixEpochDotNetTicks: Long = 621355968000000000L
Expand Down
14 changes: 14 additions & 0 deletions src/test/scala/uk/co/gresearch/spark/SparkSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ class SparkSuite extends AnyFunSuite with SparkTestSession {
assert(backticks("the.alias", "a.column", "a.field") === "`the.alias`.`a.column`.`a.field`")
}

test("count_null") {
val df = Seq(
(1, "some"), (2, "text"), (3, "and"), (4, "some"), (5, "null"), (6, "values"), (7, null), (8, null)
).toDF("id", "str")
val actual =
df.select(
count($"id").as("ids"),
count($"str").as("strs"),
count_null($"id").as("null ids"),
count_null($"str").as("null strs")
).collect().head
assert(actual === Row(8, 6, 0, 2))
}

def assertJobDescription(expected: Option[String]): Unit = {
val descriptions = collectJobDescription(spark)
assert(descriptions === 0.to(2).map(id => (id, id, expected.orNull)))
Expand Down

0 comments on commit 495b8d8

Please sign in to comment.