Skip to content

Commit 4a15632

Browse files
author
Liu Zhengyun
committed
add classify tvf
1 parent ef84301 commit 4a15632

File tree

5 files changed

+501
-125
lines changed

5 files changed

+501
-125
lines changed

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction;
2626
import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction;
2727
import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction;
28+
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ClassifyTableFunction;
2829
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
2930
import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction;
3031
import org.apache.iotdb.udf.api.relational.TableFunction;
@@ -42,7 +43,8 @@ public enum TableBuiltinTableFunction {
4243
VARIATION("variation"),
4344
CAPACITY("capacity"),
4445
FORECAST("forecast"),
45-
PATTERN_MATCH("pattern_match");
46+
PATTERN_MATCH("pattern_match"),
47+
CLASSIFY("classify");
4648

4749
private final String functionName;
4850

@@ -86,6 +88,8 @@ public static TableFunction getBuiltinTableFunction(String functionName) {
8688
return new CapacityTableFunction();
8789
case "forecast":
8890
return new ForecastTableFunction();
91+
case "classify":
92+
return new ClassifyTableFunction();
8993
default:
9094
throw new UnsupportedOperationException("Unsupported table function: " + functionName);
9195
}
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
package org.apache.iotdb.db.queryengine.plan.relational.function.tvf;
2+
3+
import org.apache.iotdb.ainode.rpc.thrift.TForecastReq;
4+
import org.apache.iotdb.ainode.rpc.thrift.TForecastResp;
5+
import org.apache.iotdb.commons.client.IClientManager;
6+
import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
7+
import org.apache.iotdb.db.exception.sql.SemanticException;
8+
import org.apache.iotdb.db.protocol.client.an.AINodeClient;
9+
import org.apache.iotdb.db.protocol.client.an.AINodeClientManager;
10+
import org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender;
11+
import org.apache.iotdb.rpc.TSStatusCode;
12+
import org.apache.iotdb.udf.api.exception.UDFException;
13+
import org.apache.iotdb.udf.api.relational.TableFunction;
14+
import org.apache.iotdb.udf.api.relational.access.Record;
15+
import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis;
16+
import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle;
17+
import org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider;
18+
import org.apache.iotdb.udf.api.relational.table.argument.Argument;
19+
import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema;
20+
import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument;
21+
import org.apache.iotdb.udf.api.relational.table.argument.TableArgument;
22+
import org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor;
23+
import org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification;
24+
import org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification;
25+
import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification;
26+
import org.apache.iotdb.udf.api.type.Type;
27+
28+
import org.apache.tsfile.block.column.Column;
29+
import org.apache.tsfile.block.column.ColumnBuilder;
30+
import org.apache.tsfile.enums.TSDataType;
31+
import org.apache.tsfile.read.common.block.TsBlock;
32+
import org.apache.tsfile.read.common.block.TsBlockBuilder;
33+
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
34+
import org.apache.tsfile.utils.PublicBAOS;
35+
import org.apache.tsfile.utils.ReadWriteIOUtils;
36+
37+
import java.io.DataOutputStream;
38+
import java.io.IOException;
39+
import java.nio.ByteBuffer;
40+
import java.util.ArrayList;
41+
import java.util.Arrays;
42+
import java.util.Collections;
43+
import java.util.HashSet;
44+
import java.util.LinkedList;
45+
import java.util.List;
46+
import java.util.Locale;
47+
import java.util.Map;
48+
import java.util.Objects;
49+
import java.util.Optional;
50+
import java.util.Set;
51+
import java.util.stream.Collectors;
52+
53+
import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex;
54+
import static org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender;
55+
import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
56+
57+
public class ClassifyTableFunction implements TableFunction {
58+
59+
public static class ClassifyTableFunctionHandle implements TableFunctionHandle {
60+
String modelId;
61+
int maxInputLength;
62+
List<Type> inputColumnTypes;
63+
64+
public ClassifyTableFunctionHandle() {}
65+
66+
public ClassifyTableFunctionHandle(
67+
String modelId, int maxInputLength, List<Type> inputColumnTypes) {
68+
this.modelId = modelId;
69+
this.maxInputLength = maxInputLength;
70+
this.inputColumnTypes = inputColumnTypes;
71+
}
72+
73+
@Override
74+
public byte[] serialize() {
75+
try (PublicBAOS publicBAOS = new PublicBAOS();
76+
DataOutputStream outputStream = new DataOutputStream(publicBAOS)) {
77+
ReadWriteIOUtils.write(modelId, outputStream);
78+
ReadWriteIOUtils.write(maxInputLength, outputStream);
79+
ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream);
80+
for (Type type : inputColumnTypes) {
81+
ReadWriteIOUtils.write(type.getType(), outputStream);
82+
}
83+
outputStream.flush();
84+
return publicBAOS.toByteArray();
85+
} catch (IOException e) {
86+
throw new IoTDBRuntimeException(
87+
String.format(
88+
"Error occurred while serializing ForecastTableFunctionHandle: %s", e.getMessage()),
89+
TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
90+
}
91+
}
92+
93+
@Override
94+
public void deserialize(byte[] bytes) {
95+
ByteBuffer buffer = ByteBuffer.wrap(bytes);
96+
this.modelId = ReadWriteIOUtils.readString(buffer);
97+
this.maxInputLength = ReadWriteIOUtils.readInt(buffer);
98+
int size = ReadWriteIOUtils.readInt(buffer);
99+
this.inputColumnTypes = new ArrayList<>(size);
100+
for (int i = 0; i < size; i++) {
101+
inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer)));
102+
}
103+
}
104+
105+
@Override
106+
public boolean equals(Object o) {
107+
if (this == o) return true;
108+
if (o == null || getClass() != o.getClass()) return false;
109+
ClassifyTableFunctionHandle that = (ClassifyTableFunctionHandle) o;
110+
return maxInputLength == that.maxInputLength
111+
&& Objects.equals(modelId, that.modelId)
112+
&& Objects.equals(inputColumnTypes, that.inputColumnTypes);
113+
}
114+
115+
@Override
116+
public int hashCode() {
117+
return Objects.hash(modelId, maxInputLength, inputColumnTypes);
118+
}
119+
}
120+
121+
private static final String INPUT_PARAMETER_NAME = "INPUT";
122+
private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
123+
public static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
124+
private static final String DEFAULT_TIME_COL = "time";
125+
private static final String DEFAULT_OUTPUT_COLUMN_NAME = "category";
126+
private static final int MAX_INPUT_LENGTH = 2880;
127+
128+
private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>();
129+
130+
static {
131+
ALLOWED_INPUT_TYPES.add(Type.INT32);
132+
ALLOWED_INPUT_TYPES.add(Type.INT64);
133+
ALLOWED_INPUT_TYPES.add(Type.FLOAT);
134+
ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
135+
}
136+
137+
@Override
138+
public List<ParameterSpecification> getArgumentsSpecifications() {
139+
return Arrays.asList(
140+
TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(),
141+
ScalarParameterSpecification.builder()
142+
.name(MODEL_ID_PARAMETER_NAME)
143+
.type(Type.STRING)
144+
.build(),
145+
ScalarParameterSpecification.builder()
146+
.name(TIMECOL_PARAMETER_NAME)
147+
.type(Type.STRING)
148+
.defaultValue(DEFAULT_TIME_COL)
149+
.build());
150+
}
151+
152+
@Override
153+
public TableFunctionAnalysis analyze(Map<String, Argument> arguments) throws UDFException {
154+
TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME);
155+
String modelId = (String) ((ScalarArgument) arguments.get(MODEL_ID_PARAMETER_NAME)).getValue();
156+
// modelId should never be null or empty
157+
if (modelId == null || modelId.isEmpty()) {
158+
throw new SemanticException(
159+
String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME));
160+
}
161+
162+
String timeColumn =
163+
((String) ((ScalarArgument) arguments.get(TIMECOL_PARAMETER_NAME)).getValue())
164+
.toLowerCase(Locale.ENGLISH);
165+
166+
if (timeColumn.isEmpty()) {
167+
throw new SemanticException(
168+
String.format("%s should never be null or empty.", TIMECOL_PARAMETER_NAME));
169+
}
170+
171+
// predicated columns should never contain partition by columns and time column
172+
Set<String> excludedColumns =
173+
input.getPartitionBy().stream()
174+
.map(s -> s.toLowerCase(Locale.ENGLISH))
175+
.collect(Collectors.toSet());
176+
excludedColumns.add(timeColumn);
177+
int timeColumnIndex = findColumnIndex(input, timeColumn, Collections.singleton(Type.TIMESTAMP));
178+
179+
List<Integer> requiredIndexList = new ArrayList<>();
180+
requiredIndexList.add(timeColumnIndex);
181+
DescribedSchema.Builder properColumnSchemaBuilder =
182+
new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP);
183+
184+
List<Type> inputColumnTypes = new ArrayList<>();
185+
List<Optional<String>> allInputColumnsName = input.getFieldNames();
186+
List<Type> allInputColumnsType = input.getFieldTypes();
187+
188+
for (int i = 0, size = allInputColumnsName.size(); i < size; i++) {
189+
Optional<String> fieldName = allInputColumnsName.get(i);
190+
// All input value columns are required for model forecasting
191+
if (!fieldName.isPresent()
192+
|| !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) {
193+
Type columnType = allInputColumnsType.get(i);
194+
checkType(columnType, fieldName.orElse(""));
195+
inputColumnTypes.add(columnType);
196+
requiredIndexList.add(i);
197+
}
198+
}
199+
properColumnSchemaBuilder.addField(DEFAULT_OUTPUT_COLUMN_NAME, Type.INT32);
200+
201+
ClassifyTableFunctionHandle functionHandle =
202+
new ClassifyTableFunctionHandle(modelId, MAX_INPUT_LENGTH, inputColumnTypes);
203+
204+
// outputColumnSchema
205+
return TableFunctionAnalysis.builder()
206+
.properColumnSchema(properColumnSchemaBuilder.build())
207+
.handle(functionHandle)
208+
.requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList)
209+
.build();
210+
}
211+
212+
// only allow for INT32, INT64, FLOAT, DOUBLE
213+
private void checkType(Type type, String columnName) {
214+
if (!ALLOWED_INPUT_TYPES.contains(type)) {
215+
throw new SemanticException(
216+
String.format(
217+
"The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed",
218+
columnName, type));
219+
}
220+
}
221+
222+
@Override
223+
public TableFunctionHandle createTableFunctionHandle() {
224+
return new ClassifyTableFunctionHandle();
225+
}
226+
227+
@Override
228+
public TableFunctionProcessorProvider getProcessorProvider(
229+
TableFunctionHandle tableFunctionHandle) {
230+
return new TableFunctionProcessorProvider() {
231+
@Override
232+
public TableFunctionDataProcessor getDataProcessor() {
233+
return new ClassifyDataProcessor((ClassifyTableFunctionHandle) tableFunctionHandle);
234+
}
235+
};
236+
}
237+
238+
private static class ClassifyDataProcessor implements TableFunctionDataProcessor {
239+
240+
private static final TsBlockSerde SERDE = new TsBlockSerde();
241+
private static final IClientManager<Integer, AINodeClient> CLIENT_MANAGER =
242+
AINodeClientManager.getInstance();
243+
244+
private final String modelId;
245+
private final int maxInputLength;
246+
private final LinkedList<Record> inputRecords;
247+
private final TsBlockBuilder inputTsBlockBuilder;
248+
private final List<ResultColumnAppender> inputColumnAppenderList;
249+
private final List<ResultColumnAppender> resultColumnAppenderList;
250+
251+
public ClassifyDataProcessor(ClassifyTableFunctionHandle functionHandle) {
252+
this.modelId = functionHandle.modelId;
253+
this.maxInputLength = functionHandle.maxInputLength;
254+
this.inputRecords = new LinkedList<>();
255+
List<TSDataType> inputTsDataTypeList =
256+
new ArrayList<>(functionHandle.inputColumnTypes.size());
257+
for (Type type : functionHandle.inputColumnTypes) {
258+
// AINode currently only accept double input
259+
inputTsDataTypeList.add(TSDataType.DOUBLE);
260+
}
261+
this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList);
262+
this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size());
263+
for (Type type : functionHandle.inputColumnTypes) {
264+
// AINode currently only accept double input
265+
inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE));
266+
}
267+
this.resultColumnAppenderList = new ArrayList<>(1);
268+
this.resultColumnAppenderList.add(createResultColumnAppender(Type.INT32));
269+
}
270+
271+
@Override
272+
public void process(
273+
Record input,
274+
List<ColumnBuilder> properColumnBuilders,
275+
ColumnBuilder passThroughIndexBuilder) {
276+
// only keep at most maxInputLength rows
277+
if (maxInputLength != 0 && inputRecords.size() == maxInputLength) {
278+
inputRecords.removeFirst();
279+
}
280+
inputRecords.add(input);
281+
}
282+
283+
@Override
284+
public void finish(
285+
List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) {
286+
287+
// time column
288+
long inputStartTime = inputRecords.getFirst().getLong(0);
289+
long inputEndTime = inputRecords.getLast().getLong(0);
290+
if (inputEndTime < inputStartTime) {
291+
throw new SemanticException(
292+
String.format(
293+
"input end time should never less than start time, start time is %s, end time is %s",
294+
inputStartTime, inputEndTime));
295+
}
296+
int outputLength = inputRecords.size();
297+
for (Record inputRecord : inputRecords) {
298+
properColumnBuilders.get(0).writeLong(inputRecord.getLong(0));
299+
}
300+
301+
// predicated columns
302+
TsBlock predicatedResult = classify();
303+
if (predicatedResult.getPositionCount() != outputLength) {
304+
throw new IoTDBRuntimeException(
305+
String.format(
306+
"Model %s output length is %s, doesn't equal to specified %s",
307+
modelId, predicatedResult.getPositionCount(), outputLength),
308+
TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode());
309+
}
310+
311+
for (int columnIndex = 1, size = predicatedResult.getValueColumnCount();
312+
columnIndex <= size;
313+
columnIndex++) {
314+
Column column = predicatedResult.getColumn(columnIndex - 1);
315+
ColumnBuilder builder = properColumnBuilders.get(columnIndex);
316+
ResultColumnAppender appender = resultColumnAppenderList.get(columnIndex - 1);
317+
for (int row = 0; row < outputLength; row++) {
318+
if (column.isNull(row)) {
319+
builder.appendNull();
320+
} else {
321+
// convert double to real type
322+
appender.writeDouble(column.getDouble(row), builder);
323+
}
324+
}
325+
}
326+
}
327+
328+
private TsBlock classify() {
329+
int outputLength = inputRecords.size();
330+
while (!inputRecords.isEmpty()) {
331+
Record row = inputRecords.removeFirst();
332+
inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0));
333+
for (int i = 1, size = row.size(); i < size; i++) {
334+
// we set null input to 0.0
335+
if (row.isNull(i)) {
336+
inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0);
337+
} else {
338+
// need to transform other types to DOUBLE
339+
inputTsBlockBuilder
340+
.getColumnBuilder(i - 1)
341+
.writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i));
342+
}
343+
}
344+
inputTsBlockBuilder.declarePosition();
345+
}
346+
TsBlock inputData = inputTsBlockBuilder.build();
347+
348+
TForecastResp resp;
349+
try (AINodeClient client =
350+
CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) {
351+
resp = client.forecast(new TForecastReq(modelId, SERDE.serialize(inputData), outputLength));
352+
} catch (Exception e) {
353+
throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode());
354+
}
355+
356+
if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
357+
String message =
358+
String.format(
359+
"Error occurred while executing classify:[%s]", resp.getStatus().getMessage());
360+
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
361+
}
362+
363+
return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
364+
}
365+
}
366+
}

0 commit comments

Comments
 (0)