Skip to content

Commit

Permalink
minor refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
paulk-asert committed May 28, 2024
1 parent a2c8cf0 commit a9255b7
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 73 deletions.
39 changes: 22 additions & 17 deletions subprojects/WhiskeyWayang/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,15 @@
* limitations under the License.
*/
apply plugin: 'groovy'
apply plugin: 'application'

repositories {
// mavenLocal()
mavenCentral()
mavenLocal()
// maven {
// url 'https://repository.apache.org/content/repositories/orgapachewayang-1017'
// }
}

ext.appName = 'WhiskeyWayang'

application {
mainClass = appName
if (JavaVersion.current().java11Compatible) {
applicationDefaultJvmArgs = ['--add-exports=java.base/sun.nio.ch=ALL-UNNAMED']
}
}

tasks.named('run').configure {
description = "Run $appName as a JVM application/Groovy script"
}

ext {
wayangVersion = '0.7.1'
hadoopVersion = '3.4.0'
Expand All @@ -45,7 +31,7 @@ ext {
}

dependencies {
implementation "org.apache.groovy:groovy:5.0.0-alpha-8"
implementation "org.apache.groovy:groovy:$groovy5Version"
implementation "org.apache.wayang:wayang-api-scala-java_$scalaMajorVersion:$wayangVersion"
implementation "org.apache.wayang:wayang-java:$wayangVersion"
implementation("org.apache.wayang:wayang-ml4all:$wayangVersion") {
Expand Down Expand Up @@ -85,4 +71,23 @@ tasks.register('versionInfo') {
}
}

run.dependsOn versionInfo
def runAll = tasks.register('runAll') {
group 'Application'
dependsOn versionInfo
}

FileUtil.baseNames(sourceSets.main.allSource.files).each { name ->
def subtask = tasks.register("run$name", JavaExec) {
dependsOn compileGroovy
group 'Application'
description "Run ${name}.groovy as a JVM application/Groovy script"
classpath = sourceSets.main.runtimeClasspath
mainClass = name
if (JavaVersion.current().java11Compatible) {
jvmArgs ['--add-exports=java.base/sun.nio.ch=ALL-UNNAMED']
}
}
runAll.configure {
dependsOn subtask
}
}
78 changes: 34 additions & 44 deletions subprojects/WhiskeyWayang/src/main/groovy/WhiskeyWayang.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
// https://github.com/apache/incubator-wayang/blob/main/README.md#k-means

import org.apache.wayang.api.JavaPlanBuilder
import org.apache.wayang.core.api.Configuration
import org.apache.wayang.core.api.WayangContext
import org.apache.wayang.core.function.ExecutionContext
import org.apache.wayang.core.function.FunctionDescriptor.ExtendedSerializableFunction
Expand All @@ -33,57 +32,51 @@ record Point(double[] pts) implements Serializable {
new Point(line.split(',')[2..-1] as double[]) }
}

record TaggedPointCounter(double[] pts, int cluster, long count) implements Serializable {
TaggedPointCounter(List<Double> pts, int cluster, long count) {
record PointGrouping(double[] pts, int cluster, long count) implements Serializable {
PointGrouping(List<Double> pts, int cluster, long count) {
this(pts as double[], cluster, count)
}

TaggedPointCounter plus(TaggedPointCounter that) {
PointGrouping plus(PointGrouping that) {
var newPts = pts.indices.collect{ pts[it] + that.pts[it] }
new TaggedPointCounter(newPts, cluster, count + that.count)
new PointGrouping(newPts, cluster, count + that.count)
}

TaggedPointCounter average() {
new TaggedPointCounter(pts.collect{ double d -> d/count }, cluster, count)
PointGrouping average() {
new PointGrouping(pts.collect{ double d -> d/count }, cluster, count)
}
}

class SelectNearestCentroid implements
ExtendedSerializableFunction<Point, TaggedPointCounter> {
Iterable<TaggedPointCounter> centroids
class SelectNearestCentroid implements ExtendedSerializableFunction<Point, PointGrouping> {
Iterable<PointGrouping> centroids

void open(ExecutionContext context) {
centroids = context.getBroadcast("centroids")
centroids = context.getBroadcast('centroids')
}

TaggedPointCounter apply(Point p) {
PointGrouping apply(Point p) {
var minDistance = Double.POSITIVE_INFINITY
var nearestCentroidId = -1
for (c in centroids) {
var distance = sqrt((0..<p.pts.size()).collect{ p.pts[it] - c.pts[it] }.sum{ it ** 2 } as double)
var distance = sqrt(p.pts.indices.collect{ p.pts[it] - c.pts[it] }.sum{ it ** 2 } as double)
if (distance < minDistance) {
minDistance = distance
nearestCentroidId = c.cluster
}
}
new TaggedPointCounter(p.pts, nearestCentroidId, 1)
new PointGrouping(p.pts, nearestCentroidId, 1)
}
}

class Cluster implements SerializableFunction<TaggedPointCounter, Integer> {
Integer apply(TaggedPointCounter tpc) { tpc.cluster() }
}

class Average implements SerializableFunction<TaggedPointCounter, TaggedPointCounter> {
TaggedPointCounter apply(TaggedPointCounter tpc) { tpc.average() }
}

class Plus implements SerializableBinaryOperator<TaggedPointCounter> {
TaggedPointCounter apply(TaggedPointCounter tpc1, TaggedPointCounter tpc2) { tpc1 + tpc2 }
class PipelineOps {
public static SerializableFunction<PointGrouping, Integer> cluster = tpc -> tpc.cluster
public static SerializableFunction<PointGrouping, PointGrouping> average = tpc -> tpc.average()
public static SerializableBinaryOperator<PointGrouping> plus = (tpc1, tpc2) -> tpc1 + tpc2
}
import static PipelineOps.*

int k = 5
int iterations = 20
int iterations = 10

// read in data from our file
var url = WhiskeyWayang.classLoader.getResource('whiskey.csv').file
Expand All @@ -95,9 +88,7 @@ var r = new Random()
var randomPoint = { (0..<dims).collect { r.nextGaussian() + 2 } as double[] }
var initPts = (1..k).collect(randomPoint)

// create planbuilder with Java and Spark enabled
var configuration = new Configuration()
var context = new WayangContext(configuration)
var context = new WayangContext()
.withPlugin(Java.basicPlugin())
.withPlugin(Spark.basicPlugin())
var planBuilder = new JavaPlanBuilder(context, "KMeans ($url, k=$k, iterations=$iterations)")
Expand All @@ -106,27 +97,26 @@ var points = planBuilder
.loadCollection(pointsData).withName('Load points')

var initialCentroids = planBuilder
.loadCollection((0..<k).collect{ idx -> new TaggedPointCounter(initPts[idx], idx, 0) })
.withName("Load random centroids")
.loadCollection((0..<k).collect{ idx -> new PointGrouping(initPts[idx], idx, 0) })
.withName('Load random centroids')

var finalCentroids = initialCentroids
.repeat(iterations, currentCentroids ->
points.map(new SelectNearestCentroid())
.withBroadcast(currentCentroids, "centroids").withName("Find nearest centroid")
.reduceByKey(new Cluster(), new Plus()).withName("Add up points")
.map(new Average()).withName("Average points")
.withOutputClass(TaggedPointCounter)).withName("Loop").collect()
var finalCentroids = initialCentroids.repeat(iterations, currentCentroids ->
points.map(new SelectNearestCentroid())
.withBroadcast(currentCentroids, 'centroids').withName('Find nearest centroid')
.reduceByKey(cluster, plus).withName('Aggregate points')
.map(average).withName('Average points')
.withOutputClass(PointGrouping)
).withName('Loop')

println 'Centroids:'
finalCentroids.each { c ->
var pts = c.pts.collect{ sprintf '%.2f', it }.join(', ')
finalCentroids.forEach { c ->
var pts = c.pts.collect { sprintf '%.2f', it }.join(', ')
println "Cluster$c.cluster ($c.count points): $pts"
}
/*
Centroids:
Cluster0 (24 points): 2.79, 2.42, 1.46, 0.04, 0.00, 1.88, 1.67, 1.96, 1.92, 2.08, 2.17, 1.71
Cluster1 (6 points): 3.67, 1.50, 3.67, 3.33, 0.67, 0.17, 1.67, 0.50, 1.17, 1.33, 1.17, 0.17
Cluster2 (15 points): 1.80, 1.93, 1.93, 1.13, 0.20, 1.20, 1.33, 0.80, 1.60, 1.80, 1.00, 1.13
Cluster3 (2 points): 2.00, 1.50, 2.50, 0.50, 0.00, 0.00, 2.50, 0.50, 0.00, 1.00, 2.00, 2.00
Cluster4 (39 points): 1.49, 2.51, 1.05, 0.21, 0.08, 1.10, 1.13, 0.54, 1.26, 1.74, 1.97, 2.13
Cluster0 (20 points): 2.00, 2.50, 1.55, 0.35, 0.20, 1.15, 1.55, 0.95, 0.90, 1.80, 1.35, 1.35
Cluster2 (21 points): 2.81, 2.43, 1.52, 0.05, 0.00, 1.90, 1.67, 2.05, 2.10, 2.10, 2.19, 1.76
Cluster3 (34 points): 1.38, 2.32, 1.09, 0.26, 0.03, 1.15, 1.09, 0.47, 1.38, 1.74, 2.03, 2.24
Cluster4 (11 points): 2.91, 1.55, 2.91, 2.73, 0.45, 0.45, 1.45, 0.55, 1.55, 1.45, 1.18, 0.55
*/
24 changes: 12 additions & 12 deletions subprojects/WhiskeyWayang/src/main/notebook/WhiskeyWayang.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
"id": "366b792b",
"metadata": {},
"source": [
"Next we'll define `Point` and `TaggedPointCounter` classes:"
"Next we'll define `Point` and `PointGrouping` classes:"
]
},
{
Expand All @@ -97,14 +97,14 @@
"class Point { double[] pts }\n",
"\n",
"@TupleConstructor(includeSuperProperties=true)\n",
"class TaggedPointCounter extends Point {\n",
"class PointGrouping extends Point {\n",
" int cluster\n",
" long count\n",
" TaggedPointCounter plus(TaggedPointCounter other) {\n",
" new TaggedPointCounter((0..<pts.size()).collect{ pts[it] + other.pts[it] } as double[], cluster, count + other.count)\n",
" PointGrouping plus(PointGrouping other) {\n",
" new PointGrouping((0..<pts.size()).collect{ pts[it] + other.pts[it] } as double[], cluster, count + other.count)\n",
" }\n",
" TaggedPointCounter average() {\n",
" new TaggedPointCounter(pts.collect{ it/count } as double[], cluster, 0)\n",
" PointGrouping average() {\n",
" new PointGrouping(pts.collect{ it/count } as double[], cluster, 0)\n",
" }\n",
"}\n",
"OutputCell.HIDDEN"
Expand All @@ -127,14 +127,14 @@
"source": [
"import org.apache.wayang.core.function.FunctionDescriptor.ExtendedSerializableFunction\n",
"\n",
"class SelectNearestCentroid implements ExtendedSerializableFunction<Point, TaggedPointCounter> {\n",
" Iterable<TaggedPointCounter> centroids\n",
"class SelectNearestCentroid implements ExtendedSerializableFunction<Point, PointGrouping> {\n",
" Iterable<PointGrouping> centroids\n",
"\n",
" void open(ExecutionContext context) {\n",
" centroids = context.getBroadcast(\"centroids\")\n",
" }\n",
"\n",
" TaggedPointCounter apply(Point p) {\n",
" PointGrouping apply(Point p) {\n",
" def minDistance = Double.POSITIVE_INFINITY\n",
" def nearestCentroidId = -1\n",
" for (c in centroids) {\n",
Expand All @@ -144,7 +144,7 @@
" nearestCentroidId = c.cluster\n",
" }\n",
" }\n",
" new TaggedPointCounter(p.pts, nearestCentroidId, 1)\n",
" new PointGrouping(p.pts, nearestCentroidId, 1)\n",
" }\n",
"}\n",
"OutputCell.HIDDEN"
Expand Down Expand Up @@ -198,14 +198,14 @@
" .loadCollection(pointsData)\n",
"\n",
"def initialCentroids = planBuilder\n",
" .loadCollection((0..<k).collect{new TaggedPointCounter(initPts[it], it, 0)})\n",
" .loadCollection((0..<k).collect{new PointGrouping(initPts[it], it, 0)})\n",
" .withName(\"Load random centroids\")\n",
"\n",
"finalCentroids = initialCentroids.repeat(iterations, { currentCentroids ->\n",
" points.map(new SelectNearestCentroid())\n",
" .withBroadcast(currentCentroids, \"centroids\").withName(\"Find nearest centroid\")\n",
" .reduceByKey({ tpc -> tpc.cluster }, { tpc1, tpc2 -> tpc1.plus(tpc2) }).withName(\"Add up points\")\n",
" .map({ tpc -> tpc.average() }).withName(\"Average points\").withOutputClass(TaggedPointCounter)\n",
" .map({ tpc -> tpc.average() }).withName(\"Average points\").withOutputClass(PointGrouping)\n",
"}).withName(\"Loop\").collect()\n",
"\n",
"cols = [\"Body\", \"Sweetness\", \"Smoky\", \"Medicinal\", \"Tobacco\", \"Honey\",\n",
Expand Down

0 comments on commit a9255b7

Please sign in to comment.