|
| 1 | +package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; |
| 2 | + |
| 3 | +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; |
| 4 | +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; |
| 5 | +import org.apache.iotdb.commons.client.IClientManager; |
| 6 | +import org.apache.iotdb.commons.exception.IoTDBRuntimeException; |
| 7 | +import org.apache.iotdb.db.exception.sql.SemanticException; |
| 8 | +import org.apache.iotdb.db.protocol.client.an.AINodeClient; |
| 9 | +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; |
| 10 | +import org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender; |
| 11 | +import org.apache.iotdb.rpc.TSStatusCode; |
| 12 | +import org.apache.iotdb.udf.api.exception.UDFException; |
| 13 | +import org.apache.iotdb.udf.api.relational.TableFunction; |
| 14 | +import org.apache.iotdb.udf.api.relational.access.Record; |
| 15 | +import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis; |
| 16 | +import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle; |
| 17 | +import org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider; |
| 18 | +import org.apache.iotdb.udf.api.relational.table.argument.Argument; |
| 19 | +import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema; |
| 20 | +import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument; |
| 21 | +import org.apache.iotdb.udf.api.relational.table.argument.TableArgument; |
| 22 | +import org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor; |
| 23 | +import org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification; |
| 24 | +import org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification; |
| 25 | +import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification; |
| 26 | +import org.apache.iotdb.udf.api.type.Type; |
| 27 | + |
| 28 | +import org.apache.tsfile.block.column.Column; |
| 29 | +import org.apache.tsfile.block.column.ColumnBuilder; |
| 30 | +import org.apache.tsfile.enums.TSDataType; |
| 31 | +import org.apache.tsfile.read.common.block.TsBlock; |
| 32 | +import org.apache.tsfile.read.common.block.TsBlockBuilder; |
| 33 | +import org.apache.tsfile.read.common.block.column.TsBlockSerde; |
| 34 | +import org.apache.tsfile.utils.PublicBAOS; |
| 35 | +import org.apache.tsfile.utils.ReadWriteIOUtils; |
| 36 | + |
| 37 | +import java.io.DataOutputStream; |
| 38 | +import java.io.IOException; |
| 39 | +import java.nio.ByteBuffer; |
| 40 | +import java.util.ArrayList; |
| 41 | +import java.util.Arrays; |
| 42 | +import java.util.Collections; |
| 43 | +import java.util.HashSet; |
| 44 | +import java.util.LinkedList; |
| 45 | +import java.util.List; |
| 46 | +import java.util.Locale; |
| 47 | +import java.util.Map; |
| 48 | +import java.util.Objects; |
| 49 | +import java.util.Optional; |
| 50 | +import java.util.Set; |
| 51 | +import java.util.stream.Collectors; |
| 52 | + |
| 53 | +import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex; |
| 54 | +import static org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender; |
| 55 | +import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; |
| 56 | + |
| 57 | +public class ClassifyTableFunction implements TableFunction { |
| 58 | + |
| 59 | + public static class ClassifyTableFunctionHandle implements TableFunctionHandle { |
| 60 | + String modelId; |
| 61 | + int maxInputLength; |
| 62 | + List<Type> inputColumnTypes; |
| 63 | + |
| 64 | + public ClassifyTableFunctionHandle() {} |
| 65 | + |
| 66 | + public ClassifyTableFunctionHandle( |
| 67 | + String modelId, int maxInputLength, List<Type> inputColumnTypes) { |
| 68 | + this.modelId = modelId; |
| 69 | + this.maxInputLength = maxInputLength; |
| 70 | + this.inputColumnTypes = inputColumnTypes; |
| 71 | + } |
| 72 | + |
| 73 | + @Override |
| 74 | + public byte[] serialize() { |
| 75 | + try (PublicBAOS publicBAOS = new PublicBAOS(); |
| 76 | + DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { |
| 77 | + ReadWriteIOUtils.write(modelId, outputStream); |
| 78 | + ReadWriteIOUtils.write(maxInputLength, outputStream); |
| 79 | + ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream); |
| 80 | + for (Type type : inputColumnTypes) { |
| 81 | + ReadWriteIOUtils.write(type.getType(), outputStream); |
| 82 | + } |
| 83 | + outputStream.flush(); |
| 84 | + return publicBAOS.toByteArray(); |
| 85 | + } catch (IOException e) { |
| 86 | + throw new IoTDBRuntimeException( |
| 87 | + String.format( |
| 88 | + "Error occurred while serializing ForecastTableFunctionHandle: %s", e.getMessage()), |
| 89 | + TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + @Override |
| 94 | + public void deserialize(byte[] bytes) { |
| 95 | + ByteBuffer buffer = ByteBuffer.wrap(bytes); |
| 96 | + this.modelId = ReadWriteIOUtils.readString(buffer); |
| 97 | + this.maxInputLength = ReadWriteIOUtils.readInt(buffer); |
| 98 | + int size = ReadWriteIOUtils.readInt(buffer); |
| 99 | + this.inputColumnTypes = new ArrayList<>(size); |
| 100 | + for (int i = 0; i < size; i++) { |
| 101 | + inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer))); |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + @Override |
| 106 | + public boolean equals(Object o) { |
| 107 | + if (this == o) return true; |
| 108 | + if (o == null || getClass() != o.getClass()) return false; |
| 109 | + ClassifyTableFunctionHandle that = (ClassifyTableFunctionHandle) o; |
| 110 | + return maxInputLength == that.maxInputLength |
| 111 | + && Objects.equals(modelId, that.modelId) |
| 112 | + && Objects.equals(inputColumnTypes, that.inputColumnTypes); |
| 113 | + } |
| 114 | + |
| 115 | + @Override |
| 116 | + public int hashCode() { |
| 117 | + return Objects.hash(modelId, maxInputLength, inputColumnTypes); |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + private static final String INPUT_PARAMETER_NAME = "INPUT"; |
| 122 | + private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID"; |
| 123 | + public static final String TIMECOL_PARAMETER_NAME = "TIMECOL"; |
| 124 | + private static final String DEFAULT_TIME_COL = "time"; |
| 125 | + private static final String DEFAULT_OUTPUT_COLUMN_NAME = "category"; |
| 126 | + private static final int MAX_INPUT_LENGTH = 2880; |
| 127 | + |
| 128 | + private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>(); |
| 129 | + |
| 130 | + static { |
| 131 | + ALLOWED_INPUT_TYPES.add(Type.INT32); |
| 132 | + ALLOWED_INPUT_TYPES.add(Type.INT64); |
| 133 | + ALLOWED_INPUT_TYPES.add(Type.FLOAT); |
| 134 | + ALLOWED_INPUT_TYPES.add(Type.DOUBLE); |
| 135 | + } |
| 136 | + |
| 137 | + @Override |
| 138 | + public List<ParameterSpecification> getArgumentsSpecifications() { |
| 139 | + return Arrays.asList( |
| 140 | + TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(), |
| 141 | + ScalarParameterSpecification.builder() |
| 142 | + .name(MODEL_ID_PARAMETER_NAME) |
| 143 | + .type(Type.STRING) |
| 144 | + .build(), |
| 145 | + ScalarParameterSpecification.builder() |
| 146 | + .name(TIMECOL_PARAMETER_NAME) |
| 147 | + .type(Type.STRING) |
| 148 | + .defaultValue(DEFAULT_TIME_COL) |
| 149 | + .build()); |
| 150 | + } |
| 151 | + |
| 152 | + @Override |
| 153 | + public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws UDFException { |
| 154 | + TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME); |
| 155 | + String modelId = (String) ((ScalarArgument) arguments.get(MODEL_ID_PARAMETER_NAME)).getValue(); |
| 156 | + // modelId should never be null or empty |
| 157 | + if (modelId == null || modelId.isEmpty()) { |
| 158 | + throw new SemanticException( |
| 159 | + String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); |
| 160 | + } |
| 161 | + |
| 162 | + String timeColumn = |
| 163 | + ((String) ((ScalarArgument) arguments.get(TIMECOL_PARAMETER_NAME)).getValue()) |
| 164 | + .toLowerCase(Locale.ENGLISH); |
| 165 | + |
| 166 | + if (timeColumn.isEmpty()) { |
| 167 | + throw new SemanticException( |
| 168 | + String.format("%s should never be null or empty.", TIMECOL_PARAMETER_NAME)); |
| 169 | + } |
| 170 | + |
| 171 | + // predicated columns should never contain partition by columns and time column |
| 172 | + Set<String> excludedColumns = |
| 173 | + input.getPartitionBy().stream() |
| 174 | + .map(s -> s.toLowerCase(Locale.ENGLISH)) |
| 175 | + .collect(Collectors.toSet()); |
| 176 | + excludedColumns.add(timeColumn); |
| 177 | + int timeColumnIndex = findColumnIndex(input, timeColumn, Collections.singleton(Type.TIMESTAMP)); |
| 178 | + |
| 179 | + List<Integer> requiredIndexList = new ArrayList<>(); |
| 180 | + requiredIndexList.add(timeColumnIndex); |
| 181 | + DescribedSchema.Builder properColumnSchemaBuilder = |
| 182 | + new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP); |
| 183 | + |
| 184 | + List<Type> inputColumnTypes = new ArrayList<>(); |
| 185 | + List<Optional<String>> allInputColumnsName = input.getFieldNames(); |
| 186 | + List<Type> allInputColumnsType = input.getFieldTypes(); |
| 187 | + |
| 188 | + for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { |
| 189 | + Optional<String> fieldName = allInputColumnsName.get(i); |
| 190 | + // All input value columns are required for model forecasting |
| 191 | + if (!fieldName.isPresent() |
| 192 | + || !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) { |
| 193 | + Type columnType = allInputColumnsType.get(i); |
| 194 | + checkType(columnType, fieldName.orElse("")); |
| 195 | + inputColumnTypes.add(columnType); |
| 196 | + requiredIndexList.add(i); |
| 197 | + } |
| 198 | + } |
| 199 | + properColumnSchemaBuilder.addField(DEFAULT_OUTPUT_COLUMN_NAME, Type.INT32); |
| 200 | + |
| 201 | + ClassifyTableFunctionHandle functionHandle = |
| 202 | + new ClassifyTableFunctionHandle(modelId, MAX_INPUT_LENGTH, inputColumnTypes); |
| 203 | + |
| 204 | + // outputColumnSchema |
| 205 | + return TableFunctionAnalysis.builder() |
| 206 | + .properColumnSchema(properColumnSchemaBuilder.build()) |
| 207 | + .handle(functionHandle) |
| 208 | + .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList) |
| 209 | + .build(); |
| 210 | + } |
| 211 | + |
| 212 | + // only allow for INT32, INT64, FLOAT, DOUBLE |
| 213 | + private void checkType(Type type, String columnName) { |
| 214 | + if (!ALLOWED_INPUT_TYPES.contains(type)) { |
| 215 | + throw new SemanticException( |
| 216 | + String.format( |
| 217 | + "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed", |
| 218 | + columnName, type)); |
| 219 | + } |
| 220 | + } |
| 221 | + |
| 222 | + @Override |
| 223 | + public TableFunctionHandle createTableFunctionHandle() { |
| 224 | + return new ClassifyTableFunctionHandle(); |
| 225 | + } |
| 226 | + |
| 227 | + @Override |
| 228 | + public TableFunctionProcessorProvider getProcessorProvider( |
| 229 | + TableFunctionHandle tableFunctionHandle) { |
| 230 | + return new TableFunctionProcessorProvider() { |
| 231 | + @Override |
| 232 | + public TableFunctionDataProcessor getDataProcessor() { |
| 233 | + return new ClassifyDataProcessor((ClassifyTableFunctionHandle) tableFunctionHandle); |
| 234 | + } |
| 235 | + }; |
| 236 | + } |
| 237 | + |
| 238 | + private static class ClassifyDataProcessor implements TableFunctionDataProcessor { |
| 239 | + |
| 240 | + private static final TsBlockSerde SERDE = new TsBlockSerde(); |
| 241 | + private static final IClientManager<Integer, AINodeClient> CLIENT_MANAGER = |
| 242 | + AINodeClientManager.getInstance(); |
| 243 | + |
| 244 | + private final String modelId; |
| 245 | + private final int maxInputLength; |
| 246 | + private final LinkedList<Record> inputRecords; |
| 247 | + private final TsBlockBuilder inputTsBlockBuilder; |
| 248 | + private final List<ResultColumnAppender> inputColumnAppenderList; |
| 249 | + private final List<ResultColumnAppender> resultColumnAppenderList; |
| 250 | + |
| 251 | + public ClassifyDataProcessor(ClassifyTableFunctionHandle functionHandle) { |
| 252 | + this.modelId = functionHandle.modelId; |
| 253 | + this.maxInputLength = functionHandle.maxInputLength; |
| 254 | + this.inputRecords = new LinkedList<>(); |
| 255 | + List<TSDataType> inputTsDataTypeList = |
| 256 | + new ArrayList<>(functionHandle.inputColumnTypes.size()); |
| 257 | + for (Type type : functionHandle.inputColumnTypes) { |
| 258 | + // AINode currently only accept double input |
| 259 | + inputTsDataTypeList.add(TSDataType.DOUBLE); |
| 260 | + } |
| 261 | + this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList); |
| 262 | + this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size()); |
| 263 | + for (Type type : functionHandle.inputColumnTypes) { |
| 264 | + // AINode currently only accept double input |
| 265 | + inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE)); |
| 266 | + } |
| 267 | + this.resultColumnAppenderList = new ArrayList<>(1); |
| 268 | + this.resultColumnAppenderList.add(createResultColumnAppender(Type.INT32)); |
| 269 | + } |
| 270 | + |
| 271 | + @Override |
| 272 | + public void process( |
| 273 | + Record input, |
| 274 | + List<ColumnBuilder> properColumnBuilders, |
| 275 | + ColumnBuilder passThroughIndexBuilder) { |
| 276 | + // only keep at most maxInputLength rows |
| 277 | + if (maxInputLength != 0 && inputRecords.size() == maxInputLength) { |
| 278 | + inputRecords.removeFirst(); |
| 279 | + } |
| 280 | + inputRecords.add(input); |
| 281 | + } |
| 282 | + |
| 283 | + @Override |
| 284 | + public void finish( |
| 285 | + List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) { |
| 286 | + |
| 287 | + // time column |
| 288 | + long inputStartTime = inputRecords.getFirst().getLong(0); |
| 289 | + long inputEndTime = inputRecords.getLast().getLong(0); |
| 290 | + if (inputEndTime < inputStartTime) { |
| 291 | + throw new SemanticException( |
| 292 | + String.format( |
| 293 | + "input end time should never less than start time, start time is %s, end time is %s", |
| 294 | + inputStartTime, inputEndTime)); |
| 295 | + } |
| 296 | + int outputLength = inputRecords.size(); |
| 297 | + for (Record inputRecord : inputRecords) { |
| 298 | + properColumnBuilders.get(0).writeLong(inputRecord.getLong(0)); |
| 299 | + } |
| 300 | + |
| 301 | + // predicated columns |
| 302 | + TsBlock predicatedResult = classify(); |
| 303 | + if (predicatedResult.getPositionCount() != outputLength) { |
| 304 | + throw new IoTDBRuntimeException( |
| 305 | + String.format( |
| 306 | + "Model %s output length is %s, doesn't equal to specified %s", |
| 307 | + modelId, predicatedResult.getPositionCount(), outputLength), |
| 308 | + TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); |
| 309 | + } |
| 310 | + |
| 311 | + for (int columnIndex = 1, size = predicatedResult.getValueColumnCount(); |
| 312 | + columnIndex <= size; |
| 313 | + columnIndex++) { |
| 314 | + Column column = predicatedResult.getColumn(columnIndex - 1); |
| 315 | + ColumnBuilder builder = properColumnBuilders.get(columnIndex); |
| 316 | + ResultColumnAppender appender = resultColumnAppenderList.get(columnIndex - 1); |
| 317 | + for (int row = 0; row < outputLength; row++) { |
| 318 | + if (column.isNull(row)) { |
| 319 | + builder.appendNull(); |
| 320 | + } else { |
| 321 | + // convert double to real type |
| 322 | + appender.writeDouble(column.getDouble(row), builder); |
| 323 | + } |
| 324 | + } |
| 325 | + } |
| 326 | + } |
| 327 | + |
| 328 | + private TsBlock classify() { |
| 329 | + int outputLength = inputRecords.size(); |
| 330 | + while (!inputRecords.isEmpty()) { |
| 331 | + Record row = inputRecords.removeFirst(); |
| 332 | + inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0)); |
| 333 | + for (int i = 1, size = row.size(); i < size; i++) { |
| 334 | + // we set null input to 0.0 |
| 335 | + if (row.isNull(i)) { |
| 336 | + inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0); |
| 337 | + } else { |
| 338 | + // need to transform other types to DOUBLE |
| 339 | + inputTsBlockBuilder |
| 340 | + .getColumnBuilder(i - 1) |
| 341 | + .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i)); |
| 342 | + } |
| 343 | + } |
| 344 | + inputTsBlockBuilder.declarePosition(); |
| 345 | + } |
| 346 | + TsBlock inputData = inputTsBlockBuilder.build(); |
| 347 | + |
| 348 | + TForecastResp resp; |
| 349 | + try (AINodeClient client = |
| 350 | + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { |
| 351 | + resp = client.forecast(new TForecastReq(modelId, SERDE.serialize(inputData), outputLength)); |
| 352 | + } catch (Exception e) { |
| 353 | + throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); |
| 354 | + } |
| 355 | + |
| 356 | + if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { |
| 357 | + String message = |
| 358 | + String.format( |
| 359 | + "Error occurred while executing classify:[%s]", resp.getStatus().getMessage()); |
| 360 | + throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); |
| 361 | + } |
| 362 | + |
| 363 | + return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); |
| 364 | + } |
| 365 | + } |
| 366 | +} |
0 commit comments