diff --git a/pom.xml b/pom.xml
index b66fd018ba0..de0250ff5d4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -68,6 +68,8 @@
1.4.1
2.10.5
2.10
+ 1.0.0
+
@@ -754,6 +756,35 @@
+
+ org.apache.flink
+ flink-java
+ ${flink.version}
+ provided
+
+
+ javax.servlet
+ servlet-api
+
+
+ org.apache.flink
+ flink-shaded-hadoop2
+
+
+
+
+ org.apache.flink
+ flink-clients_${scala.binary.version}
+ ${flink.version}
+ provided
+
+
+ org.apache.flink
+ flink-shaded-hadoop2
+
+
+
+
org.apache.spark
spark-core_2.10
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java
index 570fb07cc3a..563c22f8444 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -69,6 +69,7 @@
import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
+import org.apache.sysml.runtime.controlprogram.context.FlinkExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
@@ -93,7 +94,9 @@ public enum RUNTIME_PLATFORM {
SINGLE_NODE, // execute all matrix operations in CP
HYBRID, // execute matrix operations in CP or MR
HYBRID_SPARK, // execute matrix operations in CP or Spark
- SPARK // execute matrix operations in Spark
+ SPARK, // execute matrix operations in Spark
+ FLINK,
+ HYBRID_FLINK
}
public static RUNTIME_PLATFORM rtplatform = RUNTIME_PLATFORM.HYBRID; //default exec mode
@@ -520,6 +523,10 @@ else if ( platform.equalsIgnoreCase("spark"))
lrtplatform = RUNTIME_PLATFORM.SPARK;
else if ( platform.equalsIgnoreCase("hybrid_spark"))
lrtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
+ else if ( platform.equalsIgnoreCase("flink"))
+ lrtplatform = RUNTIME_PLATFORM.FLINK;
+ else if ( platform.equalsIgnoreCase("hybrid_flink"))
+ lrtplatform = RUNTIME_PLATFORM.HYBRID_FLINK;
else
System.err.println("ERROR: Unknown runtime platform: " + platform);
@@ -674,8 +681,11 @@ private static void execute(String dmlScriptStr, String fnameOptConfig, Map>> _outputs;
-
- public FlinkMLOutput(HashMap>> outputs,
- HashMap outMetadata) {
- super(outMetadata);
- this._outputs = outputs;
- }
-
- public DataSet> getBinaryBlockedDataSet(
- String varName) throws DMLRuntimeException {
- if (_outputs.containsKey(varName)) {
- return _outputs.get(varName);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- // TODO this should be refactored (Superclass MLOutput with Spark and Flink specific subclasses...)
- @Override
- public JavaPairRDD getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
- throw new DMLRuntimeException("FlinkOutput can't return Spark RDDs!");
- }
-}
diff --git a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
index 8a8cffc2d73..b5a053ae561 100644
--- a/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggBinaryOp.java
@@ -48,6 +48,7 @@
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
+import org.apache.sysml.runtime.controlprogram.context.FlinkExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -229,6 +230,29 @@ else if( et == ExecType.SPARK )
throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
}
}
+ else if( et == ExecType.FLINK )
+ {
+ //matrix mult operation selection part 3 (Flink type)
+ boolean tmmRewrite = input1 instanceof ReorgOp && ((ReorgOp)input1).getOp()==ReOrgOp.TRANSPOSE;
+ _method = optFindMMultMethodFlink (
+ input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(),
+ input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(),
+ mmtsj, chain, _hasLeftPMInput, tmmRewrite );
+
+ //dispatch Flink lops construction
+ switch( _method )
+ {
+ case TSMM:
+ constructFlinkLopsTSMM( mmtsj );
+ break;
+ case MAPMM_L:
+ case MAPMM_R:
+ constructFlinkLopsMapMM( _method );
+ break;
+ default:
+ throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + _method + ") while constructing SPARK lops.");
+ }
+ }
else if( et == ExecType.MR )
{
//matrix mult operation selection part 3 (MR type)
@@ -407,7 +431,7 @@ protected ExecType optFindExecType()
{
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
@@ -682,7 +706,92 @@ private Lop constructCPLopsMMWithLeftTransposeRewrite()
return out;
}
-
+
+ //////////////////////////
+ // Flink Lops generation
+ /////////////////////////
+
+ /**
+ *
+ * @param mmtsj
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private void constructFlinkLopsTSMM(MMTSJType mmtsj)
+ throws HopsException, LopsException
+ {
+ Hop input = getInput().get(mmtsj.isLeft()?1:0);
+ MMTSJ tsmm = new MMTSJ(input.constructLops(), getDataType(), getValueType(), ExecType.FLINK, mmtsj);
+ setOutputDimensions(tsmm);
+ setLineNumbers(tsmm);
+ setLops(tsmm);
+ }
+
+ /**
+ *
+ * @param method
+ * @throws LopsException
+ * @throws HopsException
+ */
+ private void constructFlinkLopsMapMM(MMultMethod method)
+ throws LopsException, HopsException
+ {
+ Lop mapmult = null;
+ if( isLeftTransposeRewriteApplicable(false, false) )
+ {
+ mapmult = constructFlinkLopsMapMMWithLeftTransposeRewrite();
+ }
+ else
+ {
+ // If number of columns is smaller than block size then explicit aggregation is not required.
+ // i.e., entire matrix multiplication can be performed in the mappers.
+ boolean needAgg = requiresAggregation(method);
+ SparkAggType aggtype = getSparkMMAggregationType(needAgg);
+ _outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
+
+ //core matrix mult
+ mapmult = new MapMult( getInput().get(0).constructLops(), getInput().get(1).constructLops(),
+ getDataType(), getValueType(), (method==MMultMethod.MAPMM_R), false,
+ _outputEmptyBlocks, aggtype, ExecType.FLINK);
+ }
+ setOutputDimensions(mapmult);
+ setLineNumbers(mapmult);
+ setLops(mapmult);
+ }
+
+ /**
+ *
+ * @return
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private Lop constructFlinkLopsMapMMWithLeftTransposeRewrite()
+ throws HopsException, LopsException
+ {
+ Hop X = getInput().get(0).getInput().get(0); //guaranteed to exists
+ Hop Y = getInput().get(1);
+
+ //right vector transpose
+ Lop tY = new Transform(Y.constructLops(), OperationTypes.Transpose, getDataType(), getValueType(), ExecType.CP);
+ tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), getRowsInBlock(), getColsInBlock(), Y.getNnz());
+ setLineNumbers(tY);
+
+ //matrix mult spark
+ boolean needAgg = requiresAggregation(MMultMethod.MAPMM_R);
+ SparkAggType aggtype = getSparkMMAggregationType(needAgg);
+ _outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
+
+ Lop mult = new MapMult( tY, X.constructLops(), getDataType(), getValueType(),
+ false, false, _outputEmptyBlocks, aggtype, ExecType.FLINK);
+ mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
+ setLineNumbers(mult);
+
+ //result transpose (dimensions set outside)
+ Lop out = new Transform(mult, OperationTypes.Transpose, getDataType(), getValueType(), ExecType.CP);
+
+ return out;
+ }
+
//////////////////////////
// Spark Lops generation
/////////////////////////
@@ -728,7 +837,7 @@ private void constructSparkLopsMapMM(MMultMethod method)
//core matrix mult
mapmult = new MapMult( getInput().get(0).constructLops(), getInput().get(1).constructLops(),
getDataType(), getValueType(), (method==MMultMethod.MAPMM_R), false,
- _outputEmptyBlocks, aggtype);
+ _outputEmptyBlocks, aggtype, ExecType.SPARK);
}
setOutputDimensions(mapmult);
setLineNumbers(mapmult);
@@ -758,7 +867,7 @@ private Lop constructSparkLopsMapMMWithLeftTransposeRewrite()
_outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
Lop mult = new MapMult( tY, X.constructLops(), getDataType(), getValueType(),
- false, false, _outputEmptyBlocks, aggtype);
+ false, false, _outputEmptyBlocks, aggtype, ExecType.SPARK);
mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), getRowsInBlock(), getColsInBlock(), getNnz());
setLineNumbers(mult);
@@ -871,7 +980,6 @@ private Lop constructSparkLopsCPMMWithLeftTransposeRewrite()
/**
*
- * @param chain
* @throws LopsException
* @throws HopsException
*/
@@ -1872,6 +1980,143 @@ else if( (chainType==ChainType.XtwXv || chainType==ChainType.XtXvy )
return MMultMethod.RMM;
}
+
+ /**
+ *
+ * @param m1_rows
+ * @param m1_cols
+ * @param m1_rpb
+ * @param m1_cpb
+ * @param m2_rows
+ * @param m2_cols
+ * @param m2_rpb
+ * @param m2_cpb
+ * @param mmtsj
+ * @param chainType
+ * @return
+ */
+ private MMultMethod optFindMMultMethodFlink( long m1_rows, long m1_cols, long m1_rpb, long m1_cpb, long m1_nnz,
+ long m2_rows, long m2_cols, long m2_rpb, long m2_cpb, long m2_nnz,
+ MMTSJType mmtsj, ChainType chainType, boolean leftPMInput, boolean tmmRewrite )
+ {
+ //Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
+ //and needs to fit once in executor broadcast memory.
+ double memBudgetExec = MAPMULT_MEM_MULTIPLIER * FlinkExecutionContext.getUDFMemoryBudget();
+ double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
+
+ //reset broadcast memory information (for concurrent parfor jobs, awareness of additional
+ //cp memory requirements on flink dataset operations with broadcasts)
+ _spBroadcastMemEstimate = 0;
+
+ // Step 0: check for forced mmultmethod
+ if( FORCED_MMULT_METHOD !=null )
+ return FORCED_MMULT_METHOD;
+
+ // Step 1: check TSMM
+ // If transpose self pattern and result is single block:
+ // use specialized TSMM method (always better than generic jobs)
+ if( ( mmtsj == MMTSJType.LEFT && m2_cols>=0 && m2_cols <= m2_cpb )
+ || ( mmtsj == MMTSJType.RIGHT && m1_rows>=0 && m1_rows <= m1_rpb ) )
+ {
+ return MMultMethod.TSMM;
+ }
+
+ // Step 2: check MapMMChain
+ // If mapmultchain pattern and result is a single block:
+ // use specialized mapmult method
+ if( OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES )
+ {
+ //matmultchain if dim2(X)<=blocksize and all vectors fit in mappers
+ //(X: m1_cols x m1_rows, v: m1_rows x m2_cols, w: m1_cols x m2_cols)
+ //NOTE: generalization possibe: m2_cols>=0 && m2_cols<=m2_cpb
+ if( chainType!=ChainType.NONE && m1_rows >=0 && m1_rows <= m1_rpb && m2_cols==1 )
+ {
+ if( chainType==ChainType.XtXv && m1_rows>=0 && m2_cols>=0
+ && OptimizerUtils.estimateSize(m1_rows, m2_cols ) < memBudgetExec )
+ {
+ return MMultMethod.MAPMM_CHAIN;
+ }
+ else if( (chainType==ChainType.XtwXv || chainType==ChainType.XtXvy )
+ && m1_rows>=0 && m2_cols>=0 && m1_cols>=0
+ && OptimizerUtils.estimateSize(m1_rows, m2_cols)
+ + OptimizerUtils.estimateSize(m1_cols, m2_cols) < memBudgetExec
+ && 2*(OptimizerUtils.estimateSize(m1_rows, m2_cols)
+ + OptimizerUtils.estimateSize(m1_cols, m2_cols)) < memBudgetLocal )
+ {
+ _spBroadcastMemEstimate = 2*(OptimizerUtils.estimateSize(m1_rows, m2_cols)
+ + OptimizerUtils.estimateSize(m1_cols, m2_cols));
+ return MMultMethod.MAPMM_CHAIN;
+ }
+ }
+ }
+
+ // Step 3: check for PMM (permutation matrix needs to fit into mapper memory)
+ // (needs to be checked before mapmult for consistency with removeEmpty compilation
+ double footprintPM1 = getMapmmMemEstimate(m1_rows, 1, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, true);
+ double footprintPM2 = getMapmmMemEstimate(m2_rows, 1, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, true);
+ if( (footprintPM1 < memBudgetExec && m1_rows>=0 || footprintPM2 < memBudgetExec && m2_rows>=0)
+ && 2*OptimizerUtils.estimateSize(m1_rows, 1) < memBudgetLocal
+ && leftPMInput )
+ {
+ _spBroadcastMemEstimate = 2*OptimizerUtils.estimateSize(m1_rows, 1);
+ return MMultMethod.PMM;
+ }
+
+ // Step 4: check MapMM
+ // If the size of one input is small, choose a method that uses broadcast variables to prevent shuffle
+
+ //memory estimates for local partitioning (mb -> partitioned mb)
+ double m1Size = OptimizerUtils.estimateSizeExactSparsity(m1_rows, m1_cols, m1_nnz); //m1 single block
+ double m2Size = OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols, m2_nnz); //m2 single block
+ double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz); //m1 partitioned
+ double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz); //m2 partitioned
+
+ //memory estimates for remote execution (broadcast and outputs)
+ double footprint1 = getMapmmMemEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, false);
+ double footprint2 = getMapmmMemEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 2, false);
+
+ if ( (footprint1 < memBudgetExec && m1Size+m1SizeP < memBudgetLocal && m1_rows>=0 && m1_cols>=0)
+ || (footprint2 < memBudgetExec && m2Size+m2SizeP < memBudgetLocal && m2_rows>=0 && m2_cols>=0) )
+ {
+ //apply map mult if one side fits in remote task memory
+ //(if so pick smaller input for distributed cache)
+ if( m1SizeP < m2SizeP && m1_rows>=0 && m1_cols>=0) {
+ _spBroadcastMemEstimate = m1Size+m1SizeP;
+ return MMultMethod.MAPMM_L;
+ }
+ else {
+ _spBroadcastMemEstimate = m2Size+m2SizeP;
+ return MMultMethod.MAPMM_R;
+ }
+ }
+
+ // Step 5: check for unknowns
+ // If the dimensions are unknown at compilation time, simply assume
+ // the worst-case scenario and produce the most robust plan -- which is CPMM
+ if ( m1_rows == -1 || m1_cols == -1 || m2_rows == -1 || m2_cols == -1 )
+ return MMultMethod.CPMM;
+
+ // Step 6: check for ZIPMM
+ // If t(X)%*%y -> t(t(y)%*%X) rewrite and ncol(X)= 0 && m1_rows <= m1_rpb //blocksize constraint left
+ && m2_cols >= 0 && m2_cols <= m2_cpb ) //blocksize constraint right
+ {
+ return MMultMethod.ZIPMM;
+ }
+
+ // Step 7: Decide CPMM vs RMM based on io costs
+ //estimate shuffle costs weighted by parallelism
+ //TODO currently we reuse the mr estimates, these need to be fine-tune for our spark operators
+ double rmm_costs = getRMMCostEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m2_rows, m2_cols, m2_rpb, m2_cpb);
+ double cpmm_costs = getCPMMCostEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m2_rows, m2_cols, m2_rpb, m2_cpb);
+
+ //final mmult method decision
+ if ( cpmm_costs < rmm_costs )
+ return MMultMethod.CPMM;
+ else
+ return MMultMethod.RMM;
+ }
+
/**
*
* @param m1_rows
diff --git a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
index 0152d5428bf..a412e7bec83 100644
--- a/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/AggUnaryOp.java
@@ -38,6 +38,7 @@
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
+import org.apache.sysml.runtime.controlprogram.parfor.opt.OptimizationWrapper;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -276,7 +277,29 @@ else if( isUnaryAggregateOuterSPRewriteApplicable() )
setLops(unary1);
}
}
- }
+ } else if (et == ExecType.FLINK) {
+ // TODO this is just copied from spark above -- eventually make flink-specific
+ OperationTypes op = HopsAgg2Lops.get(_op);
+ DirectionTypes dir = HopsDirection2Lops.get(_direction);
+
+ //unary aggregate default
+ boolean needAgg = requiresAggregation(input, _direction);
+ SparkAggType aggtype = getSparkUnaryAggregationType(needAgg);
+
+ PartialAggregate aggregate = new PartialAggregate(input.constructLops(),
+ HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), DataType.MATRIX, getValueType(), aggtype, et);
+ aggregate.setDimensionsBasedOnDirection(getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock());
+ setLineNumbers(aggregate);
+ setLops(aggregate);
+
+ if (getDataType() == DataType.SCALAR) {
+ UnaryCP unary1 = new UnaryCP(aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR),
+ getDataType(), getValueType());
+ unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1);
+ setLineNumbers(unary1);
+ setLops(unary1);
+ }
+ }
}
catch (Exception e) {
throw new HopsException(this.printErrorLocation() + "In AggUnary Hop, error constructing Lops " , e);
@@ -403,7 +426,7 @@ protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
//forced / memory-based / threshold-based decision
if( _etypeForced != null )
@@ -435,10 +458,11 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto
if( _etype == ExecType.CP && _etypeForced != ExecType.CP
&& !(getInput().get(0) instanceof DataOp) //input is not checkpoint
&& getInput().get(0).getParent().size()==1 //uagg is only parent
- && getInput().get(0).optFindExecType() == ExecType.SPARK )
+ && (getInput().get(0).optFindExecType() == ExecType.SPARK
+ || getInput().get(0).optFindExecType() == ExecType.FLINK))
{
- //pull unary aggregate into spark
- _etype = ExecType.SPARK;
+ //pull unary aggregate into spark/flink
+ _etype = OptimizerUtils.getRemoteExecType();
}
//mark for recompile (forever)
diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java
index 94de0e7469c..3b6de14815c 100644
--- a/src/main/java/org/apache/sysml/hops/BinaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java
@@ -77,7 +77,8 @@ public enum AppendMethod {
MR_MAPPEND, //map-only append (rhs must be vector and fit in mapper mem)
MR_RAPPEND, //reduce-only append (output must have at most one column block)
MR_GAPPEND, //map-reduce general case append (map-extend, aggregate)
- SP_GAlignedAppend // special case for general case in Spark where left.getCols() % left.getColsPerBlock() == 0
+ SP_GAlignedAppend, // special case for general case in Spark where left.getCols() % left.getColsPerBlock() == 0
+ FL_GAlignedAppend
};
private enum MMBinaryMethod{
@@ -527,6 +528,11 @@ else if(et == ExecType.SPARK)
append = constructSPAppendLop(getInput().get(0), getInput().get(1), getDataType(), getValueType(), cbind, this);
append.getOutputParameters().setDimensions(rlen, clen, getRowsInBlock(), getColsInBlock(), getNnz());
}
+ else if(et == ExecType.FLINK)
+ {
+ append = constructFLAppendLop(getInput().get(0), getInput().get(1), getDataType(), getValueType(), cbind, this);
+ append.getOutputParameters().setDimensions(rlen, clen, getRowsInBlock(), getColsInBlock(), getNnz());
+ }
else //CP
{
Lop offset = createOffsetLop( getInput().get(0), cbind ); //offset 1st input
@@ -585,8 +591,7 @@ else if( op==OpOp2.MULT && right instanceof LiteralOp && ((LiteralOp)right).getD
ot = Unary.OperationTypes.MULTIPLY2;
else //general case
ot = HopsOpOp2LopsU.get(op);
-
-
+
Unary unary1 = new Unary(getInput().get(0).constructLops(),
getInput().get(1).constructLops(), ot, getDataType(), getValueType(), et);
@@ -637,6 +642,35 @@ else if (mbin == MMBinaryMethod.MR_BINARY_M) {
setLineNumbers(binary);
setLops(binary);
}
+ else if(et == ExecType.FLINK)
+ {
+ Hop left = getInput().get(0);
+ Hop right = getInput().get(1);
+ MMBinaryMethod mbin = optFindMMBinaryMethodFlink(left, right);
+
+ Lop binary = null;
+ if( mbin == MMBinaryMethod.MR_BINARY_UAGG_CHAIN ) {
+ AggUnaryOp uRight = (AggUnaryOp)right;
+ binary = new BinaryUAggChain(left.constructLops(), HopsOpOp2LopsB.get(op),
+ HopsAgg2Lops.get(uRight.getOp()), HopsDirection2Lops.get(uRight.getDirection()),
+ getDataType(), getValueType(), et);
+ }
+ else if (mbin == MMBinaryMethod.MR_BINARY_M) {
+ boolean partitioned = false;
+ boolean isColVector = (right.getDim2()==1 && left.getDim1()==right.getDim1());
+
+ binary = new BinaryM(left.constructLops(), right.constructLops(),
+ HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et, partitioned, isColVector);
+ }
+ else {
+ binary = new Binary(left.constructLops(), right.constructLops(),
+ HopsOpOp2LopsB.get(op), getDataType(), getValueType(), et);
+ }
+
+ setOutputDimensions(binary);
+ setLineNumbers(binary);
+ setLops(binary);
+ }
else //MR
{
Hop left = getInput().get(0);
@@ -798,8 +832,6 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz )
ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
}
-
-
return ret;
}
@@ -933,7 +965,7 @@ protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
DataType dt1 = getInput().get(0).getDataType();
DataType dt2 = getInput().get(1).getDataType();
@@ -998,6 +1030,20 @@ && getInput().get(dt1.isScalar()?1:0).optFindExecType() == ExecType.SPARK )
_etype = ExecType.SPARK;
}
+ //spark-specific decision refinement (execute unary scalar w/ spark input and
+ //single parent also in spark because it's likely cheap and reduces intermediates)
+ if( _etype == ExecType.CP && _etypeForced != ExecType.CP
+ && getDataType().isMatrix() && (dt1.isScalar() || dt2.isScalar())
+ && supportsMatrixScalarOperations() //scalar operations
+ && !(getInput().get(dt1.isScalar()?1:0) instanceof DataOp) //input is not checkpoint
+ && getInput().get(dt1.isScalar()?1:0).getParent().size()==1 //unary scalar is only parent
+ && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar()?1:0)) //single block triggered exec
+ && getInput().get(dt1.isScalar()?1:0).optFindExecType() == ExecType.FLINK )
+ {
+ //pull unary scalar operation into spark
+ _etype = ExecType.FLINK;
+ }
+
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE ) {
setRequiresRecompile();
@@ -1159,8 +1205,67 @@ public static Lop constructSPAppendLop( Hop left, Hop right, DataType dt, ValueT
}
ret.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), current.getEndColumn());
-
-
+
+
+ return ret;
+ }
+
+
+ /**
+ *
+ * @param left
+ * @param right
+ * @param dt
+ * @param vt
+ * @param current
+ * @return
+ * @throws HopsException
+ * @throws LopsException
+ */
+ public static Lop constructFLAppendLop( Hop left, Hop right, DataType dt, ValueType vt, boolean cbind, Hop current )
+ throws HopsException, LopsException
+ {
+ Lop ret = null;
+
+ Lop offset = createOffsetLop( left, cbind ); //offset 1st input
+ AppendMethod am = optFindAppendFLMethod(left.getDim1(), left.getDim2(), right.getDim1(), right.getDim2(),
+ right.getRowsInBlock(), right.getColsInBlock(), right.getNnz(), cbind);
+
+ switch( am )
+ {
+ case MR_MAPPEND: //special case map-only append
+ {
+ ret = new AppendM(left.constructLops(), right.constructLops(), offset,
+ current.getDataType(), current.getValueType(), cbind, false, ExecType.FLINK);
+ break;
+ }
+ case MR_RAPPEND: //special case reduce append w/ one column block
+ {
+ ret = new AppendR(left.constructLops(), right.constructLops(),
+ current.getDataType(), current.getValueType(), cbind, ExecType.FLINK);
+ break;
+ }
+ case MR_GAPPEND:
+ {
+ Lop offset2 = createOffsetLop( right, cbind ); //offset second input
+ ret = new AppendG(left.constructLops(), right.constructLops(), offset, offset2,
+ current.getDataType(), current.getValueType(), cbind, ExecType.FLINK);
+ break;
+ }
+ /*
+ case SP_GAlignedAppend:
+ {
+ ret = new AppendGAlignedSP(left.constructLops(), right.constructLops(), offset,
+ current.getDataType(), current.getValueType(), cbind);
+ break;
+ }*/
+ default:
+ throw new HopsException("Invalid SP append method: "+am);
+ }
+
+ ret.setAllPositions(current.getBeginLine(), current.getBeginColumn(), current.getEndLine(), current.getEndColumn());
+
+
return ret;
}
@@ -1317,6 +1422,42 @@ private static AppendMethod optFindAppendSPMethod( long m1_dim1, long m1_dim2, l
return AppendMethod.MR_GAPPEND;
}
+ private static AppendMethod optFindAppendFLMethod( long m1_dim1, long m1_dim2, long m2_dim1, long m2_dim2, long m1_rpb, long m1_cpb, long m2_nnz, boolean cbind )
+ {
+ if(FORCED_APPEND_METHOD != null) {
+ return FORCED_APPEND_METHOD;
+ }
+
+ //check for best case (map-only w/o shuffle)
+ if( m2_dim1 >= 1 && m2_dim2 >= 1 //rhs dims known
+ && (cbind && m2_dim2 <= m1_cpb //rhs is smaller than column block
+ || !cbind && m2_dim1 <= m1_rpb) ) //rhs is smaller than row block
+ {
+ if( OptimizerUtils.checkFlinkBroadcastMemoryBudget(m2_dim1, m2_dim2, m1_rpb, m1_cpb, m2_nnz) ) {
+ return AppendMethod.MR_MAPPEND;
+ }
+ }
+
+ //check for in-block append (reduce-only)
+ if( cbind && m1_dim2 >= 1 && m2_dim2 >= 0 //column dims known
+ && m1_dim2+m2_dim2 <= m1_cpb //output has one column block
+ ||!cbind && m1_dim1 >= 1 && m2_dim1 >= 0 //row dims known
+ && m1_dim1+m2_dim1 <= m1_rpb ) //output has one column block
+ {
+ return AppendMethod.MR_RAPPEND;
+ }
+
+ // if(mc1.getCols() % mc1.getColsPerBlock() == 0) {
+ if( cbind && m1_dim2 % m1_cpb == 0
+ || !cbind && m1_dim1 % m1_rpb == 0 )
+ {
+ return AppendMethod.FL_GAlignedAppend;
+ }
+
+ //general case (map and reduce)
+ return AppendMethod.MR_GAPPEND;
+ }
+
/**
*
* @param rightInput
@@ -1371,6 +1512,37 @@ private MMBinaryMethod optFindMMBinaryMethodSpark(Hop left, Hop right) {
//MR_BINARY_R as robust fallback strategy
return MMBinaryMethod.MR_BINARY_R;
}
+
+ private MMBinaryMethod optFindMMBinaryMethodFlink(Hop left, Hop right) {
+ long m1_dim1 = left.getDim1();
+ long m1_dim2 = left.getDim2();
+ long m2_dim1 = right.getDim1();
+ long m2_dim2 = right.getDim2();
+ long m1_rpb = left.getRowsInBlock();
+ long m1_cpb = left.getColsInBlock();
+
+ //MR_BINARY_UAGG_CHAIN only applied if result is column/row vector of MV binary operation.
+ if( right instanceof AggUnaryOp && right.getInput().get(0) == left //e.g., P / rowSums(P)
+ && ((((AggUnaryOp) right).getDirection() == Direction.Row && m1_dim2 > 1 && m1_dim2 <= m1_cpb ) //single column block
+ || (((AggUnaryOp) right).getDirection() == Direction.Col && m1_dim1 > 1 && m1_dim1 <= m1_rpb ))) //single row block
+ {
+ return MMBinaryMethod.MR_BINARY_UAGG_CHAIN;
+ }
+
+ //MR_BINARY_M currently only applied for MV because potential partitioning job may cause additional latency for VV.
+ if( m2_dim1 >= 1 && m2_dim2 >= 1 // rhs dims known
+ && ((m1_dim2 >= 1 && m2_dim2 == 1) //rhs column vector
+ ||(m1_dim1 >= 1 && m2_dim1 == 1 )) ) //rhs row vector
+ {
+ double size = OptimizerUtils.estimateSize(m2_dim1, m2_dim2);
+ if( OptimizerUtils.checkFlinkBroadcastMemoryBudget(size) ) { //TODO: check this
+ return MMBinaryMethod.MR_BINARY_M;
+ }
+ }
+
+ //MR_BINARY_R as robust fallback strategy
+ return MMBinaryMethod.MR_BINARY_R;
+ }
/**
*
diff --git a/src/main/java/org/apache/sysml/hops/DataGenOp.java b/src/main/java/org/apache/sysml/hops/DataGenOp.java
index 4c81f77dbd9..c17f91611b8 100644
--- a/src/main/java/org/apache/sysml/hops/DataGenOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataGenOp.java
@@ -166,7 +166,7 @@ else if( cur.getKey().equals(DataExpression.RAND_COLS) && _dim2>0 )
(getRowsInBlock()>0)?getRowsInBlock():ConfigurationManager.getBlocksize(),
(getColsInBlock()>0)?getColsInBlock():ConfigurationManager.getBlocksize(),
//actual rand nnz might differ (in cp/mr they are corrected after execution)
- (_op==DataGenMethod.RAND && et==ExecType.SPARK && getNnz()!=0) ? -1 : getNnz(),
+ (_op==DataGenMethod.RAND && (et==ExecType.SPARK || et == ExecType.FLINK) && getNnz()!=0) ? -1 : getNnz(),
getUpdateInPlace());
setLineNumbers(rnd);
@@ -276,7 +276,7 @@ protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
_etype = _etypeForced;
diff --git a/src/main/java/org/apache/sysml/hops/DataOp.java b/src/main/java/org/apache/sysml/hops/DataOp.java
index a9cb8b97099..e356715cff5 100644
--- a/src/main/java/org/apache/sysml/hops/DataOp.java
+++ b/src/main/java/org/apache/sysml/hops/DataOp.java
@@ -428,7 +428,7 @@ protected ExecType optFindExecType()
//for example for sum(X) where the memory consumption is solely determined by the DataOp
ExecType letype = (OptimizerUtils.isMemoryBasedOptLevel()) ? findExecTypeByMemEstimate() : null;
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
//NOTE: independent of etype executed in MR (piggybacked) if input to persistent write is MR
if( _dataop == DataOpTypes.PERSISTENTWRITE || _dataop == DataOpTypes.TRANSIENTWRITE )
diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java
index 43193a9e1ee..6f659567cf7 100644
--- a/src/main/java/org/apache/sysml/hops/Hop.java
+++ b/src/main/java/org/apache/sysml/hops/Hop.java
@@ -199,6 +199,8 @@ else if ( DMLScript.rtplatform == RUNTIME_PLATFORM.HADOOP )
_etypeForced = ExecType.MR;
else if ( DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK )
_etypeForced = ExecType.SPARK;
+ else if ( DMLScript.rtplatform == RUNTIME_PLATFORM.FLINK )
+ _etypeForced = ExecType.FLINK;
}
/**
@@ -229,6 +231,8 @@ public void checkAndSetInvalidCPDimsAndSize()
_etype = ExecType.MR;
else if( DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK )
_etype = ExecType.SPARK;
+ else if( DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_FLINK)
+ _etype = ExecType.FLINK;
}
}
}
@@ -300,7 +304,7 @@ private void constructAndSetReblockLopIfRequired()
if( DMLScript.rtplatform != RUNTIME_PLATFORM.SINGLE_NODE
&& !(getDataType()==DataType.SCALAR) )
{
- et = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ et = OptimizerUtils.getRemoteExecType();
}
//add reblock lop to output if required
@@ -346,8 +350,7 @@ private void constructAndSetCheckpointLopIfRequired()
{
//determine execution type
ExecType et = ExecType.CP;
- if( OptimizerUtils.isSparkExecutionMode()
- && getDataType()!=DataType.SCALAR )
+ if (OptimizerUtils.isSparkExecutionMode() && getDataType() != DataType.SCALAR)
{
//conditional checkpoint based on memory estimate in order to
//(1) avoid unnecessary persist and unpersist calls, and
@@ -785,6 +788,8 @@ protected ExecType findExecTypeByMemEstimate() {
et = ExecType.MR;
else if( DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.HYBRID_SPARK )
et = ExecType.SPARK;
+ else if( DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.HYBRID_FLINK )
+ et = ExecType.FLINK;
c = '*';
}
@@ -928,7 +933,6 @@ private void resetRecompilationFlag( ExecType et )
/**
* Test and debugging only.
*
- * @param h
* @throws HopsException
*/
public void checkParentChildPointers( )
diff --git a/src/main/java/org/apache/sysml/hops/IndexingOp.java b/src/main/java/org/apache/sysml/hops/IndexingOp.java
index 6a6da5deac4..75793755bac 100644
--- a/src/main/java/org/apache/sysml/hops/IndexingOp.java
+++ b/src/main/java/org/apache/sysml/hops/IndexingOp.java
@@ -340,7 +340,7 @@ protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
diff --git a/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java b/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java
index 07091e7b441..6ccde11a48f 100644
--- a/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java
+++ b/src/main/java/org/apache/sysml/hops/LeftIndexingOp.java
@@ -365,7 +365,7 @@ protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
diff --git a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
index e9beafbe058..c5a2f1dd82d 100644
--- a/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysml/hops/OptimizerUtils.java
@@ -39,6 +39,7 @@
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
+import org.apache.sysml.runtime.controlprogram.context.FlinkExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysml.runtime.instructions.cp.Data;
@@ -431,6 +432,37 @@ public static boolean checkSparkBroadcastMemoryBudget( double size )
//memory and hand it over to the spark context as in-memory object
return ( size < memBudgetExec && 2*size < memBudgetLocal );
}
+
+ /**
+ *
+ * @param size
+ * @return
+ */
+ public static boolean checkFlinkBroadcastMemoryBudget( double size )
+ {
+ double memBudgetExec = FlinkExecutionContext.getUDFMemoryBudget(); //TODO: change this
+ double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
+
+ //basic requirement: the broadcast needs to to fit once in the remote broadcast memory
+ //and twice into the local memory budget because we have to create a partitioned broadcast
+ //memory and hand it over to the spark context as in-memory object
+ return ( size < memBudgetExec && 2*size < memBudgetLocal );
+ }
+
+ public static boolean checkFlinkBroadcastMemoryBudget(long rlen, long clen, long brlen, long bclen, long nnz) {
+ double memBudgetExec = FlinkExecutionContext.getUDFMemoryBudget();
+ double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
+
+ double sp = getSparsity(rlen, clen, nnz);
+ double size = estimateSizeExactSparsity(rlen, clen, sp);
+ double sizeP = estimatePartitionedSizeExactSparsity(rlen, clen, brlen, bclen, sp);
+
+ //basic requirement: the broadcast needs to to fit once in the remote broadcast memory
+ //and twice into the local memory budget because we have to create a partitioned broadcast
+ //memory and hand it over to the spark context as in-memory object
+ return ( OptimizerUtils.isValidCPDimensions(rlen, clen)
+ && sizeP < memBudgetExec && size+sizeP < memBudgetLocal );
+ }
/**
*
@@ -548,6 +580,20 @@ public static boolean isSparkExecutionMode() {
return ( DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK
|| DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK);
}
+
+ public static boolean isFlinkExecutionMode() {
+ return ( DMLScript.rtplatform == RUNTIME_PLATFORM.FLINK
+ || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_FLINK);
+ }
+
+ public static ExecType getRemoteExecType() {
+ if (isSparkExecutionMode())
+ return ExecType.SPARK;
+ else if (isFlinkExecutionMode())
+ return ExecType.FLINK;
+ else
+ return ExecType.MR;
+ }
/**
*
@@ -555,7 +601,8 @@ public static boolean isSparkExecutionMode() {
*/
public static boolean isHybridExecutionMode() {
return ( DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID
- || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK );
+ || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK
+ || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_FLINK);
}
/**
diff --git a/src/main/java/org/apache/sysml/hops/QuaternaryOp.java b/src/main/java/org/apache/sysml/hops/QuaternaryOp.java
index 6fb140bc5c2..84ba4eab498 100644
--- a/src/main/java/org/apache/sysml/hops/QuaternaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/QuaternaryOp.java
@@ -50,6 +50,7 @@
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock.PDataPartitionFormat;
+import org.apache.sysml.runtime.controlprogram.context.FlinkExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.mapred.DistributedCacheInput;
@@ -228,6 +229,8 @@ else if( et == ExecType.MR )
constructMRLopsWeightedSquaredLoss(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedSquaredLoss(wtype);
+ else if ( et == ExecType.FLINK )
+ constructFlinkLopsWeightedSquaredLoss(wtype);
else
throw new HopsException("Unsupported quaternaryop-wsloss exec type: "+et);
break;
@@ -567,6 +570,63 @@ private void constructSparkLopsWeightedSquaredLoss(WeightsType wtype)
}
}
+
+ /**
+ *
+ * @param wtype
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private void constructFlinkLopsWeightedSquaredLoss(WeightsType wtype)
+ throws HopsException, LopsException
+ {
+ //NOTE: the common case for wsloss are factors U/V with a rank of 10s to 100s; the current runtime only
+ //supports single block outer products (U/V rank <= blocksize, i.e., 1000 by default); we enforce this
+ //by applying the hop rewrite for Weighted Squared Loss only if this constraint holds.
+
+ //Notes: Any broadcast needs to fit twice in local memory because we partition the input in cp,
+ //and needs to fit once in executor broadcast memory. The 2GB broadcast constraint is no longer
+ //required because the max_int byte buffer constraint has been fixed in Spark 1.4
+ double memBudgetExec = FlinkExecutionContext.getUDFMemoryBudget();
+ double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
+
+ Hop X = getInput().get(0);
+ Hop U = getInput().get(1);
+ Hop V = getInput().get(2);
+ Hop W = getInput().get(3);
+
+ //MR operator selection, part1
+ double m1Size = OptimizerUtils.estimateSize(U.getDim1(), U.getDim2()); //size U
+ double m2Size = OptimizerUtils.estimateSize(V.getDim1(), V.getDim2()); //size V
+ boolean isMapWsloss = (!wtype.hasFourInputs() && m1Size+m2Size < memBudgetExec
+ && 2*m1Size < memBudgetLocal && 2*m2Size < memBudgetLocal);
+
+ if( !FORCE_REPLICATION && isMapWsloss ) //broadcast
+ {
+ //map-side wsloss always with broadcast
+ Lop wsloss = new WeightedSquaredLoss( X.constructLops(), U.constructLops(), V.constructLops(), W.constructLops(),
+ DataType.SCALAR, ValueType.DOUBLE, wtype, ExecType.FLINK);
+ setOutputDimensions(wsloss);
+ setLineNumbers(wsloss);
+ setLops(wsloss);
+ }
+ else //general case
+ {
+ //MR operator selection part 2
+ boolean cacheU = !FORCE_REPLICATION && (m1Size < memBudgetExec && 2*m1Size < memBudgetLocal);
+ boolean cacheV = !FORCE_REPLICATION && ((!cacheU && m2Size < memBudgetExec )
+ || (cacheU && m1Size+m2Size < memBudgetExec)) && 2*m2Size < memBudgetLocal;
+
+ //reduce-side wsloss w/ or without broadcast
+ Lop wsloss = new WeightedSquaredLossR(
+ X.constructLops(), U.constructLops(), V.constructLops(), W.constructLops(),
+ DataType.SCALAR, ValueType.DOUBLE, wtype, cacheU, cacheV, ExecType.FLINK);
+ setOutputDimensions(wsloss);
+ setLineNumbers(wsloss);
+ setLops(wsloss);
+ }
+ }
+
/**
*
* @param wtype
@@ -1602,7 +1662,7 @@ protected ExecType optFindExecType()
{
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
diff --git a/src/main/java/org/apache/sysml/hops/ReorgOp.java b/src/main/java/org/apache/sysml/hops/ReorgOp.java
index d81e777fa08..93ed196afa8 100644
--- a/src/main/java/org/apache/sysml/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysml/hops/ReorgOp.java
@@ -321,7 +321,7 @@ public Lop constructLops()
}
else //CP or Spark
{
- if( et==ExecType.SPARK && !FORCE_DIST_SORT_INDEXES)
+ if( (et==ExecType.SPARK || et==ExecType.FLINK) && !FORCE_DIST_SORT_INDEXES)
bSortSPRewriteApplicable = isSortSPRewriteApplicable();
Lop transform1 = constructCPOrSparkSortLop(input, by, desc, ixret, et, bSortSPRewriteApplicable);
@@ -482,7 +482,7 @@ protected ExecType optFindExecType() throws HopsException {
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
@@ -676,4 +676,4 @@ private boolean isSortSPRewriteApplicable()
return ret;
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysml/hops/TernaryOp.java b/src/main/java/org/apache/sysml/hops/TernaryOp.java
index e3532730831..36137059692 100644
--- a/src/main/java/org/apache/sysml/hops/TernaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/TernaryOp.java
@@ -386,7 +386,7 @@ private void constructLopsCtable() throws HopsException, LopsException {
//reset reblock requirement (see MR ctable / construct lops)
setRequiresReblock( false );
- if ( et == ExecType.CP || et == ExecType.SPARK)
+ if ( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.FLINK)
{
//for CP we support only ctable expand left
Ternary.OperationTypes tertiaryOp = isSequenceRewriteApplicable(true) ? Ternary.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT : tertiaryOpOrig;
@@ -404,7 +404,7 @@ private void constructLopsCtable() throws HopsException, LopsException {
tertiary.setAllPositions(this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn());
//force blocked output in CP (see below), otherwise binarycell
- if ( et == ExecType.SPARK ) {
+ if ( et == ExecType.SPARK || et == ExecType.FLINK) {
tertiary.getOutputParameters().setDimensions(_dim1, _dim2, -1, -1, -1);
setRequiresReblock( true );
}
@@ -757,7 +757,7 @@ protected ExecType optFindExecType()
{
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
diff --git a/src/main/java/org/apache/sysml/hops/UnaryOp.java b/src/main/java/org/apache/sysml/hops/UnaryOp.java
index 4ed0225363d..ac4dc6a5d40 100644
--- a/src/main/java/org/apache/sysml/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysml/hops/UnaryOp.java
@@ -155,10 +155,15 @@ else if(_op == Hop.OpOp1.MEDIAN) {
{
//TODO additional physical operation if offsets fit in memory
Lop cumsumLop = null;
- if( et == ExecType.MR )
+ if( et == ExecType.MR ) {
cumsumLop = constructLopsMRCumulativeUnary();
- else
+ }
+ else if (et == ExecType.FLINK) {
+ cumsumLop = constructLopsFlinkCumulativeUnary();
+ }
+ else {
cumsumLop = constructLopsSparkCumulativeUnary();
+ }
setLops(cumsumLop);
}
else //default unary
@@ -444,6 +449,72 @@ private Lop constructLopsMRCumulativeUnary()
return TEMP;
}
+
+
+ /**
+ *
+ * @return
+ * @throws HopsException
+ * @throws LopsException
+ */
+ private Lop constructLopsFlinkCumulativeUnary()
+ throws HopsException, LopsException
+ {
+ Hop input = getInput().get(0);
+ long rlen = input.getDim1();
+ long clen = input.getDim2();
+ long brlen = input.getRowsInBlock();
+ long bclen = input.getColsInBlock();
+ boolean force = !dimsKnown() || _etypeForced == ExecType.FLINK;
+ OperationTypes aggtype = getCumulativeAggType();
+
+ Lop X = input.constructLops();
+ Lop TEMP = X;
+ ArrayList DATA = new ArrayList();
+ int level = 0;
+
+ //recursive preaggregation until aggregates fit into CP memory budget
+ while( ((2*OptimizerUtils.estimateSize(TEMP.getOutputParameters().getNumRows(), clen) + OptimizerUtils.estimateSize(1, clen))
+ > OptimizerUtils.getLocalMemBudget()
+ && TEMP.getOutputParameters().getNumRows()>1) || force )
+ {
+ DATA.add(TEMP);
+
+ //preaggregation per block (for flink, the CumulativePartialAggregate subsumes both
+ //the preaggregation and subsequent block aggregation)
+ long rlenAgg = (long)Math.ceil((double)TEMP.getOutputParameters().getNumRows()/brlen);
+ Lop preagg = new CumulativePartialAggregate(TEMP, DataType.MATRIX, ValueType.DOUBLE, aggtype, ExecType.FLINK);
+ preagg.getOutputParameters().setDimensions(rlenAgg, clen, brlen, bclen, -1);
+ setLineNumbers(preagg);
+
+ TEMP = preagg;
+ level++;
+ force = false; //in case of unknowns, generate one level
+ }
+
+ //in-memory cum sum (of partial aggregates)
+ if( TEMP.getOutputParameters().getNumRows()!=1 ) {
+ int k = OptimizerUtils.getConstrainedNumThreads( _maxNumThreads );
+ Unary unary1 = new Unary( TEMP, HopsOpOp1LopsU.get(_op), DataType.MATRIX, ValueType.DOUBLE, ExecType.CP, k);
+ unary1.getOutputParameters().setDimensions(TEMP.getOutputParameters().getNumRows(), clen, brlen, bclen, -1);
+ setLineNumbers(unary1);
+ TEMP = unary1;
+ }
+
+ //split, group and mr cumsum
+ while( level-- > 0 ) {
+ //(for flink, the CumulativeOffsetBinary subsumes both the split aggregate and
+ //the subsequent offset binary apply of split aggregates against the original data)
+ double initValue = getCumulativeInitValue();
+ CumulativeOffsetBinary binary = new CumulativeOffsetBinary(DATA.get(level), TEMP,
+ DataType.MATRIX, ValueType.DOUBLE, initValue, aggtype, ExecType.FLINK);
+ binary.getOutputParameters().setDimensions(rlen, clen, brlen, bclen, -1);
+ setLineNumbers(binary);
+ TEMP = binary;
+ }
+
+ return TEMP;
+ }
/**
*
@@ -646,7 +717,7 @@ protected ExecType optFindExecType()
{
checkAndSetForcedPlatform();
- ExecType REMOTE = OptimizerUtils.isSparkExecutionMode() ? ExecType.SPARK : ExecType.MR;
+ ExecType REMOTE = OptimizerUtils.getRemoteExecType();
if( _etypeForced != null )
{
@@ -686,6 +757,18 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent
//pull unary operation into spark
_etype = ExecType.SPARK;
}
+
+ if( _etype == ExecType.CP && _etypeForced != ExecType.CP
+ && getInput().get(0).optFindExecType() == ExecType.FLINK
+ && getDataType().isMatrix()
+ && !isCumulativeUnaryOperation() && !isCastUnaryOperation()
+ && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM
+ && !(getInput().get(0) instanceof DataOp) //input is not checkpoint
+ && getInput().get(0).getParent().size()==1 ) //unary is only parent
+ {
+ //pull unary operation into flink
+ _etype = ExecType.FLINK;
+ }
//mark for recompile (forever)
if( ConfigurationManager.isDynamicRecompilation() && !dimsKnown(true) && _etype==REMOTE )
diff --git a/src/main/java/org/apache/sysml/lops/BinaryM.java b/src/main/java/org/apache/sysml/lops/BinaryM.java
index a911e5cff59..bb243538baf 100644
--- a/src/main/java/org/apache/sysml/lops/BinaryM.java
+++ b/src/main/java/org/apache/sysml/lops/BinaryM.java
@@ -79,6 +79,10 @@ else if(et == ExecType.SPARK) {
lps.addCompatibility(JobType.INVALID);
lps.setProperties( inputs, ExecType.SPARK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
}
+ else if(et == ExecType.FLINK) {
+ lps.addCompatibility(JobType.INVALID);
+ lps.setProperties( inputs, ExecType.FLINK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
else {
throw new LopsException("Incorrect execution type for BinaryM lop:" + et.name());
}
@@ -217,4 +221,4 @@ public int[] distributedCacheInputIndex()
// second input is from distributed cache
return new int[]{2};
}
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysml/lops/CSVReBlock.java b/src/main/java/org/apache/sysml/lops/CSVReBlock.java
index 5bca2bb8019..cbee1e2daf7 100644
--- a/src/main/java/org/apache/sysml/lops/CSVReBlock.java
+++ b/src/main/java/org/apache/sysml/lops/CSVReBlock.java
@@ -19,6 +19,7 @@
package org.apache.sysml.lops;
+import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.LopProperties.ExecLocation;
import org.apache.sysml.lops.LopProperties.ExecType;
import org.apache.sysml.lops.ParameterizedBuiltin.OperationTypes;
@@ -74,6 +75,9 @@ public CSVReBlock(Lop input, Long rows_per_block, Long cols_per_block, DataType
else if(et == ExecType.SPARK) {
this.lps.setProperties( inputs, ExecType.SPARK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
}
+ else if(et == ExecType.FLINK) {
+ this.lps.setProperties( inputs, ExecType.FLINK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ }
else {
throw new LopsException("Incorrect execution type for CSVReblock:" + et);
}
@@ -162,7 +166,7 @@ public String getInstructions(int input_index, int output_index) throws LopsExce
@Override
public String getInstructions(String input1, String output) throws LopsException {
- if(getExecType() != ExecType.SPARK) {
+ if(!(OptimizerUtils.isSparkExecutionMode() || OptimizerUtils.isFlinkExecutionMode())) {
throw new LopsException("The method getInstructions(String,String) for CSVReblock should be called only for Spark execution type");
}
diff --git a/src/main/java/org/apache/sysml/lops/Data.java b/src/main/java/org/apache/sysml/lops/Data.java
index 3b936d3b32a..0d571de27e0 100644
--- a/src/main/java/org/apache/sysml/lops/Data.java
+++ b/src/main/java/org/apache/sysml/lops/Data.java
@@ -404,6 +404,8 @@ public String getInstructions(String input1, String input2)
StringBuilder sb = new StringBuilder();
if(this.getExecType() == ExecType.SPARK)
sb.append( "SPARK" );
+ else if(this.getExecType() == ExecType.FLINK)
+ sb.append( "FLINK" );
else
sb.append( "CP" );
sb.append( OPERAND_DELIMITOR );
@@ -481,7 +483,7 @@ else if ( oparams.getFormat() == Format.BINARY ){
sb.append(OPERAND_DELIMITOR);
sb.append(sparseLop.getBooleanValue());
- if ( this.getExecType() == ExecType.SPARK )
+ if ( this.getExecType() == ExecType.SPARK || this.getExecType() == ExecType.FLINK )
{
boolean isInputMatrixBlock = true;
Lop input = getInputs().get(0);
diff --git a/src/main/java/org/apache/sysml/lops/DataGen.java b/src/main/java/org/apache/sysml/lops/DataGen.java
index ec8fb9c8158..cc0fa940a55 100644
--- a/src/main/java/org/apache/sysml/lops/DataGen.java
+++ b/src/main/java/org/apache/sysml/lops/DataGen.java
@@ -208,7 +208,7 @@ private String getCPInstruction_Rand(String output)
sb.append(iLop.prepScalarLabel());
sb.append(OPERAND_DELIMITOR);
- if ( getExecType() == ExecType.MR || getExecType() == ExecType.SPARK ) {
+ if ( getExecType() == ExecType.MR || getExecType() == ExecType.SPARK || getExecType() == ExecType.FLINK ) {
sb.append(baseDir);
sb.append(OPERAND_DELIMITOR);
}
diff --git a/src/main/java/org/apache/sysml/lops/Lop.java b/src/main/java/org/apache/sysml/lops/Lop.java
index 8424e284d0c..e49b4964694 100644
--- a/src/main/java/org/apache/sysml/lops/Lop.java
+++ b/src/main/java/org/apache/sysml/lops/Lop.java
@@ -668,7 +668,7 @@ public String prepScalarOperand(ExecType et, String label) {
boolean isLiteral = (isData && ((Data)this).isLiteral());
StringBuilder sb = new StringBuilder("");
- if ( et == ExecType.CP || et == ExecType.SPARK || (isData && isLiteral)) {
+ if ( et == ExecType.CP || et == ExecType.SPARK || et == ExecType.FLINK || (isData && isLiteral)) {
sb.append(label);
}
else {
diff --git a/src/main/java/org/apache/sysml/lops/LopProperties.java b/src/main/java/org/apache/sysml/lops/LopProperties.java
index 8d9e9160f93..cd46c8986ce 100644
--- a/src/main/java/org/apache/sysml/lops/LopProperties.java
+++ b/src/main/java/org/apache/sysml/lops/LopProperties.java
@@ -27,7 +27,7 @@
public class LopProperties
{
- public enum ExecType { CP, CP_FILE, MR, SPARK, INVALID };
+ public enum ExecType { CP, CP_FILE, MR, SPARK, FLINK, INVALID };
public enum ExecLocation {INVALID, RecordReader, Map, MapOrReduce, MapAndReduce, Reduce, Data, ControlProgram };
// static variable to assign an unique ID to every lop that is created
diff --git a/src/main/java/org/apache/sysml/lops/MapMult.java b/src/main/java/org/apache/sysml/lops/MapMult.java
index 9c4e4f346ee..9eeaa5537cd 100644
--- a/src/main/java/org/apache/sysml/lops/MapMult.java
+++ b/src/main/java/org/apache/sysml/lops/MapMult.java
@@ -95,7 +95,7 @@ public MapMult(Lop input1, Lop input2, DataType dt, ValueType vt, boolean rightC
* @param et
* @throws LopsException
*/
- public MapMult(Lop input1, Lop input2, DataType dt, ValueType vt, boolean rightCache, boolean partitioned, boolean emptyBlocks, SparkAggType aggtype)
+ public MapMult(Lop input1, Lop input2, DataType dt, ValueType vt, boolean rightCache, boolean partitioned, boolean emptyBlocks, SparkAggType aggtype, ExecType t)
throws LopsException
{
super(Lop.Type.MapMult, dt, vt);
@@ -117,7 +117,7 @@ public MapMult(Lop input1, Lop input2, DataType dt, ValueType vt, boolean rightC
boolean aligner = false;
boolean definesMRJob = false;
lps.addCompatibility(JobType.INVALID);
- lps.setProperties( inputs, ExecType.SPARK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ lps.setProperties( inputs, t, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
}
public String toString() {
diff --git a/src/main/java/org/apache/sysml/lops/PartialAggregate.java b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
index fdcc5b08564..80623014d0d 100644
--- a/src/main/java/org/apache/sysml/lops/PartialAggregate.java
+++ b/src/main/java/org/apache/sysml/lops/PartialAggregate.java
@@ -303,7 +303,7 @@ public String getInstructions(String input1, String output)
sb.append( this.prepOutputOperand(output) );
//in case of spark, we also compile the optional aggregate flag into the instruction.
- if( getExecType() == ExecType.SPARK ) {
+ if( getExecType() == ExecType.SPARK || getExecType() == ExecType.FLINK) {
sb.append( OPERAND_DELIMITOR );
sb.append( _aggtype );
}
diff --git a/src/main/java/org/apache/sysml/lops/ReBlock.java b/src/main/java/org/apache/sysml/lops/ReBlock.java
index 5a62ed7ca1c..7fe8f358d03 100644
--- a/src/main/java/org/apache/sysml/lops/ReBlock.java
+++ b/src/main/java/org/apache/sysml/lops/ReBlock.java
@@ -67,6 +67,8 @@ public ReBlock(Lop input, Long rows_per_block, Long cols_per_block, DataType dt,
lps.setProperties( inputs, ExecType.MR, ExecLocation.MapAndReduce, breaksAlignment, aligner, definesMRJob );
else if(et == ExecType.SPARK)
lps.setProperties( inputs, ExecType.SPARK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
+ else if(et == ExecType.FLINK)
+ lps.setProperties( inputs, ExecType.FLINK, ExecLocation.ControlProgram, breaksAlignment, aligner, definesMRJob );
else
throw new LopsException("Incorrect execution type for Reblock:" + et);
}
@@ -106,7 +108,7 @@ public String getInstructions(int input_index, int output_index) throws LopsExce
@Override
public String getInstructions(String input1, String output) throws LopsException {
- if(getExecType() != ExecType.SPARK) {
+ if(getExecType() != ExecType.SPARK && getExecType() != ExecType.FLINK) {
throw new LopsException("The method getInstructions(String,String) for Reblock should be called only for Spark execution type");
}
@@ -167,4 +169,4 @@ private Format getChildFormat(Lop node) throws LopsException {
-}
\ No newline at end of file
+}
diff --git a/src/main/java/org/apache/sysml/lops/Unary.java b/src/main/java/org/apache/sysml/lops/Unary.java
index 8ef3e7b5066..2039def6b28 100644
--- a/src/main/java/org/apache/sysml/lops/Unary.java
+++ b/src/main/java/org/apache/sysml/lops/Unary.java
@@ -132,7 +132,7 @@ private void init(Lop input1, OperationTypes op, DataType dt, ValueType vt, Exec
{
//sanity check
if ( (op == OperationTypes.INVERSE || op == OperationTypes.CHOLESKY)
- && (et == ExecType.SPARK || et == ExecType.MR) ) {
+ && (et == ExecType.SPARK || et == ExecType.MR || et == ExecType.FLINK) ) {
throw new LopsException("Invalid exection type "+et.toString()+" for operation "+op.toString());
}
diff --git a/src/main/java/org/apache/sysml/lops/compile/Dag.java b/src/main/java/org/apache/sysml/lops/compile/Dag.java
index 2d24dadbd3b..a880c52c2a2 100644
--- a/src/main/java/org/apache/sysml/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysml/lops/compile/Dag.java
@@ -66,15 +66,11 @@
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
-import org.apache.sysml.runtime.instructions.CPInstructionParser;
-import org.apache.sysml.runtime.instructions.Instruction;
+import org.apache.sysml.runtime.instructions.*;
import org.apache.sysml.runtime.instructions.Instruction.INSTRUCTION_TYPE;
-import org.apache.sysml.runtime.instructions.InstructionParser;
-import org.apache.sysml.runtime.instructions.SPInstructionParser;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction.CPINSTRUCTION_TYPE;
-import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
@@ -1260,7 +1256,7 @@ private static void excludeRemoveInstruction(String varName, ArrayList=0)?ix:tmp.length()));
+ ret[index++] = tmp.substring(0,((ix>=0)?ix:tmp.length()));
}
return ret;
}
-
+
/**
- * Given an instruction string, this function strips-off the
- * execution type (CP or MR) and returns the remaining parts,
+ * Given an instruction string, this function strips-off the
+ * execution type (CP or MR) and returns the remaining parts,
* which include the opcode as well as the input and output operands.
* Each returned part will have the datatype and valuetype associated
* with the operand.
- *
+ *
* This function is invoked mainly for parsing CPInstructions.
- *
+ *
* @param str
* @return
*/
- public static String[] getInstructionPartsWithValueType( String str )
+ public static String[] getInstructionPartsWithValueType( String str )
{
//note: split required for empty tokens
String[] parts = str.split(Instruction.OPERAND_DELIM, -1);
@@ -208,57 +209,68 @@ public static String[] getInstructionPartsWithValueType( String str )
ret[0] = parts[1]; // opcode
for( int i=1; i 1);
}
-
+
/**
* Evaluates if at least one instruction of the given instruction set
* used the distributed cache; this call can also be used for individual
- * instructions.
- *
+ * instructions.
+ *
* @param str
* @return
*/
- public static boolean isDistributedCacheUsed(String str)
- {
+ public static boolean isDistributedCacheUsed(String str)
+ {
String[] parts = str.split(Instruction.INSTRUCTION_DELIM);
- for(String inst : parts)
+ for(String inst : parts)
{
String opcode = getOpCode(inst);
- if( opcode.equalsIgnoreCase(AppendM.OPCODE)
- || opcode.equalsIgnoreCase(MapMult.OPCODE)
- || opcode.equalsIgnoreCase(MapMultChain.OPCODE)
- || opcode.equalsIgnoreCase(PMMJ.OPCODE)
- || opcode.equalsIgnoreCase(UAggOuterChain.OPCODE)
- || opcode.equalsIgnoreCase(GroupedAggregateM.OPCODE)
- || isDistQuaternaryOpcode( opcode ) //multiple quaternary opcodes
- || BinaryM.isOpcode( opcode ) ) //multiple binary opcodes
+ if( opcode.equalsIgnoreCase(AppendM.OPCODE)
+ || opcode.equalsIgnoreCase(MapMult.OPCODE)
+ || opcode.equalsIgnoreCase(MapMultChain.OPCODE)
+ || opcode.equalsIgnoreCase(PMMJ.OPCODE)
+ || opcode.equalsIgnoreCase(UAggOuterChain.OPCODE)
+ || opcode.equalsIgnoreCase(GroupedAggregateM.OPCODE)
+ || isDistQuaternaryOpcode( opcode ) //multiple quaternary opcodes
+ || BinaryM.isOpcode( opcode ) ) //multiple binary opcodes
{
return true;
}
}
return false;
}
-
+
/**
- *
+ *
* @param opcode
* @return
*/
public static AggregateUnaryOperator parseBasicAggregateUnaryOperator(String opcode)
{
AggregateUnaryOperator aggun = null;
-
+
if ( opcode.equalsIgnoreCase("uak+") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uark+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uack+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
@@ -343,12 +355,12 @@ else if ( opcode.equalsIgnoreCase("uamean") ) {
// Mean
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uarmean") ) {
// RowMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOCOLUMNS);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uacmean") ) {
// ColMeans
AggregateOperator agg = new AggregateOperator(0, Mean.getMeanFnObject(), true, CorrectionLocationType.LASTTWOROWS);
@@ -378,12 +390,12 @@ else if ( opcode.equalsIgnoreCase("uacvar") ) {
else if ( opcode.equalsIgnoreCase("ua+") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uar+") ) {
// RowSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uac+") ) {
// ColSums
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
@@ -392,7 +404,7 @@ else if ( opcode.equalsIgnoreCase("uac+") ) {
else if ( opcode.equalsIgnoreCase("ua*") ) {
AggregateOperator agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uamax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
@@ -400,19 +412,19 @@ else if ( opcode.equalsIgnoreCase("uamax") ) {
else if ( opcode.equalsIgnoreCase("uamin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uatrace") ) {
AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uaktrace") ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceDiag.getReduceDiagFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uarmax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
- }
+ }
else if (opcode.equalsIgnoreCase("uarimax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("maxindex"), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
@@ -420,7 +432,7 @@ else if (opcode.equalsIgnoreCase("uarimax") ) {
else if ( opcode.equalsIgnoreCase("uarmin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
- }
+ }
else if (opcode.equalsIgnoreCase("uarimin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("minindex"), true, CorrectionLocationType.LASTCOLUMN);
aggun = new AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject());
@@ -428,17 +440,17 @@ else if (opcode.equalsIgnoreCase("uarimin") ) {
else if ( opcode.equalsIgnoreCase("uacmax") ) {
AggregateOperator agg = new AggregateOperator(-Double.MAX_VALUE, Builtin.getBuiltinFnObject("max"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("uacmin") ) {
AggregateOperator agg = new AggregateOperator(Double.MAX_VALUE, Builtin.getBuiltinFnObject("min"));
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
-
+
return aggun;
}
-
+
/**
- *
+ *
* @param opcode
* @param corrExists
* @param corrLoc
@@ -447,7 +459,7 @@ else if ( opcode.equalsIgnoreCase("uacmin") ) {
public static AggregateOperator parseAggregateOperator(String opcode, String corrExists, String corrLoc)
{
AggregateOperator agg = null;
-
+
if ( opcode.equalsIgnoreCase("ak+") || opcode.equalsIgnoreCase("aktrace") ) {
boolean lcorrExists = (corrExists==null) ? true : Boolean.parseBoolean(corrExists);
CorrectionLocationType lcorrLoc = (corrLoc==null) ? CorrectionLocationType.LASTCOLUMN : CorrectionLocationType.valueOf(corrLoc);
@@ -460,7 +472,7 @@ else if ( opcode.equalsIgnoreCase("asqk+") ) {
}
else if ( opcode.equalsIgnoreCase("a+") ) {
agg = new AggregateOperator(0, Plus.getPlusFnObject());
- }
+ }
else if ( opcode.equalsIgnoreCase("a*") ) {
agg = new AggregateOperator(1, Multiply.getMultiplyFnObject());
}
@@ -492,85 +504,85 @@ else if ( opcode.equalsIgnoreCase("avar") ) {
return agg;
}
-
+
/**
- *
+ *
* @param uop
* @return
*/
public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(UnaryOperator uop)
{
Builtin f = (Builtin)uop.fn;
-
- if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMSUM )
+
+ if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMSUM )
return parseCumulativeAggregateUnaryOperator("ucumack+") ;
- else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMPROD )
+ else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMPROD )
return parseCumulativeAggregateUnaryOperator("ucumac*") ;
- else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMIN )
+ else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMIN )
return parseCumulativeAggregateUnaryOperator("ucumacmin") ;
- else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMAX )
+ else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMAX )
return parseCumulativeAggregateUnaryOperator("ucumacmax" ) ;
-
+
throw new RuntimeException("Unsupported cumulative aggregate unary operator: "+f.getBuiltinFunctionCode());
}
-
+
/**
- *
+ *
* @param uop
* @return
*/
public static AggregateUnaryOperator parseBasicCumulativeAggregateUnaryOperator(UnaryOperator uop)
{
Builtin f = (Builtin)uop.fn;
-
- if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMSUM )
+
+ if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMSUM )
return parseBasicAggregateUnaryOperator("uack+") ;
- else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMPROD )
+ else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMPROD )
return parseBasicAggregateUnaryOperator("uac*") ;
- else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMIN )
+ else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMIN )
return parseBasicAggregateUnaryOperator("uacmin") ;
- else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMAX )
+ else if( f.getBuiltinFunctionCode()==BuiltinFunctionCode.CUMMAX )
return parseBasicAggregateUnaryOperator("uacmax" ) ;
-
+
throw new RuntimeException("Unsupported cumulative aggregate unary operator: "+f.getBuiltinFunctionCode());
}
-
+
/**
- *
+ *
* @param opcode
* @return
*/
public static AggregateUnaryOperator parseCumulativeAggregateUnaryOperator(String opcode)
{
AggregateUnaryOperator aggun = null;
- if( "ucumack+".equals(opcode) ) {
+ if( "ucumack+".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, KahanPlus.getKahanPlusFnObject(), true, CorrectionLocationType.LASTROW);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
- else if ( "ucumac*".equals(opcode) ) {
+ else if ( "ucumac*".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, Multiply.getMultiplyFnObject(), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
- else if ( "ucumacmin".equals(opcode) ) {
+ else if ( "ucumacmin".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("min"), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
- else if ( "ucumacmax".equals(opcode) ) {
+ else if ( "ucumacmax".equals(opcode) ) {
AggregateOperator agg = new AggregateOperator(0, Builtin.getBuiltinFnObject("max"), false, CorrectionLocationType.NONE);
aggun = new AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject());
}
-
+
return aggun;
}
-
+
/**
- *
+ *
* @param opcode
* @return
* @throws DMLRuntimeException
*/
- public static BinaryOperator parseBinaryOperator(String opcode)
- throws DMLRuntimeException
+ public static BinaryOperator parseBinaryOperator(String opcode)
+ throws DMLRuntimeException
{
if(opcode.equalsIgnoreCase("=="))
return new BinaryOperator(Equals.getEqualsFnObject());
@@ -596,7 +608,7 @@ else if(opcode.equalsIgnoreCase("*"))
return new BinaryOperator(Multiply.getMultiplyFnObject());
else if(opcode.equalsIgnoreCase("1-*"))
return new BinaryOperator(Minus1Multiply.getMinus1MultiplyFnObject());
- else if ( opcode.equalsIgnoreCase("*2") )
+ else if ( opcode.equalsIgnoreCase("*2") )
return new BinaryOperator(Multiply2.getMultiply2FnObject());
else if(opcode.equalsIgnoreCase("/"))
return new BinaryOperator(Divide.getDivideFnObject());
@@ -608,34 +620,34 @@ else if(opcode.equalsIgnoreCase("^"))
return new BinaryOperator(Power.getPowerFnObject());
else if ( opcode.equalsIgnoreCase("^2") )
return new BinaryOperator(Power2.getPower2FnObject());
- else if ( opcode.equalsIgnoreCase("max") )
+ else if ( opcode.equalsIgnoreCase("max") )
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
- else if ( opcode.equalsIgnoreCase("min") )
+ else if ( opcode.equalsIgnoreCase("min") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
-
+
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
-
+
/**
* scalar-matrix operator
- *
+ *
* @param opcode
* @param arg1IsScalar
* @return
* @throws DMLRuntimeException
*/
- public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar)
- throws DMLRuntimeException
+ public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar)
+ throws DMLRuntimeException
{
//for all runtimes that set constant dynamically (cp/spark)
double default_constant = 0;
-
+
return parseScalarBinaryOperator(opcode, arg1IsScalar, default_constant);
}
-
+
/**
* scalar-matrix operator
- *
+ *
* @param opcode
* @param arg1IsScalar
* @param constant
@@ -643,15 +655,15 @@ public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean ar
* @throws DMLRuntimeException
*/
public static ScalarOperator parseScalarBinaryOperator(String opcode, boolean arg1IsScalar, double constant)
- throws DMLRuntimeException
+ throws DMLRuntimeException
{
//commutative operators
- if ( opcode.equalsIgnoreCase("+") ){
- return new RightScalarOperator(Plus.getPlusFnObject(), constant);
+ if ( opcode.equalsIgnoreCase("+") ){
+ return new RightScalarOperator(Plus.getPlusFnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("*") ) {
return new RightScalarOperator(Multiply.getMultiplyFnObject(), constant);
- }
+ }
//non-commutative operators
else if ( opcode.equalsIgnoreCase("-") ) {
if(arg1IsScalar)
@@ -666,7 +678,7 @@ else if ( opcode.equalsIgnoreCase("/") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Divide.getDivideFnObject(), constant);
else return new RightScalarOperator(Divide.getDivideFnObject(), constant);
- }
+ }
else if ( opcode.equalsIgnoreCase("%%") ) {
if(arg1IsScalar)
return new LeftScalarOperator(Modulus.getModulusFnObject(), constant);
@@ -723,38 +735,38 @@ else if ( opcode.equalsIgnoreCase("!=") ) {
return new LeftScalarOperator(NotEquals.getNotEqualsFnObject(), constant);
return new RightScalarOperator(NotEquals.getNotEqualsFnObject(), constant);
}
-
+
//operations that only exist for performance purposes (all unary or commutative operators)
else if ( opcode.equalsIgnoreCase("*2") ) {
return new RightScalarOperator(Multiply2.getMultiply2FnObject(), constant);
- }
+ }
else if ( opcode.equalsIgnoreCase("^2") ){
return new RightScalarOperator(Power2.getPower2FnObject(), constant);
}
else if ( opcode.equalsIgnoreCase("1-*") ) {
return new RightScalarOperator(Minus1Multiply.getMinus1MultiplyFnObject(), constant);
}
-
+
//operations that only exist in mr
else if ( opcode.equalsIgnoreCase("s-r") ) {
return new LeftScalarOperator(Minus.getMinusFnObject(), constant);
- }
+ }
else if ( opcode.equalsIgnoreCase("so") ) {
return new LeftScalarOperator(Divide.getDivideFnObject(), constant);
}
-
+
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
- }
+ }
/**
- *
+ *
* @param opcode
* @return
* @throws DMLRuntimeException
*/
- public static BinaryOperator parseExtendedBinaryOperator(String opcode)
- throws DMLRuntimeException
+ public static BinaryOperator parseExtendedBinaryOperator(String opcode)
+ throws DMLRuntimeException
{
if(opcode.equalsIgnoreCase("==") || opcode.equalsIgnoreCase("map=="))
return new BinaryOperator(Equals.getEqualsFnObject());
@@ -780,7 +792,7 @@ else if(opcode.equalsIgnoreCase("*") || opcode.equalsIgnoreCase("map*"))
return new BinaryOperator(Multiply.getMultiplyFnObject());
else if(opcode.equalsIgnoreCase("1-*") || opcode.equalsIgnoreCase("map1-*"))
return new BinaryOperator(Minus1Multiply.getMinus1MultiplyFnObject());
- else if ( opcode.equalsIgnoreCase("*2") )
+ else if ( opcode.equalsIgnoreCase("*2") )
return new BinaryOperator(Multiply2.getMultiply2FnObject());
else if(opcode.equalsIgnoreCase("/") || opcode.equalsIgnoreCase("map/"))
return new BinaryOperator(Divide.getDivideFnObject());
@@ -792,17 +804,17 @@ else if(opcode.equalsIgnoreCase("^") || opcode.equalsIgnoreCase("map^"))
return new BinaryOperator(Power.getPowerFnObject());
else if ( opcode.equalsIgnoreCase("^2") )
return new BinaryOperator(Power2.getPower2FnObject());
- else if ( opcode.equalsIgnoreCase("max") || opcode.equalsIgnoreCase("mapmax") )
+ else if ( opcode.equalsIgnoreCase("max") || opcode.equalsIgnoreCase("mapmax") )
return new BinaryOperator(Builtin.getBuiltinFnObject("max"));
- else if ( opcode.equalsIgnoreCase("min") || opcode.equalsIgnoreCase("mapmin") )
+ else if ( opcode.equalsIgnoreCase("min") || opcode.equalsIgnoreCase("mapmin") )
return new BinaryOperator(Builtin.getBuiltinFnObject("min"));
-
+
throw new DMLRuntimeException("Unknown binary opcode " + opcode);
}
-
-
+
+
/**
- *
+ *
* @param opcode
* @return
*/
@@ -820,7 +832,7 @@ else if ( opcode.equalsIgnoreCase("ua+") || opcode.equalsIgnoreCase("uar+") || o
return "a+";
else if ( opcode.equalsIgnoreCase("ua*") )
return "a*";
- else if ( opcode.equalsIgnoreCase("uatrace") || opcode.equalsIgnoreCase("uaktrace") )
+ else if ( opcode.equalsIgnoreCase("uatrace") || opcode.equalsIgnoreCase("uaktrace") )
return "aktrace";
else if ( opcode.equalsIgnoreCase("uamax") || opcode.equalsIgnoreCase("uarmax") || opcode.equalsIgnoreCase("uacmax") )
return "amax";
@@ -830,12 +842,12 @@ else if (opcode.equalsIgnoreCase("uarimax") )
return "arimax";
else if (opcode.equalsIgnoreCase("uarimin") )
return "arimin";
-
+
return null;
}
/**
- *
+ *
* @param opcode
* @return
*/
@@ -857,26 +869,26 @@ else if ( opcode.equalsIgnoreCase("uacvar") )
return CorrectionLocationType.LASTFOURROWS;
else if (opcode.equalsIgnoreCase("uarimax") || opcode.equalsIgnoreCase("uarimin") )
return CorrectionLocationType.LASTCOLUMN;
-
+
return CorrectionLocationType.NONE;
}
/**
- *
+ *
* @param opcode
* @return
*/
- public static boolean isDistQuaternaryOpcode(String opcode)
+ public static boolean isDistQuaternaryOpcode(String opcode)
{
return WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) //mapwsloss
- || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) //redwsloss
- || WeightedSigmoid.OPCODE.equalsIgnoreCase(opcode) //mapwsigmoid
- || WeightedSigmoidR.OPCODE.equalsIgnoreCase(opcode) //redwsigmoid
- || WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //mapwdivmm
- || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) //redwdivmm
- || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(opcode) //mapwcemm
- || WeightedCrossEntropyR.OPCODE.equalsIgnoreCase(opcode) //redwcemm
- || WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //mapwumm
- || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); //redwumm
+ || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode) //redwsloss
+ || WeightedSigmoid.OPCODE.equalsIgnoreCase(opcode) //mapwsigmoid
+ || WeightedSigmoidR.OPCODE.equalsIgnoreCase(opcode) //redwsigmoid
+ || WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) //mapwdivmm
+ || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode) //redwdivmm
+ || WeightedCrossEntropy.OPCODE.equalsIgnoreCase(opcode) //mapwcemm
+ || WeightedCrossEntropyR.OPCODE.equalsIgnoreCase(opcode) //redwcemm
+ || WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) //mapwumm
+ || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); //redwumm
}
}
diff --git a/src/main/java/org/apache/sysml/utils/Explain.java b/src/main/java/org/apache/sysml/utils/Explain.java
index 6b266ea0cf0..865bdeb0d14 100644
--- a/src/main/java/org/apache/sysml/utils/Explain.java
+++ b/src/main/java/org/apache/sysml/utils/Explain.java
@@ -63,6 +63,7 @@
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.MRJobInstruction;
import org.apache.sysml.runtime.instructions.cp.CPInstruction;
+import org.apache.sysml.runtime.instructions.flink.FLInstruction;
import org.apache.sysml.runtime.instructions.spark.CSVReblockSPInstruction;
import org.apache.sysml.runtime.instructions.spark.ReblockSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
@@ -915,7 +916,7 @@ private static String explainGenericInstruction( Instruction inst, int level )
String tmp = null;
if( inst instanceof MRJobInstruction )
tmp = explainMRJobInstruction((MRJobInstruction)inst, level+1);
- else if ( inst instanceof SPInstruction || inst instanceof CPInstruction)
+ else if ( inst instanceof SPInstruction || inst instanceof CPInstruction || inst instanceof FLInstruction)
tmp = inst.toString();
if( REPLACE_SPECIAL_CHARACTERS ){
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java b/src/main/java/org/apache/sysml/utils/Statistics.java
index edb34939a3b..621125a4d04 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -61,6 +61,10 @@ public class Statistics
private static int iNoOfExecutedSPInst = 0;
private static int iNoOfCompiledSPInst = 0;
+ // number of compiled/executed FL instructions
+ private static int iNoOfExecutedFLInst = 0;
+ private static int iNoOfCompiledFLInst = 0;
+
//JVM stats
private static long jitCompileTime = 0; //in milli sec
private static long jvmGCTime = 0; //in milli sec
@@ -164,6 +168,30 @@ public static synchronized int getNoOfCompiledSPInst() {
return iNoOfCompiledSPInst;
}
+ public static synchronized void setNoOfExecutedFLInst(int numJobs) {
+ iNoOfExecutedFLInst = numJobs;
+ }
+
+ public static synchronized int getNoOfExecutedFLInst() {
+ return iNoOfExecutedFLInst;
+ }
+
+ public static synchronized void incrementNoOfExecutedFLInst() {
+ iNoOfExecutedFLInst ++;
+ }
+
+ public static synchronized void decrementNoOfExecutedFLInst() {
+ iNoOfExecutedFLInst --;
+ }
+
+ public static synchronized void setNoOfCompiledFLInst(int numJobs) {
+ iNoOfCompiledFLInst = numJobs;
+ }
+
+ public static synchronized int getNoOfCompiledFLInst() {
+ return iNoOfCompiledFLInst;
+ }
+
public static synchronized void incrementNoOfCompiledSPInst() {
iNoOfCompiledSPInst ++;
}
@@ -220,11 +248,18 @@ public static void resetNoOfExecutedJobs( int count )
if(OptimizerUtils.isSparkExecutionMode()) {
setNoOfExecutedSPInst(count);
- setNoOfExecutedMRJobs(0);
+ setNoOfExecutedMRJobs(0);
+ setNoOfExecutedFLInst(0);
+ }
+ else if (OptimizerUtils.isFlinkExecutionMode()) {
+ setNoOfExecutedFLInst(count);
+ setNoOfExecutedMRJobs(0);
+ setNoOfExecutedSPInst(0);
}
else {
setNoOfExecutedMRJobs(count);
setNoOfExecutedSPInst(0);
+ setNoOfExecutedFLInst(0);
}
}
@@ -591,6 +626,11 @@ public static String display()
sb.append("Number of compiled Spark inst:\t" + getNoOfCompiledSPInst() + ".\n");
sb.append("Number of executed Spark inst:\t" + getNoOfExecutedSPInst() + ".\n");
}
+ else if ( OptimizerUtils.isFlinkExecutionMode() ) {
+ if( DMLScript.STATISTICS ) //moved into stats on Shiv's request
+ sb.append("Number of compiled Flink inst:\t" + getNoOfCompiledFLInst() + ".\n");
+ sb.append("Number of executed Flink inst:\t" + getNoOfExecutedFLInst() + ".\n");
+ }
else {
if( DMLScript.STATISTICS ) //moved into stats on Shiv's request
sb.append("Number of compiled MR Jobs:\t" + getNoOfCompiledMRJobs() + ".\n");
diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index eb8719747c0..5b6a3a0f13f 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -1144,6 +1144,10 @@ else if (rtplatform == RUNTIME_PLATFORM.SPARK)
args.add("spark");
else if (rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)
args.add("hybrid_spark");
+ else if (rtplatform == RUNTIME_PLATFORM.FLINK)
+ args.add("flink");
+ else if (rtplatform == RUNTIME_PLATFORM.HYBRID_FLINK)
+ args.add("hybrid_flink");
else {
throw new RuntimeException("Unknown runtime platform: " + rtplatform);
}
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullDistributedMatrixMultiplicationTest.java b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullDistributedMatrixMultiplicationTest.java
index 933f36f6e6c..48ce3823614 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullDistributedMatrixMultiplicationTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/binary/matrix_full_other/FullDistributedMatrixMultiplicationTest.java
@@ -220,7 +220,30 @@ public void testSparseSparseRmmSpark()
{
runDistributedMatrixMatrixMultiplicationTest(true, true, MMultMethod.RMM, ExecType.SPARK);
}
-
+
+ @Test
+ public void testDenseDenseMapmmFlink()
+ {
+ runDistributedMatrixMatrixMultiplicationTest(false, false, MMultMethod.MAPMM_R, ExecType.FLINK);
+ }
+
+ @Test
+ public void testDenseSparseMapmmFlink()
+ {
+ runDistributedMatrixMatrixMultiplicationTest(false, true, MMultMethod.MAPMM_R, ExecType.FLINK);
+ }
+
+ @Test
+ public void testSparseDenseMapmmFlink()
+ {
+ runDistributedMatrixMatrixMultiplicationTest(true, false, MMultMethod.MAPMM_R, ExecType.FLINK);
+ }
+
+ @Test
+ public void testSparseSparseMapmmFlink()
+ {
+ runDistributedMatrixMatrixMultiplicationTest(true, true, MMultMethod.MAPMM_R, ExecType.FLINK);
+ }
/**
*
@@ -235,6 +258,7 @@ private void runDistributedMatrixMatrixMultiplicationTest( boolean sparseM1, boo
switch( instType ){
case MR: rtplatform = RUNTIME_PLATFORM.HADOOP; break;
case SPARK: rtplatform = RUNTIME_PLATFORM.SPARK; break;
+ case FLINK: rtplatform = RUNTIME_PLATFORM.FLINK; break;
default: rtplatform = RUNTIME_PLATFORM.HYBRID; break;
}