Skip to content

Commit

Permalink
Merge pull request #77 from alibaba/gpu
Browse files Browse the repository at this point in the history
Gpu
  • Loading branch information
wuchaochen1 authored Nov 26, 2020
2 parents f5985d2 + 02393a5 commit aa84180
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flink-ml-examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner-blink_2.11</artifactId>
<artifactId>flink-table-planner-blink_${scala.major.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
8 changes: 8 additions & 0 deletions flink-ml-framework/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@
<groupId>org.apache.curator</groupId>
<artifactId>curator-test</artifactId>
</exclusion>
<exclusion>
<artifactId>commons-cli</artifactId>
<groupId>commons-cli</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
Expand All @@ -140,6 +144,10 @@
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</exclusion>
<exclusion>
<artifactId>commons-cli</artifactId>
<groupId>commons-cli</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,6 @@ public class MLConstants {
public static final String FAILOVER_RESTART_INDIVIDUAL_STRATEGY = "individual";
public static final String FAILOVER_STRATEGY_DEFAULT = FAILOVER_RESTART_ALL_STRATEGY;
public static final String PYTHON_VERSION = "python.version";

public static final String GPU_INFO = "gpu_info";
}
2 changes: 1 addition & 1 deletion flink-ml-lib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner-blink_2.11</artifactId>
<artifactId>flink-table-planner-blink_${scala.major.version}</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public MLMapFunction(ExecutionMode mode, BaseRole role, MLConfig config, TypeInf
* @throws Exception
*/
public void open(RuntimeContext runtimeContext) throws Exception {
ResourcesUtils.parseGpuInfo(runtimeContext, config);
mlContext = new MLContext(mode, config, role.name(), runtimeContext.getIndexOfThisSubtask(),
config.getEnvPath(), null);
PythonFileUtil.preparePythonFilesForExec(runtimeContext, mlContext);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.alibaba.flink.ml.operator.ops;

import org.apache.flink.api.common.externalresource.ExternalResourceInfo;
import org.apache.flink.api.common.functions.RuntimeContext;

import com.alibaba.flink.ml.cluster.MLConfig;
import com.alibaba.flink.ml.util.MLConstants;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;

public class ResourcesUtils {

public static void parseGpuInfo(RuntimeContext runtimeContext, MLConfig mlConfig) {
Set<ExternalResourceInfo> gpuInfo = runtimeContext.getExternalResourceInfos("gpu");
if (gpuInfo != null && gpuInfo.size() >0) {
List<String> indexList = new ArrayList<>();
for (ExternalResourceInfo gpu : gpuInfo) {
if (gpu.getProperty("index").isPresent()) {
indexList.add(gpu.getProperty("index").get());
}
}
Collections.sort(indexList);
String gpuStr = String.join(",", indexList);
mlConfig.getProperties().put(MLConstants.GPU_INFO, gpuStr);
}else {
mlConfig.getProperties().put(MLConstants.GPU_INFO, "");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.alibaba.flink.ml.cluster.rpc.NodeServer;
import com.alibaba.flink.ml.data.DataExchange;
import com.alibaba.flink.ml.operator.hook.FlinkOpHookManager;
import com.alibaba.flink.ml.operator.ops.ResourcesUtils;
import com.alibaba.flink.ml.operator.util.ColumnInfos;
import com.alibaba.flink.ml.operator.util.PythonFileUtil;
import com.alibaba.flink.ml.cluster.role.AMRole;
Expand All @@ -34,7 +35,6 @@
import org.apache.flink.api.common.io.statistics.BaseStatistics;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.io.InputSplit;
import org.apache.flink.core.io.InputSplitAssigner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -122,6 +122,7 @@ public InputSplitAssigner getInputSplitAssigner(MLInputSplit[] inputSplits) {
*/
@Override
public void open(MLInputSplit split) throws IOException {
ResourcesUtils.parseGpuInfo(getRuntimeContext(), mlConfig);
mlContext = new MLContext(mode, mlConfig, role.name(), split.getSplitNumber(),
mlConfig.getEnvPath(), ColumnInfos.dummy().getNameToTypeMap());

Expand Down
4 changes: 2 additions & 2 deletions flink-ml-pytorch/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-api-java-bridge_2.11</artifactId>
<artifactId>flink-table-api-java-bridge_${scala.major.version}</artifactId>
</dependency>

<dependency>
Expand Down Expand Up @@ -64,7 +64,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner-blink_2.11</artifactId>
<artifactId>flink-table-planner-blink_${scala.major.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
11 changes: 8 additions & 3 deletions flink-ml-tensorflow/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,21 @@
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow-hadoop</artifactId>
<exclusions>
<exclusion>
<artifactId>commons-cli</artifactId>
<groupId>commons-cli</groupId>
</exclusion>
</exclusions>
</dependency>

<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-api-java</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-api-java-bridge_2.11</artifactId>
<artifactId>flink-table-api-java-bridge_${scala.major.version}</artifactId>
</dependency>

<dependency>
Expand All @@ -77,7 +82,7 @@
</dependency>
<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-table-planner-blink_2.11</artifactId>
<artifactId>flink-table-planner-blink_${scala.major.version}</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,6 @@ class MLCONSTANTS(object):
FAILOVER_RESTART_INDIVIDUAL_STRATEGY = str(ml_constants.FAILOVER_RESTART_INDIVIDUAL_STRATEGY)
FAILOVER_STRATEGY_DEFAULT = str(ml_constants.FAILOVER_STRATEGY_DEFAULT)

PYTHON_VERSION = str(ml_constants.PYTHON_VERSION)
PYTHON_VERSION = str(ml_constants.PYTHON_VERSION)

GPU_INFO = str(ml_constants.GPU_INFO)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
queue_library = os.path.join(resource_loader.get_data_files_path(),
"libflink_ops.so")
print(queue_library)
flink_ops = tf.load_op_library(queue_library)
try:
flink_ops = tf.load_op_library(queue_library)
except Exception as e:
flink_ops = tf.load_op_library(queue_library)
print("load libflink_ops.so success")


Expand Down
2 changes: 2 additions & 0 deletions flink-ml-tensorflow/src/test/python/add.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@ def map_func(context):
job_name = tf_context.get_role_name()
index = tf_context.get_index()
cluster_json = tf_context.get_tf_cluster()
gpu_info = tf_context.get_property("gpu_info")
print ("cluster:" + str(cluster_json))
print ("job name:" + job_name)
print ("current index:" + str(index))
print ("gpu info: " + gpu_info)
sys.stdout.flush()
cluster = tf.train.ClusterSpec(cluster=cluster_json)
server = tf.train.Server(cluster, job_name=job_name, task_index=index)
Expand Down
12 changes: 12 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,23 @@
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>${hadoop.version}</version>
<exclusions>
<exclusion>
<artifactId>commons-cli</artifactId>
<groupId>commons-cli</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-hdfs</artifactId>
<version>${hadoop.version}</version>
<exclusions>
<exclusion>
<artifactId>commons-cli</artifactId>
<groupId>commons-cli</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.maven</groupId>
Expand Down

0 comments on commit aa84180

Please sign in to comment.