Skip to content

Commit fc01a65

Browse files
committed
Add Column unbase64
1 parent cf321d7 commit fc01a65

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,14 @@ trait NonAggregateFunctions {
233233
def base64[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] =
234234
column.typed(sparkFunctions.base64(column.untyped))
235235

236+
/** Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column.
237+
* This is the reverse of base64.
238+
*
239+
* apache/spark
240+
*/
241+
def unbase64[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Array[Byte]] =
242+
column.typed(sparkFunctions.unbase64(column.untyped))
243+
236244
/** Non-Aggregate function: Concatenates multiple input string columns together into a single string column.
237245
* @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
238246
*

dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,37 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
536536
.run()
537537
.toList
538538

539-
res ?= resCompare
539+
val backAndForth = typedDS
540+
.select(base64(unbase64(base64(typedDS('a)))))
541+
.collect()
542+
.run()
543+
.toList
544+
545+
(res ?= resCompare) && (res ?= backAndForth)
546+
}
547+
548+
check(forAll(prop _))
549+
}
550+
551+
test("unbase64") {
552+
val spark = session
553+
import spark.implicits._
554+
555+
def prop(values: List[X1[String]])(implicit encX1: Encoder[X1[String]]) = {
556+
val cDS = session.createDataset(values)
557+
val resCompare = cDS
558+
.select(sparkFunctions.unbase64(cDS("a")))
559+
.map(_.getAs[Array[Byte]](0))
560+
.collect().toList
561+
562+
val typedDS = TypedDataset.create(values)
563+
val res = typedDS
564+
.select(unbase64(typedDS('a)))
565+
.collect()
566+
.run()
567+
.toList
568+
569+
res.map(_.toList) ?= resCompare.map(_.toList)
540570
}
541571

542572
check(forAll(prop _))

0 commit comments

Comments
 (0)