Skip to content

Commit

Permalink
add restful api server
Browse files Browse the repository at this point in the history
  • Loading branch information
ycycse committed Sep 17, 2024
1 parent 86f8333 commit d898cdb
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 33 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.engineplugin.spark;

import org.apache.linkis.engineplugin.spark.utils.DataFrameResponse;
import org.apache.linkis.engineplugin.spark.utils.DirectPushCache;
import org.apache.linkis.server.Message;

import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

import javax.servlet.http.HttpServletRequest;

import java.util.Map;

import io.swagger.annotations.Api;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Api(tags = "DirectPush")
@RestController
@RequestMapping(path = "directpush")
public class DirectPushRestfulApi {
private static final Logger logger = LoggerFactory.getLogger(DirectPushRestfulApi.class);

@RequestMapping(path = "pull", method = RequestMethod.POST)
public Message getDirectPushResult(
HttpServletRequest req, @RequestBody Map<String, Object> json) {
Message message = null;
try {
String taskId = (String) json.getOrDefault("taskId", null);
if (taskId == null) {
message = Message.error("taskId is null");
return message;
}
int fetchSize = (int) json.getOrDefault("fetchSize", 1000);

DataFrameResponse response = DirectPushCache.fetchResultSetOfDataFrame(taskId, fetchSize);
if (response.dataFrame() == null) {
message = Message.error("No result found for taskId: " + taskId);
} else {
message =
Message.ok()
.data("data", response.dataFrame())
.data("hasMoreData", response.hasMoreData());
}
} catch (Exception e) {
logger.error("Failed to get direct push result", e);
message = Message.error("Failed to get direct push result: " + e.getMessage());
}
return message;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
package org.apache.linkis.engineplugin.spark.executor

import org.apache.linkis.common.utils.Utils
import org.apache.linkis.engineconn.common.conf.{EngineConnConf, EngineConnConstant}
import org.apache.linkis.engineconn.computation.executor.execute.EngineExecutionContext
import org.apache.linkis.engineplugin.spark.common.{Kind, SparkSQL}
import org.apache.linkis.engineplugin.spark.config.SparkConfiguration
import org.apache.linkis.engineplugin.spark.entity.SparkEngineSession
import org.apache.linkis.engineplugin.spark.utils.{ArrowUtils, EngineUtils}
import org.apache.linkis.engineplugin.spark.utils.{ArrowUtils, DirectPushCache, EngineUtils}
import org.apache.linkis.governance.common.constant.job.JobRequestConstants
import org.apache.linkis.governance.common.paser.SQLCodeParser
import org.apache.linkis.scheduler.executer._
Expand All @@ -32,19 +31,10 @@ import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.sql.DataFrame

import java.lang.reflect.InvocationTargetException
import java.util.concurrent.TimeUnit

import com.google.common.cache.{Cache, CacheBuilder}

class SparkSqlExecutor(sparkEngineSession: SparkEngineSession, id: Long)
extends SparkEngineConnExecutor(sparkEngineSession.sparkContext, id) {

private val resultSetIteratorCache: Cache[String, DataFrame] = CacheBuilder
.newBuilder()
.expireAfterAccess(EngineConnConf.ENGINE_TASK_EXPIRE_TIME.getValue, TimeUnit.MILLISECONDS)
.maximumSize(EngineConnConstant.MAX_TASK_NUM)
.build()

override def init(): Unit = {

setCodeParser(new SQLCodeParser)
Expand All @@ -57,34 +47,30 @@ class SparkSqlExecutor(sparkEngineSession: SparkEngineSession, id: Long)
// Only used in the scenario of direct pushing, dataFrame won't be fetched at a time,
// It will cache the lazy dataFrame in memory and return the result when client .
private def submitResultSetIterator(taskId: String, df: DataFrame): Unit = {
if (resultSetIteratorCache.getIfPresent(taskId) == null) {
resultSetIteratorCache.put(taskId, df)
if (!DirectPushCache.isTaskCached(taskId)) {
DirectPushCache.submitExecuteResult(taskId, df)
} else {
logger.error(s"Task $taskId already exists in resultSet cache.")
}
}

override def isFetchMethodOfDirectPush(taskId: String): Boolean = {
resultSetIteratorCache.getIfPresent(taskId) != null
DirectPushCache.isTaskCached(taskId)
}

// This method is not idempotent. After fetching a result set of size fetchSize each time, the corresponding results will be removed from the cache.
override def fetchMoreResultSet(taskId: String, fetchSize: Int): FetchResultResponse = {
val df = resultSetIteratorCache.getIfPresent(taskId)
if (df == null) {
throw new IllegalAccessException(s"Task $taskId not exists in resultSet cache.")
} else {
val batchDf = df.limit(fetchSize)
if (batchDf.count() < fetchSize) {
// All the data in df has been consumed.
val dataFrameResponse = DirectPushCache.fetchResultSetOfDataFrame(taskId, fetchSize)
if (dataFrameResponse.dataFrame != null) {
if (!dataFrameResponse.hasMoreData) {
succeedTasks.increase()
resultSetIteratorCache.invalidate(taskId)
FetchResultResponse(hasMoreData = false, null)
} else {
// Update df with consumed one.
resultSetIteratorCache.put(taskId, df.except(batchDf))
FetchResultResponse(hasMoreData = true, ArrowUtils.toArrow(batchDf))
}
FetchResultResponse(
hasMoreData = dataFrameResponse.hasMoreData,
ArrowUtils.toArrow(dataFrameResponse.dataFrame)
)
}
FetchResultResponse(hasMoreData = dataFrameResponse.hasMoreData, null)
}

override protected def runCode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ import org.apache.arrow.vector.{
VarCharVector,
VectorSchemaRoot
}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, LongType, StringType}

import java.io.ByteArrayOutputStream
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.jdk.CollectionConverters.asScalaBufferConverter

object ArrowUtils {

Expand Down Expand Up @@ -111,4 +113,35 @@ object ArrowUtils {
(root, fieldVectors)
}

/**
* Converts Arrow byte array results into a formatted string with line breaks.
*/
def arrowToString(arrowBytes: Array[Byte]): String = {
val allocator = new RootAllocator(Long.MaxValue)
val byteArrayInputStream = new ByteArrayInputStream(arrowBytes)
val streamReader = new ArrowStreamReader(byteArrayInputStream, allocator)

val stringBuilder = new StringBuilder

try {
val root: VectorSchemaRoot = streamReader.getVectorSchemaRoot

while (streamReader.loadNextBatch()) {
for (i <- 0 until root.getRowCount) {
val row = root.getFieldVectors.asScala
.map { vector =>
vector.getObject(i).toString
}
.mkString(", ")
stringBuilder.append(row).append("\n")
}
}
} finally {
streamReader.close()
allocator.close()
}

stringBuilder.toString()
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.linkis.engineplugin.spark.utils

import org.apache.linkis.engineconn.common.conf.{EngineConnConf, EngineConnConstant}

import org.apache.spark.sql.DataFrame

import java.util.concurrent.TimeUnit

import com.google.common.cache.{Cache, CacheBuilder}

case class DataFrameResponse(dataFrame: DataFrame, hasMoreData: Boolean)

object DirectPushCache {

private val resultSet: Cache[String, DataFrame] = CacheBuilder
.newBuilder()
.expireAfterAccess(EngineConnConf.ENGINE_TASK_EXPIRE_TIME.getValue, TimeUnit.MILLISECONDS)
.maximumSize(EngineConnConstant.MAX_TASK_NUM)
.build()

// This method is not idempotent. After fetching a result set of size fetchSize each time, the corresponding results will be removed from the cache.
def fetchResultSetOfDataFrame(taskId: String, fetchSize: Int): DataFrameResponse = {
val df = DirectPushCache.resultSet.getIfPresent(taskId)
if (df == null) {
throw new IllegalAccessException(s"Task $taskId not exists in resultSet cache.")
} else {
val batchDf = df.limit(fetchSize)
if (batchDf.count() < fetchSize) {
// All the data in df has been consumed.
DirectPushCache.resultSet.invalidate(taskId)
DataFrameResponse(batchDf, hasMoreData = false)
} else {
// Update df with consumed one.
DirectPushCache.resultSet.put(taskId, df.except(batchDf))
DataFrameResponse(batchDf, hasMoreData = true)
}
}
}

def isTaskCached(taskId: String): Boolean = {
DirectPushCache.resultSet.getIfPresent(taskId) != null
}

def submitExecuteResult(taskId: String, df: DataFrame): Unit = {
DirectPushCache.resultSet.put(taskId, df)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@

package org.apache.linkis.engineplugin.spark.executor

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.linkis.DataWorkCloudApplication
import org.apache.linkis.common.conf.DWCArgumentsParser
import org.apache.linkis.engineplugin.spark.utils.ArrowUtils

import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.spark.sql.SparkSession
import org.junit.jupiter.api.{Assertions, Test}

import java.io.ByteArrayInputStream

import scala.collection.mutable

import org.junit.jupiter.api.{Assertions, Test}

class TestArrowUtil {

def initService(port: String): Unit = {
Expand Down

0 comments on commit d898cdb

Please sign in to comment.