Skip to content

Commit

Permalink
Whiskey Beam example
Browse files Browse the repository at this point in the history
  • Loading branch information
paulk-asert committed May 29, 2024
1 parent 6551ccc commit 2f71a63
Show file tree
Hide file tree
Showing 10 changed files with 480 additions and 0 deletions.
1 change: 1 addition & 0 deletions settings.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def subprojects = [
'LanguageProcessingSparkNLP',
'Mnist',
'Whiskey',
'WhiskeyBeam',
'WhiskeyFlink',
'WhiskeyIgnite',
'WhiskeySpark',
Expand Down
47 changes: 47 additions & 0 deletions subprojects/WhiskeyBeam/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
apply plugin: 'groovy'

sourceCompatibility = 1.8

dependencies {
implementation "org.apache.beam:beam-sdks-java-core:$beamVersion"
implementation "org.apache.beam:beam-runners-direct-java:$beamVersion"
implementation "org.slf4j:slf4j-api:$slf4jVersion"
implementation "org.apache.groovy:groovy:$groovy4Version"
// implementation("com.github.haifengl:smile-core:$smileVersion") {
// transitive = false
// }
// implementation("com.github.haifengl:smile-base:$smileVersion") {
// transitive = false
// }
implementation "org.apache.commons:commons-csv:$commonsCsvVersion"
implementation "org.apache.commons:commons-math4-legacy:$commonsMath4Version"
runtimeOnly "org.slf4j:slf4j-jdk14:$slf4jVersion"
// runtimeOnly "org.bytedeco:openblas-platform:$openblasPlatformVersion"
}

FileUtil.baseNames(sourceSets.main.allSource.files).each { name ->
if (name.startsWith('Whiskey')) {
tasks.register("run$name", JavaExec) {
dependsOn compileGroovy
group 'Application'
description "Run ${name}.groovy as a JVM application/Groovy script"
classpath = sourceSets.main.runtimeClasspath
mainClass = name
}
}
}
36 changes: 36 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/AssignClusters.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import groovy.lang.Closure;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;

import java.io.IOException;

public class AssignClusters extends DoFn<Point, KV<Integer, Point>> {
final private PCollectionView<Points> centroidsView;
final private Closure<KV<Integer, Point>> clos;

public AssignClusters(PCollectionView<Points> centroidsView, Closure<KV<Integer, Point>> clos) {
this.centroidsView = centroidsView;
this.clos = clos;
}

@ProcessElement
public void processElement(@Element Point pt, OutputReceiver<KV<Integer, Point>> out, ProcessContext c) throws IOException {
out.output(clos.call(pt, c.sideInput(centroidsView)));
}
}
43 changes: 43 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/Log.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.beam.sdk.transforms.DoFn
import org.apache.beam.sdk.transforms.PTransform
import org.apache.beam.sdk.transforms.ParDo
import org.apache.beam.sdk.values.PCollection
import org.slf4j.Logger
import org.slf4j.LoggerFactory

class Log {
private static final Logger LOGGER = LoggerFactory.getLogger(Log.class)
private Log() { }

static <T> PTransform<PCollection<T>, PCollection<T>> ofElements() {
new LoggingTransform<>()
}

private static class LoggingTransform<T> extends PTransform<PCollection<T>, PCollection<T>> {
@Override
PCollection<T> expand(PCollection<T> input) {
return input.apply(ParDo.of(new DoFn<T, T>() {
@DoFn.ProcessElement
void processElement(@DoFn.Element T element, DoFn.OutputReceiver<T> out) {
LOGGER.info(element.toString())
out.output(element)
}
}))
}
}
}
34 changes: 34 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/MeanDoubleArrayCols.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.apache.beam.sdk.transforms.SerializableFunction

class MeanDoubleArrayCols implements SerializableFunction<Iterable<Point>, Point> {
@Override
Point apply(Iterable<Point> inputs) {
double[] result = new double[12]
int count = 0
for (Point input : inputs) {
result.indices.each {
result[it] += input.pts()[it]
}
count++
}
result.indices.each {
result[it] /= count
}
new Point(result)
}
}
30 changes: 30 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/Point.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

record Point(double[] pts) implements Serializable {
private static Random r = new Random()
private static Closure<double[]> randomPoint = { dims ->
(1..dims).collect { r.nextGaussian() + 2 } as double[]
}

static Point ofRandom(int dims) {
new Point(randomPoint(dims))
}

String toString() {
"Point[${pts.collect{ sprintf '%.2f', it }.join('. ')}]"
}
}
17 changes: 17 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/Points.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

record Points(List<Point> pts) implements Serializable { }
61 changes: 61 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/Squash.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import groovy.transform.CompileStatic
import groovy.transform.stc.POJO
import org.apache.beam.sdk.transforms.Combine
import org.apache.beam.sdk.values.KV

@CompileStatic
@POJO
class Squash extends Combine.CombineFn<KV<Integer, Point>, Accum, Points> {
int k, dims

@Override
Accum createAccumulator() {
new Accum()
}

@Override
Accum addInput(Accum mutableAccumulator, KV<Integer, Point> input) {
mutableAccumulator.pts << input.value
mutableAccumulator
}

@Override
Accum mergeAccumulators(Iterable<Accum> accumulators) {
Accum result = createAccumulator()
accumulators.each {
result.pts += it.pts
}
result
}

@Override
Points extractOutput(Accum accumulator) {
var pts = accumulator.pts
if (k && dims) {
while (pts.size() < k) {
pts << Point.ofRandom(dims)
}
}
new Points(pts)
}

static class Accum implements Serializable {
List<Point> pts = []
}
}
104 changes: 104 additions & 0 deletions subprojects/WhiskeyBeam/src/main/groovy/WhiskeyBeam.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import groovy.transform.CompileStatic
import org.apache.beam.sdk.Pipeline
import org.apache.beam.sdk.transforms.Combine
import org.apache.beam.sdk.transforms.Create
import org.apache.beam.sdk.transforms.DoFn
import org.apache.beam.sdk.transforms.DoFn.Element
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver
import org.apache.beam.sdk.transforms.DoFn.ProcessElement
import org.apache.beam.sdk.transforms.ParDo
import org.apache.beam.sdk.transforms.View
import org.apache.beam.sdk.values.KV

import static java.lang.Math.sqrt
import static java.util.logging.Level.INFO
import static java.util.logging.Level.SEVERE
import static java.util.logging.Logger.getLogger
import static org.apache.commons.csv.CSVFormat.RFC4180 as CSV
import static org.apache.commons.math4.legacy.stat.StatUtils.sumSq

@CompileStatic
static buildPipeline(Pipeline p, String filename, int k, int iterations, int dims) {
var readCsv = new DoFn<String, Point>() {
@ProcessElement
void processElement(@Element String path, OutputReceiver<Point> receiver) throws IOException {
def parser= CSV.builder().setHeader().setSkipHeaderRecord(true).build()
def records= new File(path).withReader{ rdr -> parser.parse(rdr).records*.toList() }
records.each { receiver.output(new Point(it[2..-1] as double[])) }
}
}

var pointArray2out = new DoFn<Points, String>() {
@ProcessElement
void processElement(@Element Points pts, OutputReceiver<String> out) {
String log = "Centroids:\n${pts.pts()*.toString().join('\n')}"
out.output(log)
}
}

var assign = { Point pt, Points centroids ->
var minDistance = Double.POSITIVE_INFINITY
var nearestCentroidId = -1
var idxs = pt.pts().indices
centroids.pts().eachWithIndex { Point next, int cluster ->
var distance = sqrt(sumSq(idxs.collect { pt.pts()[it] - next.pts()[it] } as double[]))
if (distance < minDistance) {
minDistance = distance
nearestCentroidId = cluster
}
}
KV.of(nearestCentroidId, pt)
}

Points initCentroids = new Points((1..k).collect{ Point.ofRandom(dims) })

var points = p
.apply(Create.of(filename))
.apply('Read points', ParDo.of(readCsv))

var centroids = p.apply(Create.of(initCentroids))

iterations.times {
var centroidsView = centroids
.apply(View.<Points> asSingleton())

centroids = points
.apply('Assign clusters', ParDo.of(new AssignClusters(centroidsView, assign)).withSideInputs(centroidsView))
.apply('Calculate new centroids', Combine.<Integer, Point> perKey(new MeanDoubleArrayCols()))
.apply('As Points', Combine.<KV<Integer, Point>, Points> globally(new Squash(k: k, dims: dims)))
// Uncomment below to log intermediate centroid calculations:
// centroids
// .apply('Current centroids', ParDo.of(pointArray2out)).apply(Log.ofElements())
}
centroids
.apply('Display centroids', ParDo.of(pointArray2out)).apply(Log.ofElements())
}

getLogger(getClass().name).info 'Creating pipeline ...'
var pipeline = Pipeline.create()
getLogger('').level = SEVERE // quieten root logging

int k = 5
int iterations = 10
int dims = 12

def csv = getClass().classLoader.getResource('whiskey.csv').path
buildPipeline(pipeline, csv, k, iterations, dims)
getLogger(Log.name).level = INFO // logging on for us
pipeline.run().waitUntilFinish()
Loading

0 comments on commit 2f71a63

Please sign in to comment.