From a9255b734cfbfa9685af7694d6675ebaca98c5fc Mon Sep 17 00:00:00 2001 From: Paul King Date: Tue, 28 May 2024 15:01:05 +1000 Subject: [PATCH] minor refactor --- subprojects/WhiskeyWayang/build.gradle | 39 ++++++---- .../src/main/groovy/WhiskeyWayang.groovy | 78 ++++++++----------- .../src/main/notebook/WhiskeyWayang.ipynb | 24 +++--- 3 files changed, 68 insertions(+), 73 deletions(-) diff --git a/subprojects/WhiskeyWayang/build.gradle b/subprojects/WhiskeyWayang/build.gradle index 7b80e56..830196b 100644 --- a/subprojects/WhiskeyWayang/build.gradle +++ b/subprojects/WhiskeyWayang/build.gradle @@ -14,29 +14,15 @@ * limitations under the License. */ apply plugin: 'groovy' -apply plugin: 'application' repositories { +// mavenLocal() mavenCentral() - mavenLocal() // maven { // url 'https://repository.apache.org/content/repositories/orgapachewayang-1017' // } } -ext.appName = 'WhiskeyWayang' - -application { - mainClass = appName - if (JavaVersion.current().java11Compatible) { - applicationDefaultJvmArgs = ['--add-exports=java.base/sun.nio.ch=ALL-UNNAMED'] - } -} - -tasks.named('run').configure { - description = "Run $appName as a JVM application/Groovy script" -} - ext { wayangVersion = '0.7.1' hadoopVersion = '3.4.0' @@ -45,7 +31,7 @@ ext { } dependencies { - implementation "org.apache.groovy:groovy:5.0.0-alpha-8" + implementation "org.apache.groovy:groovy:$groovy5Version" implementation "org.apache.wayang:wayang-api-scala-java_$scalaMajorVersion:$wayangVersion" implementation "org.apache.wayang:wayang-java:$wayangVersion" implementation("org.apache.wayang:wayang-ml4all:$wayangVersion") { @@ -85,4 +71,23 @@ tasks.register('versionInfo') { } } -run.dependsOn versionInfo +def runAll = tasks.register('runAll') { + group 'Application' + dependsOn versionInfo +} + +FileUtil.baseNames(sourceSets.main.allSource.files).each { name -> + def subtask = tasks.register("run$name", JavaExec) { + dependsOn compileGroovy + group 'Application' + description "Run ${name}.groovy as a JVM application/Groovy script" + classpath = sourceSets.main.runtimeClasspath + mainClass = name + if (JavaVersion.current().java11Compatible) { + jvmArgs ['--add-exports=java.base/sun.nio.ch=ALL-UNNAMED'] + } + } + runAll.configure { + dependsOn subtask + } +} diff --git a/subprojects/WhiskeyWayang/src/main/groovy/WhiskeyWayang.groovy b/subprojects/WhiskeyWayang/src/main/groovy/WhiskeyWayang.groovy index 8c75ec8..f10f597 100644 --- a/subprojects/WhiskeyWayang/src/main/groovy/WhiskeyWayang.groovy +++ b/subprojects/WhiskeyWayang/src/main/groovy/WhiskeyWayang.groovy @@ -18,7 +18,6 @@ // https://github.com/apache/incubator-wayang/blob/main/README.md#k-means import org.apache.wayang.api.JavaPlanBuilder -import org.apache.wayang.core.api.Configuration import org.apache.wayang.core.api.WayangContext import org.apache.wayang.core.function.ExecutionContext import org.apache.wayang.core.function.FunctionDescriptor.ExtendedSerializableFunction @@ -33,57 +32,51 @@ record Point(double[] pts) implements Serializable { new Point(line.split(',')[2..-1] as double[]) } } -record TaggedPointCounter(double[] pts, int cluster, long count) implements Serializable { - TaggedPointCounter(List pts, int cluster, long count) { +record PointGrouping(double[] pts, int cluster, long count) implements Serializable { + PointGrouping(List pts, int cluster, long count) { this(pts as double[], cluster, count) } - TaggedPointCounter plus(TaggedPointCounter that) { + PointGrouping plus(PointGrouping that) { var newPts = pts.indices.collect{ pts[it] + that.pts[it] } - new TaggedPointCounter(newPts, cluster, count + that.count) + new PointGrouping(newPts, cluster, count + that.count) } - TaggedPointCounter average() { - new TaggedPointCounter(pts.collect{ double d -> d/count }, cluster, count) + PointGrouping average() { + new PointGrouping(pts.collect{ double d -> d/count }, cluster, count) } } -class SelectNearestCentroid implements - ExtendedSerializableFunction { - Iterable centroids +class SelectNearestCentroid implements ExtendedSerializableFunction { + Iterable centroids void open(ExecutionContext context) { - centroids = context.getBroadcast("centroids") + centroids = context.getBroadcast('centroids') } - TaggedPointCounter apply(Point p) { + PointGrouping apply(Point p) { var minDistance = Double.POSITIVE_INFINITY var nearestCentroidId = -1 for (c in centroids) { - var distance = sqrt((0.. { - Integer apply(TaggedPointCounter tpc) { tpc.cluster() } -} - -class Average implements SerializableFunction { - TaggedPointCounter apply(TaggedPointCounter tpc) { tpc.average() } -} - -class Plus implements SerializableBinaryOperator { - TaggedPointCounter apply(TaggedPointCounter tpc1, TaggedPointCounter tpc2) { tpc1 + tpc2 } +class PipelineOps { + public static SerializableFunction cluster = tpc -> tpc.cluster + public static SerializableFunction average = tpc -> tpc.average() + public static SerializableBinaryOperator plus = (tpc1, tpc2) -> tpc1 + tpc2 } +import static PipelineOps.* int k = 5 -int iterations = 20 +int iterations = 10 // read in data from our file var url = WhiskeyWayang.classLoader.getResource('whiskey.csv').file @@ -95,9 +88,7 @@ var r = new Random() var randomPoint = { (0.. new TaggedPointCounter(initPts[idx], idx, 0) }) - .withName("Load random centroids") + .loadCollection((0.. new PointGrouping(initPts[idx], idx, 0) }) + .withName('Load random centroids') -var finalCentroids = initialCentroids - .repeat(iterations, currentCentroids -> - points.map(new SelectNearestCentroid()) - .withBroadcast(currentCentroids, "centroids").withName("Find nearest centroid") - .reduceByKey(new Cluster(), new Plus()).withName("Add up points") - .map(new Average()).withName("Average points") - .withOutputClass(TaggedPointCounter)).withName("Loop").collect() +var finalCentroids = initialCentroids.repeat(iterations, currentCentroids -> + points.map(new SelectNearestCentroid()) + .withBroadcast(currentCentroids, 'centroids').withName('Find nearest centroid') + .reduceByKey(cluster, plus).withName('Aggregate points') + .map(average).withName('Average points') + .withOutputClass(PointGrouping) +).withName('Loop') println 'Centroids:' -finalCentroids.each { c -> - var pts = c.pts.collect{ sprintf '%.2f', it }.join(', ') +finalCentroids.forEach { c -> + var pts = c.pts.collect { sprintf '%.2f', it }.join(', ') println "Cluster$c.cluster ($c.count points): $pts" } /* Centroids: -Cluster0 (24 points): 2.79, 2.42, 1.46, 0.04, 0.00, 1.88, 1.67, 1.96, 1.92, 2.08, 2.17, 1.71 -Cluster1 (6 points): 3.67, 1.50, 3.67, 3.33, 0.67, 0.17, 1.67, 0.50, 1.17, 1.33, 1.17, 0.17 -Cluster2 (15 points): 1.80, 1.93, 1.93, 1.13, 0.20, 1.20, 1.33, 0.80, 1.60, 1.80, 1.00, 1.13 -Cluster3 (2 points): 2.00, 1.50, 2.50, 0.50, 0.00, 0.00, 2.50, 0.50, 0.00, 1.00, 2.00, 2.00 -Cluster4 (39 points): 1.49, 2.51, 1.05, 0.21, 0.08, 1.10, 1.13, 0.54, 1.26, 1.74, 1.97, 2.13 +Cluster0 (20 points): 2.00, 2.50, 1.55, 0.35, 0.20, 1.15, 1.55, 0.95, 0.90, 1.80, 1.35, 1.35 +Cluster2 (21 points): 2.81, 2.43, 1.52, 0.05, 0.00, 1.90, 1.67, 2.05, 2.10, 2.10, 2.19, 1.76 +Cluster3 (34 points): 1.38, 2.32, 1.09, 0.26, 0.03, 1.15, 1.09, 0.47, 1.38, 1.74, 2.03, 2.24 +Cluster4 (11 points): 2.91, 1.55, 2.91, 2.73, 0.45, 0.45, 1.45, 0.55, 1.55, 1.45, 1.18, 0.55 */ diff --git a/subprojects/WhiskeyWayang/src/main/notebook/WhiskeyWayang.ipynb b/subprojects/WhiskeyWayang/src/main/notebook/WhiskeyWayang.ipynb index 8b0c1fb..d1d9f33 100644 --- a/subprojects/WhiskeyWayang/src/main/notebook/WhiskeyWayang.ipynb +++ b/subprojects/WhiskeyWayang/src/main/notebook/WhiskeyWayang.ipynb @@ -84,7 +84,7 @@ "id": "366b792b", "metadata": {}, "source": [ - "Next we'll define `Point` and `TaggedPointCounter` classes:" + "Next we'll define `Point` and `PointGrouping` classes:" ] }, { @@ -97,14 +97,14 @@ "class Point { double[] pts }\n", "\n", "@TupleConstructor(includeSuperProperties=true)\n", - "class TaggedPointCounter extends Point {\n", + "class PointGrouping extends Point {\n", " int cluster\n", " long count\n", - " TaggedPointCounter plus(TaggedPointCounter other) {\n", - " new TaggedPointCounter((0.. {\n", - " Iterable centroids\n", + "class SelectNearestCentroid implements ExtendedSerializableFunction {\n", + " Iterable centroids\n", "\n", " void open(ExecutionContext context) {\n", " centroids = context.getBroadcast(\"centroids\")\n", " }\n", "\n", - " TaggedPointCounter apply(Point p) {\n", + " PointGrouping apply(Point p) {\n", " def minDistance = Double.POSITIVE_INFINITY\n", " def nearestCentroidId = -1\n", " for (c in centroids) {\n", @@ -144,7 +144,7 @@ " nearestCentroidId = c.cluster\n", " }\n", " }\n", - " new TaggedPointCounter(p.pts, nearestCentroidId, 1)\n", + " new PointGrouping(p.pts, nearestCentroidId, 1)\n", " }\n", "}\n", "OutputCell.HIDDEN" @@ -198,14 +198,14 @@ " .loadCollection(pointsData)\n", "\n", "def initialCentroids = planBuilder\n", - " .loadCollection((0..\n", " points.map(new SelectNearestCentroid())\n", " .withBroadcast(currentCentroids, \"centroids\").withName(\"Find nearest centroid\")\n", " .reduceByKey({ tpc -> tpc.cluster }, { tpc1, tpc2 -> tpc1.plus(tpc2) }).withName(\"Add up points\")\n", - " .map({ tpc -> tpc.average() }).withName(\"Average points\").withOutputClass(TaggedPointCounter)\n", + " .map({ tpc -> tpc.average() }).withName(\"Average points\").withOutputClass(PointGrouping)\n", "}).withName(\"Loop\").collect()\n", "\n", "cols = [\"Body\", \"Sweetness\", \"Smoky\", \"Medicinal\", \"Tobacco\", \"Honey\",\n",