diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java index e118d6c3a98f..34fd7e85240c 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java @@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper { private static final String PROPERTIES_FILE = "iotdb-ainode.properties"; public static final String CONFIG_PATH = "conf"; public static final String SCRIPT_PATH = "sbin"; - public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights"; + public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin"; public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights"; private void replaceAttribute(String[] keys, String[] values, String filePath) { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java new file mode 100644 index 000000000000..5368c584443f --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeCallInferenceIT { + + private static final String[] WRITE_SQL_IN_TREE = + new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; + + private static final String CALL_INFERENCE_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final int DEFAULT_OUTPUT_LENGTH = 48; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareData(WRITE_SQL_IN_TREE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void callInferenceTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + callInferenceTest(statement, modelInfo); + } + } + } + + public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format( + CALL_INFERENCE_SQL_TEMPLATE, + modelInfo.getModelId(), + i, + DEFAULT_OUTPUT_LENGTH, + DEFAULT_OUTPUT_LENGTH); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output"); + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count); + } + } + } + + @Test + public void errorCallInferenceTestInTree() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; + errorTest(statement, sql, "1505: model [notFound404] has not been created."); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java new file mode 100644 index 000000000000..a23eec97497d --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeConcurrentForecastIT { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataForTableModel(); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + private static void prepareDataForTableModel() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE root"); + statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); + } + } + } + + @Test + public void concurrentGPUForecastTest() throws SQLException, InterruptedException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) { + concurrentGPUForecastTest(modelInfo); + } + } + + public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException, InterruptedException { + final int forecastLength = 512; + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + // Single forecast request can be processed successfully + final String forecastSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength); + final int threadCnt = 10; + final int loop = 100; + final String devices = "0,1"; + statement.execute( + String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelOnSpecifiedDevice( + statement, modelInfo.getModelId(), modelInfo.getModelType(), devices); + long startTime = System.currentTimeMillis(); + concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength); + long endTime = System.currentTimeMillis(); + LOGGER.info( + String.format( + "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", + modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime)); + statement.execute( + String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java deleted file mode 100644 index a08990d472fe..000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import com.google.common.collect.ImmutableSet; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.HashSet; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeConcurrentInferenceIT { - - private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class); - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataForTreeModel(); - prepareDataForTableModel(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - private static void prepareDataForTreeModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root.AI"); - statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)", - i, Math.sin(i * Math.PI / 1440))); - } - } - } - - private static void prepareDataForTableModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root"); - statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); - } - } - } - - // @Test - public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException { - concurrentGPUCallInferenceTest("timer_xl"); - concurrentGPUCallInferenceTest("sundial"); - } - - private void concurrentGPUCallInferenceTest(String modelId) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - concurrentInference( - statement, - String.format( - "CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)", - modelId, predictLength), - threadCnt, - loop, - predictLength); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - String forecastTableFunctionSql = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d"; - String forecastUDTFSql = - "SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM root.AI"; - - @Test - public void concurrentGPUForecastTest() throws SQLException, InterruptedException { - concurrentGPUForecastTest("timer_xl", forecastUDTFSql); - concurrentGPUForecastTest("sundial", forecastUDTFSql); - concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql); - concurrentGPUForecastTest("sundial", forecastTableFunctionSql); - } - - public void concurrentGPUForecastTest(String modelId, String selectSql) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - long startTime = System.currentTimeMillis(); - concurrentInference( - statement, - String.format(selectSql, modelId, predictLength), - threadCnt, - loop, - predictLength); - long endTime = System.currentTimeMillis(); - LOGGER.info( - String.format( - "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", - modelId, threadCnt * loop, threadCnt, loop, endTime - startTime)); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) - throws SQLException, InterruptedException { - Set targetDevices = ImmutableSet.copyOf(device.split(",")); - LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); - for (int retry = 0; retry < 200; retry++) { - Set foundDevices = new HashSet<>(); - try (final ResultSet resultSet = - statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { - while (resultSet.next()) { - String deviceId = resultSet.getString("DeviceId"); - String loadedModelId = resultSet.getString("ModelId"); - int count = resultSet.getInt("Count(instances)"); - LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); - if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { - foundDevices.add(deviceId); - LOGGER.info("Model {} is loaded to device {}", modelId, device); - } - } - if (foundDevices.containsAll(targetDevices)) { - LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); - return; - } - } - TimeUnit.SECONDS.sleep(3); - } - Assert.fail("Model " + modelId + " is not loaded on device " + device); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java new file mode 100644 index 000000000000..8953bec07a74 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeForecastIT { + + private static final String[] WRITE_SQL_IN_TABLE = + new String[] { + "CREATE DATABASE root", + "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", + }; + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM root.AI) ORDER BY time)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareData(WRITE_SQL_IN_TABLE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void forecastTableFunctionTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + forecastTableFunctionTest(statement, modelInfo); + } + } + } + + public void forecastTableFunctionTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String forecastTableFunctionSQL = + String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertTrue(count > 0); + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java deleted file mode 100644 index 70f7a1d9f9eb..000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData; -import static org.junit.Assert.assertEquals; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeInferenceSQLIT { - - static String[] WRITE_SQL_IN_TREE = - new String[] { - "set configuration \"trusted_uri_pattern\"='.*'", - "create model identity using uri \"" + EXAMPLE_MODEL_PATH + "\"", - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; - - static String[] WRITE_SQL_IN_TABLE = - new String[] { - "CREATE DATABASE root", - "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", - }; - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TREE); - prepareTableData(WRITE_SQL_IN_TABLE); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - // @Test - public void callInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void callInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - public void callInferenceTest(Statement statement) throws SQLException { - // SQL0: Invoke timer-sundial and timer-xl to inference, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(sundial, \"select s1 from root.AI\", generateTime=true, predict_length=720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(timer_xl, \"select s2 from root.AI\", generateTime=true, predict_length=256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", generateTime=true)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "CALL INFERENCE(identity, \"select s2,s0,s3,s1 from root.AI\", generateTime=false)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "CALL INFERENCE(naive_forecaster, \"select s0 from root.AI\", predict_length=3, generateTime=true)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "CALL INFERENCE(holtwinters, \"select s0 from root.AI\", predict_length=6, generateTime=true)"; - // TODO: enable following tests after refactor the CALL INFERENCE - - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2,output3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - - // @Test - public void errorCallInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void errorCallInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - public void errorCallInferenceTest(Statement statement) { - String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; - errorTest(statement, sql, "1505: model [notFound404] has not been created."); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", window=head(2))"; - // TODO: enable following tests after refactor the CALL INFERENCE - // errorTest(statement, sql, "701: Window output 2 is not equal to input size of model 7"); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI limit 5\")"; - // errorTest( - // statement, - // sql, - // "301: The number of rows 5 in the input data does not match the model input 7. Try to - // use LIMIT in SQL or WINDOW in CALL INFERENCE"); - sql = "CREATE MODEL 中文 USING URI \"" + EXAMPLE_MODEL_PATH + "\""; - errorTest(statement, sql, "701: ModelId can only contain letters, numbers, and underscores"); - } - - @Test - public void selectForecastTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - // SQL0: Invoke timer-sundial and timer-xl to forecast, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s1 FROM root.AI) ORDER BY time, output_length=>720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s1"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'timer_xl', input=>(SELECT time,s2 FROM root.AI) ORDER BY time, output_length=>256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s2"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s0,s1,s2,s3 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s2,s0,s3,s1 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "SELECT * FROM FORECAST(model_id=>'naive_forecaster', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>3)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "SELECT * FROM FORECAST(model_id=>'holtwinters', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>6)"; - // TODO: enable following tests after refactor the FORECAST - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2,s3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s2,s0,s3,s1"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index 93351c017852..56b9a5bbda73 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -35,14 +35,14 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.TimeUnit; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; public class AINodeInstanceManagementIT { - private static final int WAITING_TIME_SEC = 30; private static final Set TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1")); @BeforeClass @@ -85,52 +85,18 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter } // Load sundial to each device - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS 0")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - while (resultSet.next()) { - Assert.assertEquals("0", resultSet.getString("DeviceID")); - Assert.assertEquals("Timer-Sundial", resultSet.getString("ModelType")); - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - } - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "sundial", "sundial", TARGET_DEVICES.toString()); // Load timer_xl to each device - statement.execute("LOAD MODEL timer_xl TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - if (resultSet.getString("ModelType").equals("Timer-XL")) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "timer_xl", "timer_xl", TARGET_DEVICES.toString()); // Clean every device - statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - statement.execute("UNLOAD MODEL timer_xl FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); + statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES)); + checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } private static final int LOOP_CNT = 10; @@ -141,23 +107,9 @@ public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + checkModelOnSpecifiedDevice(statement, "sundial", "sundial", TARGET_DEVICES.toString()); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } } @@ -170,12 +122,7 @@ public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedExcept statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); } - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC * LOOP_CNT); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 2a1461e4a15b..25cdf0f8ceef 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -19,6 +19,7 @@ package org.apache.iotdb.ainode.it; +import org.apache.iotdb.ainode.utils.AINodeTestUtils; import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo; import org.apache.iotdb.it.env.EnvFactory; import org.apache.iotdb.it.framework.IoTDBTestRunner; @@ -36,13 +37,8 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; -import java.util.AbstractMap; -import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.junit.Assert.assertEquals; @@ -54,36 +50,6 @@ @Category({AIClusterIT.class}) public class AINodeModelManageIT { - private static final Map BUILT_IN_MODEL_MAP = - Stream.of( - new AbstractMap.SimpleEntry<>( - "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "holtwinters", - new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "exponential_smoothing", - new FakeModelInfo( - "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "naive_forecaster", - new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stl_forecaster", - new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gaussian_hmm", - new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "sundial", new FakeModelInfo("sundial", "Timer-Sundial", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "timer_xl", new FakeModelInfo("timer_xl", "Timer-XL", "BUILT-IN", "ACTIVE"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - @BeforeClass public static void setUp() throws Exception { // Init 1C1D1A cluster environment @@ -95,7 +61,7 @@ public static void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } - @Test + // @Test public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -103,7 +69,7 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup } } - @Test + // @Test public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -114,8 +80,7 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru private void userDefinedModelManagementTest(Statement statement) throws SQLException, InterruptedException { final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = - "create model operationTest using uri \"" + EXAMPLE_MODEL_PATH + "\""; + final String registerSql = "create model operationTest using uri \"" + "\""; final String showSql = "SHOW MODELS operationTest"; final String dropSql = "DROP MODEL operationTest"; @@ -208,10 +173,10 @@ private void showBuiltInModelTest(Statement statement) throws SQLException { resultSet.getString(2), resultSet.getString(3), resultSet.getString(4)); - assertTrue(BUILT_IN_MODEL_MAP.containsKey(modelInfo.getModelId())); - assertEquals(BUILT_IN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); + assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId())); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); } } - assertEquals(BUILT_IN_MODEL_MAP.size(), built_in_model_count); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.size(), built_in_model_count); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index cbb0b03b2299..d9ddb6a4e097 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -19,30 +19,70 @@ package org.apache.iotdb.ainode.utils; -import java.io.File; +import com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.util.AbstractMap; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class AINodeTestUtils { - public static final String EXAMPLE_MODEL_PATH = - "file://" - + System.getProperty("user.dir") - + File.separator - + "src" - + File.separator - + "test" - + File.separator - + "resources" - + File.separator - + "ainode-example"; + public static final Map BUILTIN_LTSM_MAP = + Stream.of( + new AbstractMap.SimpleEntry<>( + "sundial", new FakeModelInfo("sundial", "sundial", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "timer_xl", new FakeModelInfo("timer_xl", "timer", "BUILT-IN", "ACTIVE"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static final Map BUILTIN_MODEL_MAP; + + static { + Map tmp = + Stream.of( + new AbstractMap.SimpleEntry<>( + "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "holtwinters", + new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "exponential_smoothing", + new FakeModelInfo( + "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "naive_forecaster", + new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "stl_forecaster", + new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "gaussian_hmm", + new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + tmp.putAll(BUILTIN_LTSM_MAP); + BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp); + } + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeTestUtils.class); public static void checkHeader(ResultSetMetaData resultSetMetaData, String title) throws SQLException { @@ -94,6 +134,68 @@ public static void concurrentInference( } } + public static void checkModelOnSpecifiedDevice( + Statement statement, String modelId, String modelType, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 200; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + String loadedModelType = resultSet.getString("ModelType"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) + && loadedModelType.equals(modelType) + && targetDevices.contains(deviceId) + && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.containsAll(targetDevices)) { + LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is not loaded on device " + device); + } + + public static void checkModelNotOnSpecifiedDevice( + Statement statement, String modelId, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 50; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.isEmpty()) { + LOGGER.info("Model {} is unloaded from devices {}.", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is still loaded on device " + device); + } + public static class FakeModelInfo { private final String modelId; diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index e04ff838819e..609a228022f8 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -403,7 +403,6 @@ public void testInformationSchema() throws SQLException { "databases,INF,", "functions,INF,", "keywords,INF,", - "models,INF,", "nodes,INF,", "pipe_plugins,INF,", "pipes,INF,", @@ -504,16 +503,6 @@ public void testInformationSchema() throws SQLException { "database,STRING,TAG,", "table_name,STRING,TAG,", "view_definition,STRING,ATTRIBUTE,"))); - TestUtils.assertResultSetEqual( - statement.executeQuery("desc models"), - "ColumnName,DataType,Category,", - new HashSet<>( - Arrays.asList( - "model_id,STRING,TAG,", - "model_type,STRING,ATTRIBUTE,", - "state,STRING,ATTRIBUTE,", - "configs,STRING,ATTRIBUTE,", - "notes,STRING,ATTRIBUTE,"))); TestUtils.assertResultSetEqual( statement.executeQuery("desc functions"), "ColumnName,DataType,Category,", @@ -638,7 +627,6 @@ public void testInformationSchema() throws SQLException { "information_schema,pipes,INF,USING,null,SYSTEM VIEW,", "information_schema,subscriptions,INF,USING,null,SYSTEM VIEW,", "information_schema,views,INF,USING,null,SYSTEM VIEW,", - "information_schema,models,INF,USING,null,SYSTEM VIEW,", "information_schema,functions,INF,USING,null,SYSTEM VIEW,", "information_schema,configurations,INF,USING,null,SYSTEM VIEW,", "information_schema,keywords,INF,USING,null,SYSTEM VIEW,", @@ -651,7 +639,7 @@ public void testInformationSchema() throws SQLException { TestUtils.assertResultSetEqual( statement.executeQuery("count devices from tables where status = 'USING'"), "count(devices),", - Collections.singleton("20,")); + Collections.singleton("18,")); TestUtils.assertResultSetEqual( statement.executeQuery( "select * from columns where table_name = 'queries' or database = 'test'"), diff --git a/integration-test/src/test/resources/ainode-example/config.yaml b/integration-test/src/test/resources/ainode-example/config.yaml deleted file mode 100644 index 665acb8704e2..000000000000 --- a/integration-test/src/test/resources/ainode-example/config.yaml +++ /dev/null @@ -1,5 +0,0 @@ -configs: - input_shape: [7, 4] - output_shape: [7, 4] - input_type: ["float32", "float32", "float32", "float32"] - output_type: ["float32", "float32", "float32", "float32"] diff --git a/integration-test/src/test/resources/ainode-example/model.pt b/integration-test/src/test/resources/ainode-example/model.pt deleted file mode 100644 index 67d4aec6999f..000000000000 Binary files a/integration-test/src/test/resources/ainode-example/model.pt and /dev/null differ diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index afcf0683d7d0..e465df7e36d2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -20,7 +20,6 @@ from iotdb.ainode.core.constant import ( AINODE_BUILD_INFO, - AINODE_BUILTIN_MODELS_DIR, AINODE_CLUSTER_INGRESS_ADDRESS, AINODE_CLUSTER_INGRESS_PASSWORD, AINODE_CLUSTER_INGRESS_PORT, @@ -33,10 +32,11 @@ AINODE_CONF_POM_FILE_NAME, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, - AINODE_INFERENCE_MAX_PREDICT_LENGTH, + AINODE_INFERENCE_MAX_OUTPUT_LENGTH, AINODE_INFERENCE_MEMORY_USAGE_RATIO, AINODE_INFERENCE_MODEL_MEM_USAGE_MAP, AINODE_LOG_DIR, + AINODE_MODELS_BUILTIN_DIR, AINODE_MODELS_DIR, AINODE_RPC_ADDRESS, AINODE_RPC_PORT, @@ -75,9 +75,7 @@ def __init__(self): self._ain_inference_batch_interval_in_ms: int = ( AINODE_INFERENCE_BATCH_INTERVAL_IN_MS ) - self._ain_inference_max_predict_length: int = ( - AINODE_INFERENCE_MAX_PREDICT_LENGTH - ) + self._ain_inference_max_output_length: int = AINODE_INFERENCE_MAX_OUTPUT_LENGTH self._ain_inference_model_mem_usage_map: dict[str, int] = ( AINODE_INFERENCE_MODEL_MEM_USAGE_MAP ) @@ -95,7 +93,7 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR - self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR + self._ain_models_builtin_dir = AINODE_MODELS_BUILTIN_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -160,13 +158,13 @@ def set_ain_inference_batch_interval_in_ms( ) -> None: self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms - def get_ain_inference_max_predict_length(self) -> int: - return self._ain_inference_max_predict_length + def get_ain_inference_max_output_length(self) -> int: + return self._ain_inference_max_output_length - def set_ain_inference_max_predict_length( - self, ain_inference_max_predict_length: int + def set_ain_inference_max_output_length( + self, ain_inference_max_output_length: int ) -> None: - self._ain_inference_max_predict_length = ain_inference_max_predict_length + self._ain_inference_max_output_length = ain_inference_max_output_length def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]: return self._ain_inference_model_mem_usage_map @@ -204,11 +202,11 @@ def get_ain_models_dir(self) -> str: def set_ain_models_dir(self, ain_models_dir: str) -> None: self._ain_models_dir = ain_models_dir - def get_ain_builtin_models_dir(self) -> str: - return self._ain_builtin_models_dir + def get_ain_models_builtin_dir(self) -> str: + return self._ain_models_builtin_dir - def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None: - self._ain_builtin_models_dir = ain_builtin_models_dir + def set_ain_models_builtin_dir(self, ain_models_builtin_dir: str) -> None: + self._ain_models_builtin_dir = ain_models_builtin_dir def get_ain_system_dir(self) -> str: return self._ain_system_dir @@ -374,6 +372,11 @@ def _load_config_from_file(self) -> None: if "ain_models_dir" in config_keys: self._config.set_ain_models_dir(file_configs["ain_models_dir"]) + if "ain_models_builtin_dir" in config_keys: + self._config.set_ain_models_builtin_dir( + file_configs["ain_models_builtin_dir"] + ) + if "ain_system_dir" in config_keys: self._config.set_ain_system_dir(file_configs["ain_system_dir"]) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index b9923d3e3ee7..c0b19a570d20 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -18,9 +18,7 @@ import logging import os from enum import Enum -from typing import List -from iotdb.ainode.core.model.model_enums import BuiltInModelType from iotdb.thrift.common.ttypes import TEndPoint IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "") @@ -50,23 +48,23 @@ # AINode inference configuration AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 -AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 +AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880 + +# TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { - BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB - BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB + "sundial": 1036 * 1024**2, # 1036 MiB + "timer": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes + AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( 1.2 # the overhead ratio for inference, used to estimate the pool size ) -# AINode folder structure AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") -AINODE_BUILTIN_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/weights" -) # For built-in models, we only need to store their weights and config. -AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") -AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs") +AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model" +AINODE_SYSTEM_DIR = "data/ainode/system" +AINODE_LOG_DIR = "logs" # AINode log LOG_FILE_TYPE = ["all", "info", "warn", "error"] @@ -77,11 +75,6 @@ "log_inference_rank_{}_" # example: log_inference_rank_0_all.log ) -# AINode model management -MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" -MODEL_CONFIG_FILE_IN_JSON = "config.json" -MODEL_WEIGHTS_FILE_IN_PT = "model.pt" -MODEL_CONFIG_FILE_IN_YAML = "config.yaml" DEFAULT_CHUNK_SIZE = 8192 @@ -141,132 +134,8 @@ def name(self): return self.value -class ForecastModelType(Enum): - DLINEAR = "dlinear" - DLINEAR_INDIVIDUAL = "dlinear_individual" - NBEATS = "nbeats" - - @classmethod - def values(cls) -> List[str]: - values = [] - for item in list(cls): - values.append(item.value) - return values - - class ModelInputName(Enum): DATA_X = "data_x" TIME_STAMP_X = "time_stamp_x" TIME_STAMP_Y = "time_stamp_y" DEC_INP = "dec_inp" - - -class AttributeName(Enum): - # forecast Attribute - PREDICT_LENGTH = "predict_length" - - # NaiveForecaster - STRATEGY = "strategy" - SP = "sp" - - # STLForecaster - # SP = 'sp' - SEASONAL = "seasonal" - SEASONAL_DEG = "seasonal_deg" - TREND_DEG = "trend_deg" - LOW_PASS_DEG = "low_pass_deg" - SEASONAL_JUMP = "seasonal_jump" - TREND_JUMP = "trend_jump" - LOSS_PASS_JUMP = "low_pass_jump" - - # ExponentialSmoothing - DAMPED_TREND = "damped_trend" - INITIALIZATION_METHOD = "initialization_method" - OPTIMIZED = "optimized" - REMOVE_BIAS = "remove_bias" - USE_BRUTE = "use_brute" - - # Arima - ORDER = "order" - SEASONAL_ORDER = "seasonal_order" - METHOD = "method" - MAXITER = "maxiter" - SUPPRESS_WARNINGS = "suppress_warnings" - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" - SCORING = "scoring" - WITH_INTERCEPT = "with_intercept" - TIME_VARYING_REGRESSION = "time_varying_regression" - ENFORCE_STATIONARITY = "enforce_stationarity" - ENFORCE_INVERTIBILITY = "enforce_invertibility" - SIMPLE_DIFFERENCING = "simple_differencing" - MEASUREMENT_ERROR = "measurement_error" - MLE_REGRESSION = "mle_regression" - HAMILTON_REPRESENTATION = "hamilton_representation" - CONCENTRATE_SCALE = "concentrate_scale" - - # GAUSSIAN_HMM - N_COMPONENTS = "n_components" - COVARIANCE_TYPE = "covariance_type" - MIN_COVAR = "min_covar" - STARTPROB_PRIOR = "startprob_prior" - TRANSMAT_PRIOR = "transmat_prior" - MEANS_PRIOR = "means_prior" - MEANS_WEIGHT = "means_weight" - COVARS_PRIOR = "covars_prior" - COVARS_WEIGHT = "covars_weight" - ALGORITHM = "algorithm" - N_ITER = "n_iter" - TOL = "tol" - PARAMS = "params" - INIT_PARAMS = "init_params" - IMPLEMENTATION = "implementation" - - # GMMHMM - # N_COMPONENTS = "n_components" - N_MIX = "n_mix" - # MIN_COVAR = "min_covar" - # STARTPROB_PRIOR = "startprob_prior" - # TRANSMAT_PRIOR = "transmat_prior" - WEIGHTS_PRIOR = "weights_prior" - - # MEANS_PRIOR = "means_prior" - # MEANS_WEIGHT = "means_weight" - # ALGORITHM = "algorithm" - # COVARIANCE_TYPE = "covariance_type" - # N_ITER = "n_iter" - # TOL = "tol" - # INIT_PARAMS = "init_params" - # PARAMS = "params" - # IMPLEMENTATION = "implementation" - - # STRAY - ALPHA = "alpha" - K = "k" - KNN_ALGORITHM = "knn_algorithm" - P = "p" - SIZE_THRESHOLD = "size_threshold" - OUTLIER_TAIL = "outlier_tail" - - # timerxl - INPUT_TOKEN_LEN = "input_token_len" - HIDDEN_SIZE = "hidden_size" - INTERMEDIATE_SIZE = "intermediate_size" - OUTPUT_TOKEN_LENS = "output_token_lens" - NUM_HIDDEN_LAYERS = "num_hidden_layers" - NUM_ATTENTION_HEADS = "num_attention_heads" - HIDDEN_ACT = "hidden_act" - USE_CACHE = "use_cache" - ROPE_THETA = "rope_theta" - ATTENTION_DROPOUT = "attention_dropout" - INITIALIZER_RANGE = "initializer_range" - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - CKPT_PATH = "ckpt_path" - - # sundial - DROPOUT_RATE = "dropout_rate" - FLOW_LOSS_DEPTH = "flow_loss_depth" - NUM_SAMPLING_STEPS = "num_sampling_steps" - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" - - def name(self) -> str: - return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index bc89cdc30662..30b9d54dcc7d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -17,7 +17,7 @@ # import re -from iotdb.ainode.core.constant import ( +from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_YAML, MODEL_WEIGHTS_FILE_IN_PT, ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index 82c72cc37abf..50634914c273 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. # + import threading from typing import Any import torch -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.atmoic_int import AtomicInt @@ -41,8 +39,7 @@ def __init__( req_id: str, model_id: str, inputs: torch.Tensor, - inference_pipeline: AbstractInferencePipeline, - max_new_tokens: int = 96, + output_length: int = 96, **infer_kwargs, ): if inputs.ndim == 1: @@ -52,9 +49,8 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.inference_pipeline = inference_pipeline - self.max_new_tokens = ( - max_new_tokens # Number of time series data points to generate + self.output_length = ( + output_length # Number of time series data points to generate ) self.batch_size = inputs.size(0) @@ -65,7 +61,7 @@ def __init__( # Preallocate output buffer [batch_size, max_new_tokens] self.output_tensor = torch.zeros( - self.batch_size, max_new_tokens, device="cpu" + self.batch_size, output_length, device="cpu" ) # shape: [self.batch_size, max_new_steps] def mark_running(self): @@ -77,7 +73,7 @@ def mark_finished(self): def is_finished(self) -> bool: return ( self.state == InferenceRequestState.FINISHED - or self.cur_step_idx >= self.max_new_tokens + or self.cur_step_idx >= self.output_length ) def write_step_output(self, step_output: torch.Tensor): @@ -87,11 +83,11 @@ def write_step_output(self, step_output: torch.Tensor): batch_size, step_size = step_output.shape end_idx = self.cur_step_idx + step_size - if end_idx > self.max_new_tokens: + if end_idx > self.output_length: self.output_tensor[:, self.cur_step_idx :] = step_output[ - :, : self.max_new_tokens - self.cur_step_idx + :, : self.output_length - self.cur_step_idx ] - self.cur_step_idx = self.max_new_tokens + self.cur_step_idx = self.output_length else: self.output_tensor[:, self.cur_step_idx : end_idx] = step_output self.cur_step_idx = end_idx diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 6b054c91fe31..a6c415a6c848 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -25,19 +25,22 @@ import numpy as np import torch import torch.multiprocessing as mp -from transformers import PretrainedConfig from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, +) +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device @@ -62,7 +65,6 @@ def __init__( pool_id: int, model_info: ModelInfo, device: str, - config: PretrainedConfig, request_queue: mp.Queue, result_queue: mp.Queue, ready_event, @@ -71,7 +73,6 @@ def __init__( super().__init__() self.pool_id = pool_id self.model_info = model_info - self.config = config self.pool_kwargs = pool_kwargs self.ready_event = ready_event self.device = convert_device_id_to_torch_device(device) @@ -86,8 +87,8 @@ def __init__( self._batcher = BasicBatcher() self._stop_event = mp.Event() - self._model = None - self._model_manager = None + self._inference_pipeline = None + self._logger = None # Fix inference seed @@ -98,9 +99,6 @@ def __init__( def _activate_requests(self): requests = self._request_scheduler.schedule_activate() for request in requests: - request.inputs = request.inference_pipeline.preprocess_inputs( - request.inputs - ) request.mark_running() self._running_queue.put(request) self._logger.debug( @@ -117,72 +115,51 @@ def _step(self): grouped_requests = defaultdict(list) for req in all_requests: - key = (req.inputs.shape[1], req.max_new_tokens) + key = (req.inputs.shape[1], req.output_length) grouped_requests[key].append(req) grouped_requests = list(grouped_requests.values()) for requests in grouped_requests: batch_inputs = self._batcher.batch_request(requests).to(self.device) - if self.model_info.model_type == BuiltInModelType.SUNDIAL.value: - batch_output = self._model.generate( + if isinstance(self._inference_pipeline, ForecastPipeline): + batch_output = self._inference_pipeline.forecast( batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - num_samples=10, + predict_length=requests[0].output_length, revin=True, ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output.mean(dim=1)) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) - - elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value: - batch_output = self._model.generate( + elif isinstance(self._inference_pipeline, ClassificationPipeline): + batch_output = self._inference_pipeline.classify( batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - revin=True, + # more infer kwargs can be added here ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) + elif isinstance(self._inference_pipeline, ChatPipeline): + batch_output = self._inference_pipeline.chat( + batch_inputs, + # more infer kwargs can be added here + ) + else: + self._logger.error("[Inference] Unsupported pipeline type.") + offset = 0 + for request in requests: + request.output_tensor = request.output_tensor.to(self.device) + cur_batch_size = request.batch_size + cur_output = batch_output[offset : offset + cur_batch_size] + offset += cur_batch_size + request.write_step_output(cur_output) + + if request.is_finished(): + # ensure the output tensor is on CPU before sending to result queue + request.output_tensor = request.output_tensor.cpu() + self._finished_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" + ) + else: + self._waiting_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" + ) + return def _requests_execute_loop(self): while not self._stop_event.is_set(): @@ -193,11 +170,8 @@ def run(self): self._logger = Logger( INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) - self._model_manager = ModelManager() self._request_scheduler.device = self.device - self._model = self._model_manager.load_model(self.model_info.model_id, {}).to( - self.device - ) + self._inference_pipeline = load_pipeline(self.model_info, str(self.device)) self.ready_event.set() activate_daemon = threading.Thread( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py new file mode 100644 index 000000000000..82601e398059 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import ABC, abstractmethod + +import torch + +from iotdb.ainode.core.model.model_loader import load_model + + +class BasicPipeline(ABC): + def __init__(self, model_info, **model_kwargs): + self.model_info = model_info + self.device = model_kwargs.get("device", "cpu") + self.model = load_model(model_info, device_map=self.device, **model_kwargs) + + def _preprocess(self, inputs): + """ + Preprocess the input before inference, including shape validation and value transformation. + """ + return inputs + + def _postprocess(self, output: torch.Tensor): + """ + Post-process the outputs after the entire inference task. + """ + return output + + +class ForecastPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def forecast(self, inputs, **infer_kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output + + +class ClassificationPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def classify(self, inputs, **kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output + + +class ChatPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def chat(self, inputs, **kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py new file mode 100644 index 000000000000..a30038dd5fef --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from pathlib import Path + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_storage import ModelInfo +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path + +logger = Logger() + + +def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs): + if model_info.model_type == "sktime": + from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline + + pipeline_cls = SktimePipeline + elif model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls) + else: + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + pipeline_cls = import_class_from_path( + model_info.model_id, model_info.pipeline_cls + ) + + return pipeline_cls(model_info, device=device, **model_kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 54580402ec29..8ffa89ffd675 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -22,7 +22,6 @@ from concurrent.futures import wait from typing import Dict, Optional -import torch import torch.multiprocessing as mp from iotdb.ainode.core.exception import InferenceModelInternalError @@ -41,9 +40,6 @@ ) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig from iotdb.ainode.core.util.atmoic_int import AtomicInt from iotdb.ainode.core.util.batch_executor import BatchExecutor from iotdb.ainode.core.util.decorator import synchronized @@ -76,9 +72,9 @@ def __init__(self, result_queue: mp.Queue): thread_name_prefix=ThreadName.INFERENCE_POOL_CONTROLLER.value ) - # =============== Pool Management =============== + # =============== Automatic Pool Management (Developing) =============== @synchronized(threading.Lock()) - def first_req_init(self, model_id: str): + def first_req_init(self, model_id: str, device): """ Initialize the pools when the first request for the given model_id arrives. """ @@ -107,38 +103,35 @@ def _first_pool_init(self, model_id: str, device_str: str): Initialize the first pool for the given model_id. Ensure the pool is ready before returning. """ - device = torch.device(device_str) - device_id = device.index - - if model_id == "sundial": - config = SundialConfig() - elif model_id == "timer_xl": - config = TimerConfig() - first_queue = mp.Queue() - ready_event = mp.Event() - first_pool = InferenceRequestPool( - pool_id=0, - model_id=model_id, - device=device_str, - config=config, - request_queue=first_queue, - result_queue=self._result_queue, - ready_event=ready_event, - ) - first_pool.start() - self._register_pool(model_id, device_str, 0, first_pool, first_queue) - - if not ready_event.wait(timeout=30): - self._erase_pool(model_id, device_id, 0) - logger.error( - f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" - ) - else: - self.set_state(model_id, device_id, 0, PoolState.RUNNING) - logger.info( - f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" - ) + pass + # device = torch.device(device_str) + # device_id = device.index + # + # first_queue = mp.Queue() + # ready_event = mp.Event() + # first_pool = InferenceRequestPool( + # pool_id=0, + # model_id=model_id, + # device=device_str, + # request_queue=first_queue, + # result_queue=self._result_queue, + # ready_event=ready_event, + # ) + # first_pool.start() + # self._register_pool(model_id, device_str, 0, first_pool, first_queue) + # + # if not ready_event.wait(timeout=30): + # self._erase_pool(model_id, device_id, 0) + # logger.error( + # f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" + # ) + # else: + # self.set_state(model_id, device_id, 0, PoolState.RUNNING) + # logger.info( + # f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" + # ) + # =============== Pool Management =============== def load_model(self, model_id: str, device_id_list: list[str]): """ Load the model to the specified devices asynchronously. @@ -255,29 +248,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int): """ def _expand_pool_on_device(*_): - result_queue = mp.Queue() + request_queue = mp.Queue() pool_id = self._new_pool_id.get_and_increment() model_info = self._model_manager.get_model_info(model_id) - model_type = model_info.model_type - if model_type == BuiltInModelType.SUNDIAL.value: - config = SundialConfig() - elif model_type == BuiltInModelType.TIMER_XL.value: - config = TimerConfig() - else: - raise InferenceModelInternalError( - f"Unsupported model type {model_type} for loading model {model_id}" - ) pool = InferenceRequestPool( pool_id=pool_id, model_info=model_info, device=device_id, - config=config, - request_queue=result_queue, + request_queue=request_queue, result_queue=self._result_queue, ready_event=mp.Event(), ) pool.start() - self._register_pool(model_id, device_id, pool_id, pool, result_queue) + self._register_pool(model_id, device_id, pool_id, pool, request_queue) if not pool.ready_event.wait(timeout=300): logger.error( f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool failed to be ready in time" diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 6a2bd2b619aa..d2e7292ecd8f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -36,7 +36,7 @@ estimate_pool_size, evaluate_system_resources, ) -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo +from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device logger = Logger() diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py deleted file mode 100644 index 2300169a6ee9..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from abc import ABC, abstractmethod - -import torch - - -class AbstractInferencePipeline(ABC): - """ - Abstract assistance strategy class for model inference. - This class shall define the interface process for specific model. - """ - - def __init__(self, model_config, **infer_kwargs): - self.model_config = model_config - self.infer_kwargs = infer_kwargs - - @abstractmethod - def preprocess_inputs(self, inputs: torch.Tensor): - """ - Preprocess the inputs before inference, including shape validation and value transformation. - - Args: - inputs (torch.Tensor): The input tensor to be preprocessed. - - Returns: - torch.Tensor: The preprocessed input tensor. - """ - # TODO: Integrate with the data processing pipeline operators - pass - - @abstractmethod - def post_decode(self): - """ - Post-process the outputs after each decode step. - """ - pass - - @abstractmethod - def post_inference(self): - """ - Post-process the outputs after the entire inference task. - """ - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index a67d576b0ec8..1ce2e84e0592 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -18,7 +18,6 @@ import threading import time -from abc import ABC, abstractmethod from typing import Dict import pandas as pd @@ -29,29 +28,22 @@ from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import ( InferenceModelInternalError, - InvalidWindowArgumentError, NumericalRangeException, - runtime_error_extractor, ) from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, InferenceRequestProxy, ) -from iotdb.ainode.core.inference.pool_controller import PoolController -from iotdb.ainode.core.inference.strategy.timer_sundial_inference_pipeline import ( - TimerSundialInferencePipeline, -) -from iotdb.ainode.core.inference.strategy.timerxl_inference_pipeline import ( - TimerXLInferencePipeline, +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, ) +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline +from iotdb.ainode.core.inference.pool_controller import PoolController from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.modeling_timer import TimerForPrediction from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.ainode.core.util.serde import convert_to_binary @@ -71,83 +63,6 @@ logger = Logger() -class InferenceStrategy(ABC): - def __init__(self, model): - self.model = model - - @abstractmethod - def infer(self, full_data, **kwargs): - pass - - -# [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length], -# we only get valueList currently. -class TimerXLStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) - df = pd.DataFrame(output[0]) - return convert_to_binary(df) - - -class SundialStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate( - seqs, max_new_tokens=predict_length, num_samples=10, revin=True - ) - df = pd.DataFrame(output[0].mean(dim=0)) - return convert_to_binary(df) - - -class BuiltInStrategy(InferenceStrategy): - def infer(self, full_data, **_): - data = pd.DataFrame(full_data[1]).T - output = self.model.inference(data) - df = pd.DataFrame(output) - return convert_to_binary(df) - - -class RegisteredStrategy(InferenceStrategy): - def infer(self, full_data, window_interval=None, window_step=None, **_): - _, dataset, _, length = full_data - if window_interval is None or window_step is None: - window_interval = length - window_step = float("inf") - - if window_interval <= 0 or window_step <= 0 or window_interval > length: - raise InvalidWindowArgumentError(window_interval, window_step, length) - - data = torch.tensor(dataset, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1) - - times = int((length - window_interval) // window_step + 1) - results = [] - try: - for i in range(times): - start = 0 if window_step == float("inf") else i * window_step - end = start + window_interval - window = data[:, start:end, :] - out = self.model(window) - df = pd.DataFrame(out.squeeze(0).detach().numpy()) - results.append(df) - except Exception as e: - msg = runtime_error_extractor(str(e)) or str(e) - raise InferenceModelInternalError(msg) - - # concatenate or return first window for forecast - return [convert_to_binary(df) for df in results] - - class InferenceManager: WAITING_INTERVAL_IN_MS = ( AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms() @@ -251,15 +166,6 @@ def _process_request(self, req): with self._result_wrapper_lock: del self._result_wrapper_map[req_id] - def _get_strategy(self, model_id, model): - if isinstance(model, TimerForPrediction): - return TimerXLStrategy(model) - if isinstance(model, SundialForPrediction): - return SundialStrategy(model) - if self._model_manager.model_storage.is_built_in_or_fine_tuned(model_id): - return BuiltInStrategy(model) - return RegisteredStrategy(model) - def _run( self, req, @@ -272,59 +178,54 @@ def _run( model_id = req.modelId try: raw = data_getter(req) + # full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently. full_data = deserializer(raw) - inference_attrs = extract_attrs(req) + # TODO: TSBlock -> Tensor codes should be unified + data = full_data[1][0] # get valueList in ndarray + if data.dtype.byteorder not in ("=", "|"): + np_data = data.byteswap() + data = np_data.view(np_data.dtype.newbyteorder()) + # the inputs should be on CPU before passing to the inference request + inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - predict_length = int(inference_attrs.pop("predict_length", 96)) + inference_attrs = extract_attrs(req) + output_length = int(inference_attrs.pop("output_length", 96)) if ( - predict_length - > AINodeDescriptor().get_config().get_ain_inference_max_predict_length() + output_length + > AINodeDescriptor().get_config().get_ain_inference_max_output_length() ): raise NumericalRangeException( "output_length", 1, AINodeDescriptor() .get_config() - .get_ain_inference_max_predict_length(), - predict_length, + .get_ain_inference_max_output_length(), + output_length, ) if self._pool_controller.has_request_pools(model_id): - # use request pool to accelerate inference when the model instance is already loaded. - # TODO: TSBlock -> Tensor codes should be unified - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - # the inputs should be on CPU before passing to the inference request - inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - model_type = self._model_manager.get_model_info(model_id).model_type - if model_type == BuiltInModelType.SUNDIAL.value: - inference_pipeline = TimerSundialInferencePipeline(SundialConfig()) - elif model_type == BuiltInModelType.TIMER_XL.value: - inference_pipeline = TimerXLInferencePipeline(TimerConfig()) - else: - raise InferenceModelInternalError( - f"Unsupported model_id: {model_id}" - ) infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id, inputs=inputs, - inference_pipeline=inference_pipeline, - max_new_tokens=predict_length, + output_length=output_length, ) outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) else: - # load model - accel = str(inference_attrs.get("acceleration", "")).lower() == "true" - model = self._model_manager.load_model(model_id, inference_attrs, accel) - # inference by strategy - strategy = self._get_strategy(model_id, model) - outputs = strategy.infer( - full_data, predict_length=predict_length, **inference_attrs - ) + model_info = self._model_manager.get_model_info(model_id) + inference_pipeline = load_pipeline(model_info, device="cpu") + if isinstance(inference_pipeline, ForecastPipeline): + outputs = inference_pipeline.forecast( + inputs, predict_length=output_length, **inference_attrs + ) + elif isinstance(inference_pipeline, ClassificationPipeline): + outputs = inference_pipeline.classify(inputs) + elif isinstance(inference_pipeline, ChatPipeline): + outputs = inference_pipeline.chat(inputs) + else: + logger.error("[Inference] Unsupported pipeline type.") + outputs = convert_to_binary(pd.DataFrame(outputs[0])) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) @@ -345,7 +246,7 @@ def forecast(self, req: TForecastReq): data_getter=lambda r: r.inputData, deserializer=deserialize, extract_attrs=lambda r: { - "predict_length": r.outputLength, + "output_length": r.outputLength, **(r.options or {}), }, resp_cls=TForecastResp, @@ -358,8 +259,7 @@ def inference(self, req: TInferenceReq): data_getter=lambda r: r.dataset, deserializer=deserialize, extract_attrs=lambda r: { - "window_interval": getattr(r.windowParams, "windowInterval", None), - "window_step": getattr(r.windowParams, "windowStep", None), + "output_length": int(r.inferenceAttributes.pop("outputLength", 96)), **(r.inferenceAttributes or {}), }, resp_cls=TInferenceResp, diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index d84bca77c843..8ffb33d91e2d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -15,17 +15,14 @@ # specific language governing permissions and limitations # under the License. # -from typing import Callable, Dict -from torch import nn -from yaml import YAMLError +from typing import Any, List, Optional from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType, ModelStates -from iotdb.ainode.core.model.model_info import ModelInfo -from iotdb.ainode.core.model.model_storage import ModelStorage +from iotdb.ainode.core.model.model_loader import load_model +from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.ainode.ttypes import ( @@ -43,127 +40,60 @@ @singleton class ModelManager: def __init__(self): - self.model_storage = ModelStorage() + self._model_storage = ModelStorage() - def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: - logger.info(f"register model {req.modelId} from {req.uri}") + def register_model( + self, + req: TRegisterModelReq, + ) -> TRegisterModelResp: try: - configs, attributes = self.model_storage.register_model( - req.modelId, req.uri - ) - return TRegisterModelResp( - get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes - ) - except InvalidUriError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, e.message) - ) - except BadConfigValueError as e: - logger.warning(e) + if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + except ValueError as e: return TRegisterModelResp( - get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) + get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) ) - except YAMLError as e: - logger.warning(e) - if hasattr(e, "problem_mark"): - mark = e.problem_mark - return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file, " - f"at line {mark.line + 1} column {mark.column + 1}.", - ) - ) + except Exception as e: return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file", - ) + get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) ) - except Exception as e: - logger.warning(e) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + self._refresh() + return self._model_storage.show_models(req) def delete_model(self, req: TDeleteModelReq) -> TSStatus: - logger.info(f"delete model {req.modelId}") try: - self.model_storage.delete_model(req.modelId) + self._model_storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) - except Exception as e: + except BuiltInModelDeletionError as e: logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False - ) -> Callable: - """ - Load the model with the given model_id. - """ - logger.info(f"Load model {model_id}") - try: - model = self.model_storage.load_model( - model_id, inference_attrs, acceleration - ) - logger.info(f"Model {model_id} loaded") - return model - except Exception as e: - logger.error(f"Failed to load model {model_id}: {e}") - raise - - def save_model(self, model_id: str, model: nn.Module) -> TSStatus: - """ - Save the model using save_pretrained - """ - logger.info(f"Saving model {model_id}") - try: - self.model_storage.save_model(model_id, model) - logger.info(f"Saving model {model_id} successfully") - return get_status( - TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved successfully" - ) except Exception as e: - logger.error(f"Save model failed: {e}") + logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. - - Args: - model_id (str): The ID of the model. - - Returns: - str: The path to the checkpoint file for the model. - """ - return self.model_storage.get_ckpt_path(model_id) - - def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - return self.model_storage.show_models(req) - - def register_built_in_model(self, model_info: ModelInfo): - self.model_storage.register_built_in_model(model_info) - - def get_model_info(self, model_id: str) -> ModelInfo: - return self.model_storage.get_model_info(model_id) - - def update_model_state(self, model_id: str, state: ModelStates): - self.model_storage.update_model_state(model_id, state) - - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: - """ - Get the type of the model with the given model_id. - """ - return self.model_storage.get_built_in_model_type(model_id) - - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. - - Args: - model_id (str): The ID of the model. - - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return self.model_storage.is_built_in_or_fine_tuned(model_id) + def get_model_info( + self, + model_id: str, + category: Optional[ModelCategory] = None, + ) -> Optional[ModelInfo]: + return self._model_storage.get_model_info(model_id, category) + + def get_model_infos( + self, + category: Optional[ModelCategory] = None, + model_type: Optional[str] = None, + ) -> List[ModelInfo]: + return self._model_storage.get_model_infos(category, model_type) + + def _refresh(self): + """Refresh the model list (re-scan the file system)""" + self._model_storage.discover_all_models() + + def get_registered_models(self) -> List[str]: + return self._model_storage.get_registered_models() + + def is_model_registered(self, model_id: str) -> bool: + return self._model_storage.is_model_registered(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 0264e27331a8..23a98f26bbff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -25,7 +25,7 @@ from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP +from iotdb.ainode.core.model.model_loader import load_model logger = Logger() @@ -47,7 +47,8 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id, {}).to(device) + model_info = ModelManager().get_model_info(model_id) + model = load_model(model_info).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start @@ -80,7 +81,7 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = BUILT_IN_LTSM_MAP.get(model_id, None) + model_info = ModelManager().get_model_info(model_id) if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py deleted file mode 100644 index 3b55142350ba..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py +++ /dev/null @@ -1,1238 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from abc import abstractmethod -from typing import Callable, Dict, List - -import numpy as np -from huggingface_hub import hf_hub_download -from sklearn.preprocessing import MinMaxScaler -from sktime.detection.hmm_learn import GMMHMM, GaussianHMM -from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA -from sktime.forecasting.exp_smoothing import ExponentialSmoothing -from sktime.forecasting.naive import NaiveForecaster -from sktime.forecasting.trend import STLForecaster - -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - AttributeName, -) -from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - InferenceModelInternalError, - ListRangeException, - NumericalRangeException, - StringRangeException, - WrongAttributeTypeError, -) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import TIMER_REPO_ID -from iotdb.ainode.core.model.sundial import modeling_sundial -from iotdb.ainode.core.model.timerxl import modeling_timer - -logger = Logger() - - -def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: - weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) - if not os.path.exists(weights_path): - logger.info( - f"Model weights file not found at {weights_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - local_dir=local_dir, - ) - logger.info(f"Got file to {weights_path}") - except Exception as e: - logger.error( - f"Failed to download model weights file to {local_dir} due to {e}" - ) - return False - if not os.path.exists(config_path): - logger.info( - f"Model config file not found at {config_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, - local_dir=local_dir, - ) - logger.info(f"Got file to {config_path}") - except Exception as e: - logger.error( - f"Failed to download model config file to {local_dir} due to {e}" - ) - return False - return True - - -def download_built_in_ltsm_from_hf_if_necessary( - model_type: BuiltInModelType, local_dir: str -) -> bool: - """ - Download the built-in ltsm from HuggingFace repository when necessary. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - repo_id = TIMER_REPO_ID[model_type] - if not _download_file_from_hf_if_necessary(local_dir, repo_id): - return False - return True - - -def get_model_attributes(model_type: BuiltInModelType): - if model_type == BuiltInModelType.ARIMA: - attribute_map = arima_attribute_map - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - attribute_map = naive_forecaster_attribute_map - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - attribute_map = exponential_smoothing_attribute_map - elif model_type == BuiltInModelType.STL_FORECASTER: - attribute_map = stl_forecaster_attribute_map - elif model_type == BuiltInModelType.GMM_HMM: - attribute_map = gmmhmm_attribute_map - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - attribute_map = gaussian_hmm_attribute_map - elif model_type == BuiltInModelType.STRAY: - attribute_map = stray_attribute_map - elif model_type == BuiltInModelType.TIMER_XL: - attribute_map = timerxl_attribute_map - elif model_type == BuiltInModelType.SUNDIAL: - attribute_map = sundial_attribute_map - else: - raise BuiltInModelNotSupportError(model_type.value) - return attribute_map - - -def fetch_built_in_model( - model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str] -) -> Callable: - """ - Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config. - Args: - model_type: the type of the built-in model - model_dir: for huggingface models only, the directory where the model is stored - Returns: - model: the built-in model - """ - default_attributes = get_model_attributes(model_type) - # parse the attributes from inference_attrs - attributes = parse_attribute(inference_attrs, default_attributes) - - # build the built-in model - if model_type == BuiltInModelType.ARIMA: - model = ArimaModel(attributes) - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - model = ExponentialSmoothingModel(attributes) - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - model = NaiveForecasterModel(attributes) - elif model_type == BuiltInModelType.STL_FORECASTER: - model = STLForecasterModel(attributes) - elif model_type == BuiltInModelType.GMM_HMM: - model = GMMHMMModel(attributes) - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - model = GaussianHmmModel(attributes) - elif model_type == BuiltInModelType.STRAY: - model = STRAYModel(attributes) - elif model_type == BuiltInModelType.TIMER_XL: - model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) - elif model_type == BuiltInModelType.SUNDIAL: - model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) - else: - raise BuiltInModelNotSupportError(model_type.value) - - return model - - -class Attribute(object): - def __init__(self, name: str): - """ - Args: - name: the name of the attribute - """ - self._name = name - - @abstractmethod - def get_default_value(self): - raise NotImplementedError - - @abstractmethod - def validate_value(self, value): - raise NotImplementedError - - @abstractmethod - def parse(self, string_value: str): - raise NotImplementedError - - -class IntAttribute(Attribute): - def __init__( - self, - name: str, - default_value: int, - default_low: int, - default_high: int, - ): - super(IntAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - int_value = int(string_value) - except: - raise WrongAttributeTypeError(self._name, "int") - return int_value - - -class FloatAttribute(Attribute): - def __init__( - self, - name: str, - default_value: float, - default_low: float, - default_high: float, - ): - super(FloatAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - float_value = float(string_value) - except: - raise WrongAttributeTypeError(self._name, "float") - return float_value - - -class StringAttribute(Attribute): - def __init__(self, name: str, default_value: str, value_choices: List[str]): - super(StringAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_choices = value_choices - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if value in self.__value_choices: - return True - raise StringRangeException(self._name, value, self.__value_choices) - - def parse(self, string_value: str): - return string_value - - -class BooleanAttribute(Attribute): - def __init__(self, name: str, default_value: bool): - super(BooleanAttribute, self).__init__(name) - self.__default_value = default_value - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if isinstance(value, bool): - return True - raise WrongAttributeTypeError(self._name, "bool") - - def parse(self, string_value: str): - if string_value.lower() == "true": - return True - elif string_value.lower() == "false": - return False - else: - raise WrongAttributeTypeError(self._name, "bool") - - -class ListAttribute(Attribute): - def __init__(self, name: str, default_value: List, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(ListAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, list): - raise WrongAttributeTypeError(self._name, "list") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - list_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "list") - if not isinstance(list_value, list): - raise WrongAttributeTypeError(self._name, "list") - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - return list_value - - -class TupleAttribute(Attribute): - def __init__(self, name: str, default_value: tuple, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(TupleAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - tuple_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "tuple") - if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - list_value = list(tuple_value) - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - tuple_value = tuple(list_value) - return tuple_value - - -def parse_attribute( - input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] -): - """ - Args: - input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of - the attribute - attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute - object - Returns: - a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute - """ - attributes = {} - for attribute_name in attribute_map: - # user specified the attribute - if attribute_name in input_attributes: - attribute = attribute_map[attribute_name] - value = attribute.parse(input_attributes[attribute_name]) - attribute.validate_value(value) - attributes[attribute_name] = value - # user did not specify the attribute, use the default value - else: - try: - attributes[attribute_name] = attribute_map[ - attribute_name - ].get_default_value() - except NotImplementedError as e: - logger.error(f"attribute {attribute_name} is not implemented.") - raise e - return attributes - - -sundial_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=16, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=768, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=3072, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=12, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=12, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.DROPOUT_RATE.value: FloatAttribute( - name=AttributeName.DROPOUT_RATE.value, - default_value=0.1, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute( - name=AttributeName.FLOW_LOSS_DEPTH.value, - default_value=3, - default_low=1, - default_high=50, - ), - AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute( - name=AttributeName.NUM_SAMPLING_STEPS.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute( - name=AttributeName.DIFFUSION_BATCH_MUL.value, - default_value=4, - default_low=1, - default_high=5000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "sundial", - ), - value_choices=[""], - ), -} - -timerxl_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=96, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=1024, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=2048, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=8, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=8, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.ATTENTION_DROPOUT.value: FloatAttribute( - name=AttributeName.ATTENTION_DROPOUT.value, - default_value=0.0, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "timerxl", - "model.safetensors", - ), - value_choices=[""], - ), -} - -# built-in sktime model attributes -# NaiveForecaster -naive_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.STRATEGY.value: StringAttribute( - name=AttributeName.STRATEGY.value, - default_value="last", - value_choices=["last", "mean"], - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 - ), -} -# ExponentialSmoothing -exponential_smoothing_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.DAMPED_TREND.value: BooleanAttribute( - name=AttributeName.DAMPED_TREND.value, - default_value=False, - ), - AttributeName.INITIALIZATION_METHOD.value: StringAttribute( - name=AttributeName.INITIALIZATION_METHOD.value, - default_value="estimated", - value_choices=["estimated", "heuristic", "legacy-heuristic", "known"], - ), - AttributeName.OPTIMIZED.value: BooleanAttribute( - name=AttributeName.OPTIMIZED.value, - default_value=True, - ), - AttributeName.REMOVE_BIAS.value: BooleanAttribute( - name=AttributeName.REMOVE_BIAS.value, - default_value=False, - ), - AttributeName.USE_BRUTE.value: BooleanAttribute( - name=AttributeName.USE_BRUTE.value, - default_value=False, - ), -} -# Arima -arima_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.ORDER.value: TupleAttribute( - name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int - ), - AttributeName.SEASONAL_ORDER.value: TupleAttribute( - name=AttributeName.SEASONAL_ORDER.value, - default_value=(0, 0, 0, 0), - value_type=int, - ), - AttributeName.METHOD.value: StringAttribute( - name=AttributeName.METHOD.value, - default_value="lbfgs", - value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], - ), - AttributeName.MAXITER.value: IntAttribute( - name=AttributeName.MAXITER.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( - name=AttributeName.SUPPRESS_WARNINGS.value, - default_value=True, - ), - AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute( - name=AttributeName.OUT_OF_SAMPLE_SIZE.value, - default_value=0, - default_low=0, - default_high=5000, - ), - AttributeName.SCORING.value: StringAttribute( - name=AttributeName.SCORING.value, - default_value="mse", - value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], - ), - AttributeName.WITH_INTERCEPT.value: BooleanAttribute( - name=AttributeName.WITH_INTERCEPT.value, - default_value=True, - ), - AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute( - name=AttributeName.TIME_VARYING_REGRESSION.value, - default_value=False, - ), - AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_STATIONARITY.value, - default_value=True, - ), - AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_INVERTIBILITY.value, - default_value=True, - ), - AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute( - name=AttributeName.SIMPLE_DIFFERENCING.value, - default_value=False, - ), - AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute( - name=AttributeName.MEASUREMENT_ERROR.value, - default_value=False, - ), - AttributeName.MLE_REGRESSION.value: BooleanAttribute( - name=AttributeName.MLE_REGRESSION.value, - default_value=True, - ), - AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute( - name=AttributeName.HAMILTON_REPRESENTATION.value, - default_value=False, - ), - AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( - name=AttributeName.CONCENTRATE_SCALE.value, - default_value=False, - ), -} -# STLForecaster -stl_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000 - ), - AttributeName.SEASONAL.value: IntAttribute( - name=AttributeName.SEASONAL.value, - default_value=7, - default_low=1, - default_high=5000, - ), - AttributeName.SEASONAL_DEG.value: IntAttribute( - name=AttributeName.SEASONAL_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_DEG.value: IntAttribute( - name=AttributeName.TREND_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOW_PASS_DEG.value: IntAttribute( - name=AttributeName.LOW_PASS_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.SEASONAL_JUMP.value: IntAttribute( - name=AttributeName.SEASONAL_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_JUMP.value: IntAttribute( - name=AttributeName.TREND_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOSS_PASS_JUMP.value: IntAttribute( - name=AttributeName.LOSS_PASS_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), -} - -# GAUSSIAN_HMM -gaussian_hmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["spherical", "diag", "full", "tied"], - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_PRIOR.value: FloatAttribute( - name=AttributeName.COVARS_PRIOR.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_WEIGHT.value: FloatAttribute( - name=AttributeName.COVARS_WEIGHT.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# GMMHMM -gmmhmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.N_MIX.value: IntAttribute( - name=AttributeName.N_MIX.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.WEIGHTS_PRIOR.value: FloatAttribute( - name=AttributeName.WEIGHTS_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["sperical", "diag", "full", "tied"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# STRAY -stray_attribute_map = { - AttributeName.ALPHA.value: FloatAttribute( - name=AttributeName.ALPHA.value, - default_value=0.01, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.K.value: IntAttribute( - name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000 - ), - AttributeName.KNN_ALGORITHM.value: StringAttribute( - name=AttributeName.KNN_ALGORITHM.value, - default_value="brute", - value_choices=["brute", "kd_tree", "ball_tree", "auto"], - ), - AttributeName.P.value: FloatAttribute( - name=AttributeName.P.value, - default_value=0.5, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.SIZE_THRESHOLD.value: IntAttribute( - name=AttributeName.SIZE_THRESHOLD.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.OUTLIER_TAIL.value: StringAttribute( - name=AttributeName.OUTLIER_TAIL.value, - default_value="max", - value_choices=["min", "max"], - ), -} - - -class BuiltInModel(object): - def __init__(self, attributes): - self._attributes = attributes - self._model = None - - @abstractmethod - def inference(self, data): - raise NotImplementedError - - -class ArimaModel(BuiltInModel): - def __init__(self, attributes): - super(ArimaModel, self).__init__(attributes) - self._model = ARIMA( - order=attributes["order"], - seasonal_order=attributes["seasonal_order"], - method=attributes["method"], - suppress_warnings=attributes["suppress_warnings"], - maxiter=attributes["maxiter"], - out_of_sample_size=attributes["out_of_sample_size"], - scoring=attributes["scoring"], - with_intercept=attributes["with_intercept"], - time_varying_regression=attributes["time_varying_regression"], - enforce_stationarity=attributes["enforce_stationarity"], - enforce_invertibility=attributes["enforce_invertibility"], - simple_differencing=attributes["simple_differencing"], - measurement_error=attributes["measurement_error"], - mle_regression=attributes["mle_regression"], - hamilton_representation=attributes["hamilton_representation"], - concentrate_scale=attributes["concentrate_scale"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class ExponentialSmoothingModel(BuiltInModel): - def __init__(self, attributes): - super(ExponentialSmoothingModel, self).__init__(attributes) - self._model = ExponentialSmoothing( - damped_trend=attributes["damped_trend"], - initialization_method=attributes["initialization_method"], - optimized=attributes["optimized"], - remove_bias=attributes["remove_bias"], - use_brute=attributes["use_brute"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class NaiveForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(NaiveForecasterModel, self).__init__(attributes) - self._model = NaiveForecaster( - strategy=attributes["strategy"], sp=attributes["sp"] - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STLForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(STLForecasterModel, self).__init__(attributes) - self._model = STLForecaster( - sp=attributes["sp"], - seasonal=attributes["seasonal"], - seasonal_deg=attributes["seasonal_deg"], - trend_deg=attributes["trend_deg"], - low_pass_deg=attributes["low_pass_deg"], - seasonal_jump=attributes["seasonal_jump"], - trend_jump=attributes["trend_jump"], - low_pass_jump=attributes["low_pass_jump"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GMMHMMModel(BuiltInModel): - def __init__(self, attributes): - super(GMMHMMModel, self).__init__(attributes) - self._model = GMMHMM( - n_components=attributes["n_components"], - n_mix=attributes["n_mix"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - weights_prior=attributes["weights_prior"], - algorithm=attributes["algorithm"], - covariance_type=attributes["covariance_type"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GaussianHmmModel(BuiltInModel): - def __init__(self, attributes): - super(GaussianHmmModel, self).__init__(attributes) - self._model = GaussianHMM( - n_components=attributes["n_components"], - covariance_type=attributes["covariance_type"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - covars_prior=attributes["covars_prior"], - covars_weight=attributes["covars_weight"], - algorithm=attributes["algorithm"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STRAYModel(BuiltInModel): - def __init__(self, attributes): - super(STRAYModel, self).__init__(attributes) - self._model = STRAY( - alpha=attributes["alpha"], - k=attributes["k"], - knn_algorithm=attributes["knn_algorithm"], - p=attributes["p"], - size_threshold=attributes["size_threshold"], - outlier_tail=attributes["outlier_tail"], - ) - - def inference(self, data): - try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - # change the output to int - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py new file mode 100644 index 000000000000..c42ec98551b8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from enum import Enum + +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + +class ModelCategory(Enum): + BUILTIN = "builtin" + USER_DEFINED = "user_defined" + + +class ModelStates(Enum): + INACTIVE = "inactive" + ACTIVATING = "activating" + ACTIVE = "active" + DROPPING = "dropping" + + +class UriType(Enum): + REPO = "repo" + FILE = "file" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py deleted file mode 100644 index 348f9924316b..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from enum import Enum -from typing import List - - -class BuiltInModelType(Enum): - # forecast models - ARIMA = "Arima" - HOLTWINTERS = "HoltWinters" - EXPONENTIAL_SMOOTHING = "ExponentialSmoothing" - NAIVE_FORECASTER = "NaiveForecaster" - STL_FORECASTER = "StlForecaster" - - # anomaly detection models - GAUSSIAN_HMM = "GaussianHmm" - GMM_HMM = "GmmHmm" - STRAY = "Stray" - - # large time series models (LTSM) - TIMER_XL = "Timer-XL" - # sundial - SUNDIAL = "Timer-Sundial" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def is_built_in_model(model_type: str) -> bool: - """ - Check if the given model type corresponds to a built-in model. - """ - return model_type in BuiltInModelType.values() - - -class ModelFileType(Enum): - SAFETENSORS = "safetensors" - PYTORCH = "pytorch" - UNKNOWN = "unknown" - - -class ModelCategory(Enum): - BUILT_IN = "BUILT-IN" - FINE_TUNED = "FINE-TUNED" - USER_DEFINED = "USER-DEFINED" - - -class ModelStates(Enum): - ACTIVE = "ACTIVE" - INACTIVE = "INACTIVE" - LOADING = "LOADING" - DROPPING = "DROPPING" - TRAINING = "TRAINING" - FAILED = "FAILED" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py deleted file mode 100644 index 26d863156f37..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py +++ /dev/null @@ -1,291 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import glob -import os -import shutil -from urllib.parse import urljoin - -import yaml - -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.uri_utils import ( - UriType, - download_file, - download_snapshot_from_hf, -) -from iotdb.ainode.core.util.serde import get_data_type_byte_from_str -from iotdb.thrift.ainode.ttypes import TConfigs - -logger = Logger() - - -def fetch_model_by_uri( - uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Fetch the model files from the specified URI. - - Args: - uri_type (UriType): type of the URI, either repo, file, http or https - uri (str): a network or a local path of the model to be registered - storage_path (str): path to save the whole model, including weights, config, codes, etc. - model_file_type (ModelFileType): The type of model file, either safetensors or pytorch - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if uri_type == UriType.REPO or uri_type in [UriType.HTTP, UriType.HTTPS]: - return _fetch_model_from_network(uri, storage_path, model_file_type) - elif uri_type == UriType.FILE: - return _fetch_model_from_local(uri, storage_path, model_file_type) - else: - raise InvalidUriError(f"Invalid URI type: {uri_type}") - - -def _fetch_model_from_network( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - download_snapshot_from_hf(uri, storage_path) - return _process_huggingface_files(storage_path) - - # TODO: The following codes might be refactored in future - # concat uri to get complete url - uri = uri if uri.endswith("/") else uri + "/" - target_model_path = urljoin(uri, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = urljoin(uri, MODEL_CONFIG_FILE_IN_YAML) - - # download config file - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - download_file(target_config_path, config_storage_path) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, download model file - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - download_file(target_model_path, model_storage_path) - return configs, attributes - - -def _fetch_model_from_local( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - # copy anything in the uri to local_dir - for file in os.listdir(uri): - shutil.copy(os.path.join(uri, file), storage_path) - return _process_huggingface_files(storage_path) - # concat uri to get complete path - target_model_path = os.path.join(uri, MODEL_WEIGHTS_FILE_IN_PT) - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = os.path.join(uri, MODEL_CONFIG_FILE_IN_YAML) - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - - # check if file exist - exist_model_file = os.path.exists(target_model_path) - exist_config_file = os.path.exists(target_config_path) - - configs = None - attributes = None - if exist_model_file and exist_config_file: - # copy config.yaml - shutil.copy(target_config_path, config_storage_path) - logger.info( - f"copy file from {target_config_path} to {config_storage_path} success" - ) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, copy model file - shutil.copy(target_model_path, model_storage_path) - logger.info( - f"copy file from {target_model_path} to {model_storage_path} success" - ) - - elif not exist_model_file or not exist_config_file: - raise InvalidUriError(uri) - - return configs, attributes - - -def _parse_inference_config(config_dict): - """ - Args: - config_dict: dict - - configs: dict - - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2] - - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2] - - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - attributes: dict - Returns: - configs: TConfigs - attributes: str - """ - configs = config_dict["configs"] - - # check if input_shape and output_shape are two-dimensional array - if not ( - isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2 - ): - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "input_shape should be a two-dimensional array.", - ) - if not ( - isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2 - ): - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "output_shape should be a two-dimensional array.", - ) - - # check if input_shape and output_shape are positive integer - input_shape_is_positive_number = ( - isinstance(configs["input_shape"][0], int) - and isinstance(configs["input_shape"][1], int) - and configs["input_shape"][0] > 0 - and configs["input_shape"][1] > 0 - ) - if not input_shape_is_positive_number: - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "element in input_shape should be positive integer.", - ) - - output_shape_is_positive_number = ( - isinstance(configs["output_shape"][0], int) - and isinstance(configs["output_shape"][1], int) - and configs["output_shape"][0] > 0 - and configs["output_shape"][1] > 0 - ) - if not output_shape_is_positive_number: - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "element in output_shape should be positive integer.", - ) - - # check if input_type and output_type are one-dimensional array with right length - if "input_type" in configs and not ( - isinstance(configs["input_type"], list) - and len(configs["input_type"]) == configs["input_shape"][1] - ): - raise BadConfigValueError( - "input_type", - configs["input_type"], - "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].", - ) - - if "output_type" in configs and not ( - isinstance(configs["output_type"], list) - and len(configs["output_type"]) == configs["output_shape"][1] - ): - raise BadConfigValueError( - "output_type", - configs["output_type"], - "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].", - ) - - # parse input_type and output_type to byte - if "input_type" in configs: - input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]] - else: - input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][ - 1 - ] - - if "output_type" in configs: - output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]] - else: - output_type = [get_data_type_byte_from_str("float32")] * configs[ - "output_shape" - ][1] - - # parse attributes - attributes = "" - if "attributes" in config_dict: - attributes = str(config_dict["attributes"]) - - return ( - TConfigs( - configs["input_shape"], configs["output_shape"], input_type, output_type - ), - attributes, - ) - - -def _process_huggingface_files(local_dir: str): - """ - TODO: Currently, we use this function to convert the model config from huggingface, we will refactor this in the future. - """ - config_file = None - for config_name in ["config.json", "model_config.json"]: - config_path = os.path.join(local_dir, config_name) - if os.path.exists(config_path): - config_file = config_path - break - - if not config_file: - raise InvalidUriError(f"No config.json found in {local_dir}") - - safetensors_files = glob.glob(os.path.join(local_dir, "*.safetensors")) - if not safetensors_files: - raise InvalidUriError(f"No .safetensors files found in {local_dir}") - - simple_config = { - "configs": { - "input_shape": [96, 1], - "output_shape": [96, 1], - "input_type": ["float32"], - "output_type": ["float32"], - }, - "attributes": { - "model_type": "huggingface_model", - "source_dir": local_dir, - "files": [os.path.basename(f) for f in safetensors_files], - }, - } - - configs, attributes = _parse_inference_config(simple_config) - return configs, attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 167bfd76640d..718ead530dd2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -15,140 +15,118 @@ # specific language governing permissions and limitations # under the License. # -import glob -import os -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) +from typing import Dict, Optional - -def get_model_file_type(model_path: str) -> ModelFileType: - """ - Determine the file type of the specified model directory. - """ - if _has_safetensors_format(model_path): - return ModelFileType.SAFETENSORS - elif _has_pytorch_format(model_path): - return ModelFileType.PYTORCH - else: - return ModelFileType.UNKNOWN - - -def _has_safetensors_format(path: str) -> bool: - """Check if directory contains safetensors files.""" - safetensors_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)) - json_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_JSON)) - return len(safetensors_files) > 0 and len(json_files) > 0 - - -def _has_pytorch_format(path: str) -> bool: - """Check if directory contains pytorch files.""" - pt_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_PT)) - yaml_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_YAML)) - return len(pt_files) > 0 and len(yaml_files) > 0 - - -def get_built_in_model_type(model_type: str) -> BuiltInModelType: - if not BuiltInModelType.is_built_in_model(model_type): - raise ValueError(f"Invalid built-in model type: {model_type}") - return BuiltInModelType(model_type) +from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates class ModelInfo: def __init__( self, model_id: str, - model_type: str, category: ModelCategory, state: ModelStates, + model_type: str = "", + config_cls: str = "", + model_cls: str = "", + pipeline_cls: str = "", + repo_id: str = "", + auto_map: Optional[Dict] = None, + _transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type self.category = category self.state = state + self.config_cls = config_cls + self.model_cls = model_cls + self.pipeline_cls = pipeline_cls + self.repo_id = repo_id + self.auto_map = auto_map # If exists, indicates it's a Transformers model + self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + def __repr__(self): + return ( + f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " + f"category={self.category.value}, state={self.state.value}, " + f"has_auto_map={self.auto_map is not None})" + ) -TIMER_REPO_ID = { - BuiltInModelType.TIMER_XL: "thuml/timer-base-84m", - BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m", -} -# Built-in machine learning models, they can be employed directly -BUILT_IN_MACHINE_LEARNING_MODEL_MAP = { +BUILTIN_SKTIME_MODEL_MAP = { # forecast models "arima": ModelInfo( model_id="arima", - model_type=BuiltInModelType.ARIMA.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "holtwinters": ModelInfo( model_id="holtwinters", - model_type=BuiltInModelType.HOLTWINTERS.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "exponential_smoothing": ModelInfo( model_id="exponential_smoothing", - model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "naive_forecaster": ModelInfo( model_id="naive_forecaster", - model_type=BuiltInModelType.NAIVE_FORECASTER.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stl_forecaster": ModelInfo( model_id="stl_forecaster", - model_type=BuiltInModelType.STL_FORECASTER.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), # anomaly detection models "gaussian_hmm": ModelInfo( model_id="gaussian_hmm", - model_type=BuiltInModelType.GAUSSIAN_HMM.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "gmm_hmm": ModelInfo( model_id="gmm_hmm", - model_type=BuiltInModelType.GMM_HMM.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stray": ModelInfo( model_id="stray", - model_type=BuiltInModelType.STRAY.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), } -# Built-in large time series models (LTSM), their weights are not included in AINode by default -BUILT_IN_LTSM_MAP = { +# Built-in huggingface transformers models, their weights are not included in AINode by default +BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { "timer_xl": ModelInfo( model_id="timer_xl", - model_type=BuiltInModelType.TIMER_XL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="timer", + config_cls="configuration_timer.TimerConfig", + model_cls="modeling_timer.TimerForPrediction", + pipeline_cls="pipeline_timer.TimerPipeline", + repo_id="thuml/timer-base-84m", ), "sundial": ModelInfo( model_id="sundial", - model_type=BuiltInModelType.SUNDIAL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="sundial", + config_cls="configuration_sundial.SundialConfig", + model_cls="modeling_sundial.SundialForPrediction", + pipeline_cls="pipeline_sundial.SundialPipeline", + repo_id="thuml/sundial-base-128m", ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py new file mode 100644 index 000000000000..a6e3b1f7b5e3 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from pathlib import Path +from typing import Any + +import torch +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForNextSentencePrediction, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTimeSeriesPrediction, + AutoModelForTokenClassification, +) + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path + +logger = Logger() + + +def load_model(model_info: ModelInfo, **model_kwargs) -> Any: + if model_info.auto_map is not None: + model = load_model_from_transformers(model_info, **model_kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_info.model_id) + else: + model = load_model_from_pt(model_info, **model_kwargs) + + logger.info( + f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully." + ) + return model + + +def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): + device_map = model_kwargs.get("device_map", "cpu") + trust_remote_code = model_kwargs.get("trust_remote_code", True) + train_from_scratch = model_kwargs.get("train_from_scratch", False) + + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + + if model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + config_cls = import_class_from_path(module_name, model_info.config_cls) + model_cls = import_class_from_path(module_name, model_info.model_cls) + elif model_info.model_cls and model_info.config_cls: + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_cls = import_class_from_path( + model_info.model_id, model_info.config_cls + ) + model_cls = import_class_from_path( + model_info.model_id, model_info.model_cls + ) + else: + config_cls = AutoConfig.from_pretrained(model_path) + if type(config_cls) in AutoModelForTimeSeriesPrediction._model_mapping.keys(): + model_cls = AutoModelForTimeSeriesPrediction + elif ( + type(config_cls) in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + model_cls = AutoModelForNextSentencePrediction + elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys(): + model_cls = AutoModelForSeq2SeqLM + elif ( + type(config_cls) in AutoModelForSequenceClassification._model_mapping.keys() + ): + model_cls = AutoModelForSequenceClassification + elif type(config_cls) in AutoModelForTokenClassification._model_mapping.keys(): + model_cls = AutoModelForTokenClassification + else: + model_cls = AutoModelForCausalLM + + if train_from_scratch: + model = model_cls.from_config( + config_cls, trust_remote_code=trust_remote_code, device_map=device_map + ) + else: + model = model_cls.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + device_map=device_map, + ) + + return model + + +def load_model_from_pt(model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + acceleration = kwargs.get("acceleration", False) + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + model_file = os.path.join(model_path, "model.pt") + if not os.path.exists(model_file): + logger.error(f"Model file not found at {model_file}.") + raise ModelNotExistError(model_file) + model = torch.jit.load(model_file) + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: + return model + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + return model.to(device_map) + + +def load_model_for_efficient_inference(): + # TODO: An efficient model loading method for inference based on model_arguments + pass + + +def load_model_for_powerful_finetune(): + # TODO: An powerful model loading method for finetune based on model_arguments + pass + + +def unload_model(): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index e346f569102e..5194ed4df1bd 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -20,43 +20,37 @@ import json import os import shutil -from collections.abc import Callable -from typing import Dict +from pathlib import Path +from typing import Dict, List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, ModelCategory, - ModelFileType, ModelStates, + UriType, ) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import ( + ensure_init_file, + get_parsed_uri, + import_class_from_path, + load_model_config_in_json, + parse_uri_type, + temporary_sys_path, + validate_model_files, +) from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus @@ -64,320 +58,368 @@ logger = Logger() -class ModelStorage(object): +class ModelStorage: + """Model storage class - unified management of model discovery and registration""" + def __init__(self): - self._model_dir = os.path.join( + self._models_dir = os.path.join( os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() ) - if not os.path.exists(self._model_dir): - try: - os.makedirs(self._model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._builtin_model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir() - ) - if not os.path.exists(self._builtin_model_dir): - try: - os.makedirs(self._builtin_model_dir) - except PermissionError as e: - logger.error(e) - raise e + # Unified storage: category -> {model_id -> ModelInfo} + self._models: Dict[str, Dict[str, ModelInfo]] = { + ModelCategory.BUILTIN.value: {}, + ModelCategory.USER_DEFINED.value: {}, + } + # Async download executor + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + # Thread lock pool for protecting concurrent access to model information self._lock_pool = ModelLockPool() - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1 - ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading. - self._model_info_map: Dict[str, ModelInfo] = {} - self._init_model_info_map() + self._initialize_directories() + self.discover_all_models() + + def _initialize_directories(self): + """Initialize directory structure and ensure __init__.py files exist""" + os.makedirs(self._models_dir, exist_ok=True) + ensure_init_file(self._models_dir) + for category in ModelCategory: + category_path = os.path.join(self._models_dir, category.value) + os.makedirs(category_path, exist_ok=True) + ensure_init_file(category_path) + + # ==================== Discovery Methods ==================== + + def discover_all_models(self): + """Scan file system to discover all models""" + self._discover_category(ModelCategory.BUILTIN) + self._discover_category(ModelCategory.USER_DEFINED) + + def _discover_category(self, category: ModelCategory): + """Discover all models in a category directory""" + category_path = os.path.join(self._models_dir, category.value) + if category == ModelCategory.BUILTIN: + self._discover_builtin_models(category_path) + elif category == ModelCategory.USER_DEFINED: + for model_id in os.listdir(category_path): + if os.path.isdir(os.path.join(category_path, model_id)): + self._process_user_defined_model_directory( + os.path.join(category_path, model_id), model_id + ) - def _init_model_info_map(self): - """ - Initialize the model info map. - """ - # 1. initialize built-in and ready-to-use models - for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP: - self._model_info_map[model_id] = BUILT_IN_MACHINE_LEARNING_MODEL_MAP[ - model_id - ] - # 2. retrieve fine-tuned models from the built-in model directory - fine_tuned_models = self._retrieve_fine_tuned_models() - for model_id in fine_tuned_models: - self._model_info_map[model_id] = fine_tuned_models[model_id] - # 3. automatically downloading the weights of built-in LSTM models when necessary - for model_id in BUILT_IN_LTSM_MAP: - if model_id not in self._model_info_map: - self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id] - future = self._executor.submit( - self._download_built_in_model_if_necessary, model_id - ) - future.add_done_callback( - lambda f, mid=model_id: self._callback_model_download_result(f, mid) - ) - # 4. retrieve user-defined models from the model directory - user_defined_models = self._retrieve_user_defined_models() - for model_id in user_defined_models: - self._model_info_map[model_id] = user_defined_models[model_id] + def _discover_builtin_models(self, category_path: str): + # Register SKTIME models directly from map + for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_SKTIME_MODEL_MAP[model_id] + ) - def _retrieve_fine_tuned_models(self): - """ - Retrieve fine-tuned models from the built-in model directory. + # Process HuggingFace Transformers models + for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): + model_dir = os.path.join(category_path, model_id) + os.makedirs(model_dir, exist_ok=True) + self._process_builtin_model_directory(model_dir, model_id) - Returns: - {"model_id": ModelInfo} - """ - result = {} - build_in_dirs = [ - d - for d in os.listdir(self._builtin_model_dir) - if os.path.isdir(os.path.join(self._builtin_model_dir, d)) - ] - for model_id in build_in_dirs: - config_file_path = os.path.join( - self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON + def _process_builtin_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a builtin model directory.""" + ensure_init_file(model_dir) + with self._lock_pool.get_lock(model_id).write_lock(): + # Check if model already exists and is in a valid state + existing_model = self._models[ModelCategory.BUILTIN.value].get(model_id) + if existing_model: + # If model is already ACTIVATING or ACTIVE, skip duplicate download + if existing_model.state in (ModelStates.ACTIVATING, ModelStates.ACTIVE): + return + + # If model not exists or is INACTIVE, we'll try to update its info and download its weights + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] ) - if os.path.isfile(config_file_path): - with open(config_file_path, "r") as f: - model_config = json.load(f) - if "model_type" in model_config: - model_type = model_config["model_type"] - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=ModelCategory.FINE_TUNED, - state=ModelStates.ACTIVE, + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.ACTIVATING + + def _download_model_if_necessary() -> bool: + """Returns: True if the model is existed or downloaded successfully, False otherwise.""" + repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + local_dir=model_dir, ) - # Refactor the built-in model category - if "timer_xl" == model_id: - model_info.category = ModelCategory.BUILT_IN - if "sundial" == model_id: - model_info.category = ModelCategory.BUILT_IN - # Compatible patch with the codes in HuggingFace - if "timer" == model_type: - model_info.model_type = BuiltInModelType.TIMER_XL.value - if "sundial" == model_type: - model_info.model_type = BuiltInModelType.SUNDIAL.value - result[model_id] = model_info - return result - - def _download_built_in_model_if_necessary(self, model_id: str) -> bool: - """ - Download the built-in model if it is not already downloaded. - - Args: - model_id (str): The ID of the model to download. + except Exception as e: + logger.error( + f"Failed to download model weights from HuggingFace: {e}" + ) + return False + if not os.path.exists(config_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE_IN_JSON, + local_dir=model_dir, + ) + except Exception as e: + logger.error( + f"Failed to download model config from HuggingFace: {e}" + ) + return False + return True - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - with self._lock_pool.get_lock(model_id).write_lock(): - local_dir = os.path.join(self._builtin_model_dir, model_id) - return download_built_in_ltsm_from_hf_if_necessary( - get_built_in_model_type(self._model_info_map[model_id].model_type), - local_dir, - ) + future = self._executor.submit(_download_model_if_necessary) + future.add_done_callback( + lambda f, mid=model_id: self._callback_model_download_result(f, mid) + ) def _callback_model_download_result(self, future, model_id: str): + """Callback function for handling model download results""" with self._lock_pool.get_lock(model_id).write_lock(): - if future.result(): - self._model_info_map[model_id].state = ModelStates.ACTIVE - logger.info( - f"The built-in model: {model_id} is active and ready to use." - ) - else: - self._model_info_map[model_id].state = ModelStates.INACTIVE - - def _retrieve_user_defined_models(self): - """ - Retrieve user_defined models from the model directory. + try: + if future.result(): + model_info = self._models[ModelCategory.BUILTIN.value][model_id] + model_info.state = ModelStates.ACTIVE + config_path = os.path.join( + self._models_dir, + ModelCategory.BUILTIN.value, + model_id, + MODEL_CONFIG_FILE_IN_JSON, + ) + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + if model_info.model_type == "": + model_info.model_type = config.get("model_type", "") + model_info.auto_map = config.get("auto_map", None) + logger.info( + f"Model {model_id} downloaded successfully and is ready to use." + ) + else: + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + logger.warning(f"Failed to download model {model_id}.") + except Exception as e: + logger.error(f"Error in download callback for model {model_id}: {e}") + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + + def _process_user_defined_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a user-defined model directory.""" + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + model_type = "" + auto_map = {} + pipeline_cls = "" + if os.path.exists(config_path): + config = load_model_config_in_json(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map", None) + pipeline_cls = config.get("pipeline_cls", "") - Returns: - {"model_id": ModelInfo} - """ - result = {} - user_dirs = [ - d - for d in os.listdir(self._model_dir) - if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights" - ] - for model_id in user_dirs: - result[model_id] = ModelInfo( + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( model_id=model_id, - model_type="", + model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + _transformers_registered=False, # Lazy registration ) - return result + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + # ==================== Registration Methods ==================== - def register_model(self, model_id: str, uri: str): + def register_model(self, model_id: str, uri: str) -> bool: """ - Args: - model_id: id of model to register - uri: network or local dir path of the model to register - Returns: - configs: TConfigs - attributes: str + Supported URI formats: + - repo:// (Maybe in the future) + - file:// """ + uri_type = parse_uri_type(uri) + parsed_uri = get_parsed_uri(uri) + + model_dir = os.path.join( + self._models_dir, ModelCategory.USER_DEFINED.value, model_id + ) + os.makedirs(model_dir, exist_ok=True) + ensure_init_file(model_dir) + + if uri_type == UriType.REPO: + self._fetch_model_from_hf_repo(parsed_uri, model_dir) + else: + self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) + + config_path, _ = validate_model_files(model_dir) + config = load_model_config_in_json(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + pipeline_cls = config.get("pipeline_cls", "") + with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - # create storage dir if not exist - if not os.path.exists(storage_path): - os.makedirs(storage_path) - uri_type, parsed_uri, model_file_type = get_model_register_strategy(uri) - self._model_info_map[model_id] = ModelInfo( + model_info = ModelInfo( model_id=model_id, - model_type="", + model_type=model_type, category=ModelCategory.USER_DEFINED, - state=ModelStates.LOADING, + state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + _transformers_registered=False, # Register later + ) + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + if auto_map: + # Transformers model: immediately register to Transformers auto-loading mechanism + success = self._register_transformers_model(model_info) + if success: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info._transformers_registered = True + else: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info.state = ModelStates.INACTIVE + logger.error(f"Failed to register Transformers model {model_id}") + return False + else: + # Other type models: only log + self._register_other_model(model_info) + + logger.info(f"Successfully registered model {model_id} from URI: {uri}") + return True + + def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" + ) + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise + + def _fetch_model_from_local(self, source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") + source_dir = Path(source_path) + if not source_dir.is_dir(): + raise ValueError( + f"Source path does not exist or is not a directory: {source_path}" ) - try: - # TODO: The uri should be fetched asynchronously - configs, attributes = fetch_model_by_uri( - uri_type, parsed_uri, storage_path, model_file_type - ) - self._model_info_map[model_id].state = ModelStates.ACTIVE - return configs, attributes - except Exception as e: - logger.error(f"Failed to register model {model_id}: {e}") - self._model_info_map[model_id].state = ModelStates.INACTIVE - raise e - def delete_model(self, model_id: str) -> None: - """ - Args: - model_id: id of model to delete - Returns: - None - """ - # check if the model is built-in - with self._lock_pool.get_lock(model_id).read_lock(): - if self._is_built_in(model_id): - raise BuiltInModelDeletionError(model_id) + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + return - # delete the user-defined or fine-tuned model - with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - storage_path = os.path.join(self._builtin_model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - if model_id in self._model_info_map: - del self._model_info_map[model_id] - logger.info(f"Model {model_id} deleted successfully.") - - def _is_built_in(self, model_id: str) -> bool: + def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ - Check if the model_id corresponds to a built-in model. + Register Transformers model to auto-loading mechanism (internal method) + """ + auto_map = model_info.auto_map + if not auto_map: + return False - Args: - model_id (str): The ID of the model. + auto_config_path = auto_map.get("AutoConfig") + auto_model_path = auto_map.get("AutoModelForCausalLM") - Returns: - bool: True if the model is built-in, False otherwise. - """ - return ( - model_id in self._model_info_map - and self._model_info_map[model_id].category == ModelCategory.BUILT_IN - ) + try: + model_path = os.path.join( + self._models_dir, model_info.category.value, model_info.model_id + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_class = import_class_from_path( + model_info.model_id, auto_config_path + ) + AutoConfig.register(model_info.model_type, config_class) + logger.info( + f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}" + ) - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. + model_class = import_class_from_path( + model_info.model_id, auto_model_path + ) + AutoModelForCausalLM.register(config_class, model_class) + logger.info( + f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" + ) - Args: - model_id (str): The ID of the model. + return True + except Exception as e: + logger.warning( + f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct." + ) + return False - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return model_id in self._model_info_map and ( - self._model_info_map[model_id].category == ModelCategory.BUILT_IN - or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED + def _register_other_model(self, model_info: ModelInfo): + """Register other type models (non-Transformers models)""" + logger.info( + f"Registered other type model: {model_info.model_id} ({model_info.model_type})" ) - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool - ) -> Callable: + def ensure_transformers_registered(self, model_id: str) -> ModelInfo: """ - Load a model with automatic detection of .safetensors or .pt format - + Ensure Transformers model is registered (called for lazy registration) + This method uses locks to ensure thread safety. All check logic is within lock protection. Returns: - model: The model instance corresponding to specific model_id - """ - with self._lock_pool.get_lock(model_id).read_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - return fetch_built_in_model( - get_built_in_model_type(self._model_info_map[model_id].model_type), - model_dir, - inference_attrs, - ) - else: - # load the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - model_file_type = get_model_file_type(model_dir) - if model_file_type == ModelFileType.SAFETENSORS: - # TODO: Support this function - raise UnsupportedError("SAFETENSORS format") - else: - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - - if not os.path.exists(model_path): - raise ModelNotExistError(model_path) - model = torch.jit.load(model_path) - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - - try: - model = torch.compile(model) - except Exception as e: - logger.warning( - f"acceleration failed, fallback to normal mode: {str(e)}" - ) - return model - - def save_model(self, model_id: str, model: nn.Module): - """ - Save the model using save_pretrained - - Returns: - Whether saving succeeded + str: If None, registration failed, otherwise returns model path """ + # Use lock to protect entire check-execute process with self._lock_pool.get_lock(model_id).write_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - model.save_pretrained(model_dir) - else: - # save the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - os.makedirs(model_dir, exist_ok=True) - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - try: - scripted_model = ( - model - if isinstance(model, torch.jit.ScriptModule) - else torch.jit.script(model) + # Directly access _models dictionary (avoid calling get_model_info which may cause deadlock) + model_info = None + for category_dict in self._models.values(): + if model_id in category_dict: + model_info = category_dict[model_id] + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot register") + return None + + # If already registered, return directly + if model_info._transformers_registered: + return model_info + + # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) + if ( + not model_info.auto_map + or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() + ): + model_info._transformers_registered = True + return model_info + + # Execute registration (under lock protection) + try: + success = self._register_transformers_model(model_info) + if success: + model_info._transformers_registered = True + logger.info( + f"Model {model_id} successfully registered to Transformers" ) - torch.jit.save(scripted_model, model_path) - except Exception as e: - logger.error(f"Failed to save scripted model: {e}") - - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. + return model_info + else: + model_info.state = ModelStates.INACTIVE + logger.error(f"Model {model_id} failed to register to Transformers") + return None - Args: - model_id (str): The ID of the model. + except Exception as e: + # Ensure state consistency in exception cases + model_info.state = ModelStates.INACTIVE + model_info._transformers_registered = False + logger.error( + f"Exception occurred while registering model {model_id} to Transformers: {e}" + ) + return None - Returns: - str: The path to the checkpoint file for the model. - """ - # Only support built-in models for now - return os.path.join(self._builtin_model_dir, f"{model_id}") + # ==================== Show and Delete Models ==================== def show_models(self, req: TShowModelsReq) -> TShowModelsResp: resp_status = TSStatus( @@ -385,8 +427,14 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: message="Show models successfully", ) if req.modelId: - if req.modelId in self._model_info_map: - model_info = self._model_info_map[req.modelId] + # Find specified model + model_info = None + for category_dict in self._models.values(): + if req.modelId in category_dict: + model_info = category_dict[req.modelId] + break + + if model_info: return TShowModelsResp( status=resp_status, modelIdList=[req.modelId], @@ -402,55 +450,133 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: categoryMap={}, stateMap={}, ) + # Return all models + model_id_list = [] + model_type_map = {} + category_map = {} + state_map = {} + + for category_dict in self._models.values(): + for model_id, model_info in category_dict.items(): + model_id_list.append(model_id) + model_type_map[model_id] = model_info.model_type + category_map[model_id] = model_info.category.value + state_map[model_id] = model_info.state.value + return TShowModelsResp( status=resp_status, - modelIdList=list(self._model_info_map.keys()), - modelTypeMap=dict( - (model_id, model_info.model_type) - for model_id, model_info in self._model_info_map.items() - ), - categoryMap=dict( - (model_id, model_info.category.value) - for model_id, model_info in self._model_info_map.items() - ), - stateMap=dict( - (model_id, model_info.state.value) - for model_id, model_info in self._model_info_map.items() - ), + modelIdList=model_id_list, + modelTypeMap=model_type_map, + categoryMap=category_map, + stateMap=state_map, ) - def register_built_in_model(self, model_info: ModelInfo): - with self._lock_pool.get_lock(model_info.model_id).write_lock(): - self._model_info_map[model_info.model_id] = model_info + def delete_model(self, model_id: str) -> None: + # Use write lock to protect entire deletion process + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = None + category_value = None + for cat_value, category_dict in self._models.items(): + if model_id in category_dict: + model_info = category_dict[model_id] + category_value = cat_value + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot delete") + return + + if model_info.category == ModelCategory.BUILTIN: + raise BuiltInModelDeletionError(model_id) + model_info.state = ModelStates.DROPPING + model_path = os.path.join( + self._models_dir, model_info.category.value, model_id + ) + if model_path.exists(): + try: + shutil.rmtree(model_path) + logger.info(f"Deleted model directory: {model_path}") + except Exception as e: + logger.error(f"Failed to delete model directory {model_path}: {e}") + raise - def get_model_info(self, model_id: str) -> ModelInfo: - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return self._model_info_map[model_id] - else: - raise ValueError(f"Model {model_id} does not exist.") + if category_value and model_id in self._models[category_value]: + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from storage") - def update_model_state(self, model_id: str, state: ModelStates): - with self._lock_pool.get_lock(model_id).write_lock(): - if model_id in self._model_info_map: - self._model_info_map[model_id].state = state - else: - raise ValueError(f"Model {model_id} does not exist.") + return - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: + # ==================== Query Methods ==================== + + def get_model_info( + self, model_id: str, category: Optional[ModelCategory] = None + ) -> Optional[ModelInfo]: """ - Get the type of the model with the given model_id. + Get single model information - Args: - model_id (str): The ID of the model. + If category is specified, use model_id's lock + If category is not specified, need to traverse all dictionaries, use global lock + """ + if category: + # Category specified, only need to access specific dictionary, use model_id's lock + with self._lock_pool.get_lock(model_id).read_lock(): + return self._models[category.value].get(model_id) + else: + # Category not specified, need to traverse all dictionaries, use global lock + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return category_dict[model_id] + return None + + def get_model_infos( + self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None + ) -> List[ModelInfo]: + """ + Get model information list - Returns: - str: The type of the model. + Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure + For single model access, using model_id-based lock would be more efficient """ - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return get_built_in_model_type( - self._model_info_map[model_id].model_type - ) + matching_models = [] + + # For traversal operations, we need to protect the entire dictionary structure + # Use a special lock (using empty string as key) to protect the entire dictionary + with self._lock_pool.get_lock("").read_lock(): + if category and model_type: + for model_info in self._models[category.value].values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models + elif category: + return list(self._models[category.value].values()) + elif model_type: + for category_dict in self._models.values(): + for model_info in category_dict.values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models else: - raise ValueError(f"Model {model_id} does not exist.") + for category_dict in self._models.values(): + matching_models.extend(category_dict.values()) + return matching_models + + def is_model_registered(self, model_id: str) -> bool: + """Check if model is registered (search in _models)""" + # Lazy registration: if it's a Transformers model and not registered, register it first + if self.ensure_transformers_registered(model_id) is None: + return False + + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return True + return False + + def get_registered_models(self) -> List[str]: + """Get list of all registered model IDs""" + with self._lock_pool.get_lock("").read_lock(): + model_ids = [] + for category_dict in self._models.values(): + model_ids.extend(category_dict.keys()) + return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json new file mode 100644 index 000000000000..1561124badd1 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -0,0 +1,25 @@ +{ + "model_type": "sktime", + "model_id": "arima", + "predict_length": 1, + "order": [1, 0, 0], + "seasonal_order": [0, 0, 0, 0], + "start_params": null, + "method": "lbfgs", + "maxiter": 50, + "suppress_warnings": false, + "out_of_sample_size": 0, + "scoring": "mse", + "scoring_args": null, + "trend": null, + "with_intercept": true, + "time_varying_regression": false, + "enforce_stationarity": true, + "enforce_invertibility": true, + "simple_differencing": false, + "measurement_error": false, + "mle_regression": true, + "hamilton_representation": false, + "concentrate_scale": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py new file mode 100644 index 000000000000..261de3c9abe7 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -0,0 +1,409 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if value is None: + return True # Allow None for optional int parameters + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if value is None: + return True # Allow None for optional float parameters + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if value is None: + return True # Allow None for optional str parameters + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if value is None: + return True # Allow None for optional bool parameters + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + if string_value.lower() == "none" or string_value.strip() == "": + return None + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + if string_value.lower() == "none" or string_value.strip() == "": + return None + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + if string_value.lower() == "none" or string_value.strip() == "": + return None + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + elif string_value.lower() == "none" or string_value.strip() == "": + return None + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "strategy": AttributeConfig( + "strategy", "last", "str", choices=["last", "mean", "drift"] + ), + "window_length": AttributeConfig("window_length", None, "int"), + "sp": AttributeConfig("sp", 1, "int", 1, 5000), + }, + "EXPONENTIAL_SMOOTHING": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "damped_trend": AttributeConfig("damped_trend", False, "bool"), + "initialization_method": AttributeConfig( + "initialization_method", + "estimated", + "str", + choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + "optimized": AttributeConfig("optimized", True, "bool"), + "remove_bias": AttributeConfig("remove_bias", False, "bool"), + "use_brute": AttributeConfig("use_brute", False, "bool"), + }, + "ARIMA": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "seasonal_order": AttributeConfig( + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + ), + "start_params": AttributeConfig("start_params", None, "str"), + "method": AttributeConfig( + "method", + "lbfgs", + "str", + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + "maxiter": AttributeConfig("maxiter", 50, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", False, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "scoring_args": AttributeConfig("scoring_args", None, "str"), + "trend": AttributeConfig("trend", None, "str"), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), + }, + "STL_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "sp": AttributeConfig("sp", 2, "int", 1, 5000), + "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "trend": AttributeConfig("trend", None, "int"), + "low_pass": AttributeConfig("low_pass", None, "int"), + "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), + "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), + "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "robust": AttributeConfig("robust", False, "bool"), + "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), + "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), + "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + "inner_iter": AttributeConfig("inner_iter", None, "int"), + "outer_iter": AttributeConfig("outer_iter", None, "int"), + "forecaster_trend": AttributeConfig("forecaster_trend", None, "str"), + "forecaster_seasonal": AttributeConfig("forecaster_seasonal", None, "str"), + "forecaster_resid": AttributeConfig("forecaster_resid", None, "str"), + }, + "GAUSSIAN_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 0.01, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "random_state": AttributeConfig("random_state", None, "float"), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), + "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), + "init_params": AttributeConfig( + "init_params", "stmc", "str", choices=["stmc", "stm"] + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "GMM_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", None, "float"), + "covars_weight": AttributeConfig("covars_weight", None, "float"), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "random_state": AttributeConfig("random_state", None, "int"), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), + "init_params": AttributeConfig( + "init_params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "params": AttributeConfig( + "params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "STRAY": { + "alpha": AttributeConfig("alpha", 0.01, "float", -1e10, 1e10), + "k": AttributeConfig("k", 10, "int", 1, 5000), + "knn_algorithm": AttributeConfig( + "knn_algorithm", + "brute", + "str", + choices=["brute", "kd_tree", "ball_tree", "auto"], + ), + "p": AttributeConfig("p", 0.5, "float", -1e10, 1e10), + "size_threshold": AttributeConfig("size_threshold", 50, "int", 1, 5000), + "outlier_tail": AttributeConfig( + "outlier_tail", "max", "str", choices=["min", "max"] + ), + }, +} + + +def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: + """Get attribute configuration for Sktime model""" + model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id + if model_id not in MODEL_CONFIGS: + raise BuiltInModelNotSupportError(model_id) + return MODEL_CONFIGS[model_id] + + +def update_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, AttributeConfig] +) -> Dict[str, Any]: + """Update Sktime model attributes using input attributes""" + attributes = {} + for name, config in attribute_map.items(): + if name in input_attributes: + value = config.parse(input_attributes[name]) + config.validate_value(value) + attributes[name] = value + else: + attributes[name] = config.default + return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json new file mode 100644 index 000000000000..4126d9de857a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "model_id": "exponential_smoothing", + "predict_length": 1, + "damped_trend": false, + "initialization_method": "estimated", + "optimized": true, + "remove_bias": false, + "use_brute": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json new file mode 100644 index 000000000000..94f7d7ec659f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "model_id": "gaussian_hmm", + "n_components": 1, + "covariance_type": "diag", + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "means_prior": 0, + "means_weight": 0, + "covars_prior": 0.01, + "covars_weight": 1, + "algorithm": "viterbi", + "random_state": null, + "n_iter": 10, + "tol": 0.01, + "verbose": false, + "params": "stmc", + "init_params": "stmc", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json new file mode 100644 index 000000000000..fb19d1aaf86d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json @@ -0,0 +1,24 @@ +{ + "model_type": "sktime", + "model_id": "gmm_hmm", + "n_components": 1, + "n_mix": 1, + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "weights_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "covars_prior": null, + "covars_weight": null, + "algorithm": "viterbi", + "covariance_type": "diag", + "random_state": null, + "n_iter": 10, + "tol": 0.01, + "verbose": false, + "init_params": "stmcw", + "params": "stmcw", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py new file mode 100644 index 000000000000..eca812d35ec9 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import abstractmethod +from typing import Any, Dict + +import numpy as np +import pandas as pd +from sklearn.preprocessing import MinMaxScaler +from sktime.detection.hmm_learn import GMMHMM, GaussianHMM +from sktime.detection.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + InferenceModelInternalError, +) +from iotdb.ainode.core.log import Logger + +from .configuration_sktime import get_attributes, update_attribute + +logger = Logger() + + +class SktimeModel: + """Base class for Sktime models""" + + def __init__(self, attributes: Dict[str, Any]): + self._attributes = attributes + self._model = None + + @abstractmethod + def generate(self, data, **kwargs): + """Execute generation/inference""" + raise NotImplementedError + + +class ForecastingModel(SktimeModel): + """Base class for forecasting models""" + + def generate(self, data, **kwargs): + """Execute forecasting""" + try: + predict_length = kwargs.get( + "predict_length", self._attributes["predict_length"] + ) + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + return np.array(output, dtype=np.float64) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class DetectionModel(SktimeModel): + """Base class for detection models""" + + def generate(self, data, **kwargs): + """Execute detection""" + try: + predict_length = kwargs.get("predict_length", data.size) + output = self._model.fit_transform(data[:predict_length]) + if isinstance(output, pd.DataFrame): + return np.array(output["labels"], dtype=np.int32) + else: + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ArimaModel(ForecastingModel): + """ARIMA model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ARIMA( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class ExponentialSmoothingModel(ForecastingModel): + """Exponential smoothing model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ExponentialSmoothing( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class NaiveForecasterModel(ForecastingModel): + """Naive forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = NaiveForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class STLForecasterModel(ForecastingModel): + """STL forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STLForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class GMMHMMModel(DetectionModel): + """GMM HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GMMHMM(**attributes) + + +class GaussianHmmModel(DetectionModel): + """Gaussian HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GaussianHMM(**attributes) + + +class STRAYModel(DetectionModel): + """STRAY anomaly detection model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STRAY(**{k: v for k, v in attributes.items() if v is not None}) + + def generate(self, data, **kwargs): + """STRAY requires special handling: normalize first""" + try: + scaled_data = MinMaxScaler().fit_transform(data.values.reshape(-1, 1)) + scaled_data = pd.Series(scaled_data.flatten()) + return super().generate(scaled_data, **kwargs) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +# Model factory mapping +_MODEL_FACTORY = { + "ARIMA": ArimaModel, + "EXPONENTIAL_SMOOTHING": ExponentialSmoothingModel, + "HOLTWINTERS": ExponentialSmoothingModel, # Use the same model class + "NAIVE_FORECASTER": NaiveForecasterModel, + "STL_FORECASTER": STLForecasterModel, + "GMM_HMM": GMMHMMModel, + "GAUSSIAN_HMM": GaussianHmmModel, + "STRAY": STRAYModel, +} + + +def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: + """Create a Sktime model instance""" + attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) + model_class = _MODEL_FACTORY.get(model_id.upper()) + if model_class is None: + raise BuiltInModelNotSupportError(model_id) + return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json new file mode 100644 index 000000000000..3dadd7c3b1e5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json @@ -0,0 +1,9 @@ +{ + "model_type": "sktime", + "model_id": "naive_forecaster", + "predict_length": 1, + "strategy": "last", + "window_length": null, + "sp": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py new file mode 100644 index 000000000000..ced21f29a2b8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline + + +class SktimePipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + model_kwargs.pop("device", None) # sktime models run on CPU + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + input_ids = self._preprocess(inputs) + + # Convert to pandas Series for sktime (sktime expects Series or DataFrame) + # Handle batch dimension: if batch_size > 1, process each sample separately + if len(input_ids.shape) == 2 and input_ids.shape[0] > 1: + # Batch processing: convert each row to Series + outputs = [] + for i in range(input_ids.shape[0]): + series = pd.Series( + input_ids[i].cpu().numpy() + if isinstance(input_ids, torch.Tensor) + else input_ids[i] + ) + output = self.model.generate(series, predict_length=predict_length) + outputs.append(output) + output = np.array(outputs) + else: + # Single sample: convert to Series + if isinstance(input_ids, torch.Tensor): + series = pd.Series(input_ids.squeeze().cpu().numpy()) + else: + series = pd.Series(input_ids.squeeze()) + output = self.model.generate(series, predict_length=predict_length) + # Add batch dimension if needed + if len(output.shape) == 1: + output = output[np.newaxis, :] + + return self._postprocess(output) + + def _postprocess(self, output): + if isinstance(output, np.ndarray): + return torch.from_numpy(output).float() + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json new file mode 100644 index 000000000000..bfe71dbc4861 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "model_id": "stl_forecaster", + "predict_length": 1, + "sp": 2, + "seasonal": 7, + "trend": null, + "low_pass": null, + "seasonal_deg": 1, + "trend_deg": 1, + "low_pass_deg": 1, + "robust": false, + "seasonal_jump": 1, + "trend_jump": 1, + "low_pass_jump": 1, + "inner_iter": null, + "outer_iter": null, + "forecaster_trend": null, + "forecaster_seasonal": null, + "forecaster_resid": null +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json new file mode 100644 index 000000000000..e5bcc03cd071 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "model_id": "stray", + "alpha": 0.01, + "k": 10, + "knn_algorithm": "brute", + "p": 0.5, + "size_threshold": 50, + "outlier_tail": "max" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index 3ebf516f705e..dc1de32506e5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -16,13 +16,10 @@ # under the License. # -import os -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors from torch import nn from transformers import Cache, DynamicCache, PreTrainedModel from transformers.activations import ACT2FN @@ -32,13 +29,10 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin -logger = Logger() - def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] @@ -616,11 +610,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): @@ -633,10 +623,9 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - token_num = ( - input_ids.shape[1] + self.config.input_token_len - 1 - ) // self.config.input_token_len - position_ids = position_ids[:, -token_num:] + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py similarity index 56% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 17c88e32fb5a..85b6f7db2ffe 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -19,33 +19,33 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class TimerSundialInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-Sundial model inference. - """ +class SundialPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) - def __init__(self, model_config: SundialConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) + def _preprocess(self, inputs): if len(inputs.shape) != 2: raise InferenceModelInternalError( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py return inputs - def post_decode(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass - - def post_inference(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + num_samples = infer_kwargs.get("num_samples", 10) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, + max_new_tokens=predict_length, + num_samples=num_samples, + revin=revin, + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output.mean(dim=1) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py similarity index 98% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py index 0a33c682742a..fc9d7b41388b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py @@ -16,7 +16,7 @@ # under the License. # -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -29,11 +29,8 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin - -logger = Logger() +from iotdb.ainode.core.model.timer_xl.configuration_timer import TimerConfig +from iotdb.ainode.core.model.timer_xl.ts_generation_mixin import TSGenerationMixin def rotate_half(x): @@ -606,11 +603,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py similarity index 52% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index dc1dd304f68e..c0f00b1f5caf 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -19,33 +19,29 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class TimerXLInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-XL model inference. - """ +class TimerPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) - def __init__(self, model_config: TimerConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) + def _preprocess(self, inputs): if len(inputs.shape) != 2: raise InferenceModelInternalError( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) - # Considering that we are currently using the generate function interface, it seems that no pre-processing is required return inputs - def post_decode(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, max_new_tokens=predict_length, revin=revin + ) + return self._postprocess(output) - def post_inference(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py deleted file mode 100644 index b2e759e00ce0..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from enum import Enum -from typing import List - -from huggingface_hub import snapshot_download -from requests import Session -from requests.adapters import HTTPAdapter - -from iotdb.ainode.core.constant import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_RECONNECT_TIMEOUT, - DEFAULT_RECONNECT_TIMES, -) -from iotdb.ainode.core.exception import UnsupportedError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.model_info import get_model_file_type - -HTTP_PREFIX = "http://" -HTTPS_PREFIX = "https://" - -logger = Logger() - - -class UriType(Enum): - REPO = "repo" - FILE = "file" - HTTP = "http" - HTTPS = "https" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def parse_uri_type(uri: str): - """ - Parse the URI type from the given string. - """ - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): - return UriType.FILE - elif uri.startswith("http://"): - return UriType.HTTP - elif uri.startswith("https://"): - return UriType.HTTPS - else: - raise ValueError(f"Invalid URI type for {uri}") - - -def get_model_register_strategy(uri: str): - """ - Determine the loading strategy for a model based on its URI/path. - - Args: - uri (str): The URI of the model to be registered. - - Returns: - uri_type (UriType): The type of the URI, which can be one of: REPO, FILE, HTTP, or HTTPS. - parsed_uri (str): Parsed uri to get related file - model_file_type (ModelFileType): The type of the model file, which can be one of: SAFETENSORS, PYTORCH, or UNKNOWN. - """ - - uri_type = UriType.parse_uri_type(uri) - if uri_type in (UriType.HTTP, UriType.HTTPS): - # TODO: support HTTP(S) URI - raise UnsupportedError("CREATE MODEL FROM HTTP(S) URI") - else: - parsed_uri = uri[7:] - if uri_type == UriType.FILE: - # handle ~ in URI - parsed_uri = os.path.expanduser(parsed_uri) - model_file_type = get_model_file_type(uri) - elif uri_type == UriType.REPO: - # Currently, UriType.REPO only corresponds to huggingface repository with SAFETENSORS format - model_file_type = ModelFileType.SAFETENSORS - else: - raise ValueError(f"Invalid URI type for {uri}") - return uri_type, parsed_uri, model_file_type - - -def download_snapshot_from_hf(repo_id: str, local_dir: str): - """ - Download everything from a HuggingFace repository. - - Args: - repo_id (str): The HuggingFace repository ID. - local_dir (str): The local directory to save the downloaded files. - """ - try: - snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - ) - except Exception as e: - logger.error(f"Failed to download HuggingFace model {repo_id}: {e}") - raise e - - -def download_file(url: str, storage_path: str) -> None: - """ - Args: - url: url of file to download - storage_path: path to save the file - Returns: - None - """ - logger.info(f"Start Downloading file from {url} to {storage_path}") - session = Session() - adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES) - session.mount(HTTP_PREFIX, adapter) - session.mount(HTTPS_PREFIX, adapter) - response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) - response.raise_for_status() - with open(storage_path, "wb") as file: - for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): - if chunk: - file.write(chunk) - logger.info(f"Download file from {url} to {storage_path} success") diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py new file mode 100644 index 000000000000..1cd0ee44912d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import importlib +import json +import os.path +import sys +from contextlib import contextmanager +from typing import Dict, Tuple + +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + UriType, +) + + +def parse_uri_type(uri: str) -> UriType: + if uri.startswith("repo://"): + return UriType.REPO + elif uri.startswith("file://"): + return UriType.FILE + else: + raise ValueError( + f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + ) + + +def get_parsed_uri(uri: str) -> str: + return uri[7:] # Remove "repo://" or "file://" prefix + + +@contextmanager +def temporary_sys_path(path: str): + """Context manager for temporarily adding a path to sys.path""" + path_added = path not in sys.path + if path_added: + sys.path.insert(0, path) + try: + yield + finally: + if path_added and path in sys.path: + sys.path.remove(path) + + +def load_model_config_in_json(config_path: str) -> Dict: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def validate_model_files(model_dir: str) -> Tuple[str, str]: + """Validate model files exist, return config and weights file paths""" + + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + + if not os.path.exists(config_path): + raise ValueError(f"Model config file does not exist: {config_path}") + if not os.path.exists(weights_path): + raise ValueError(f"Model weights file does not exist: {weights_path}") + + # Create __init__.py file to ensure model directory can be imported as a module + init_file = os.path.join(model_dir, "__init__.py") + if not os.path.exists(init_file): + with open(init_file, "w"): + pass + + return config_path, weights_path + + +def import_class_from_path(module_name, class_path: str): + file_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name + "." + file_name) + return getattr(module, class_name) + + +def ensure_init_file(dir_path: str): + """Ensure __init__.py file exists in the given dir path""" + init_file = os.path.join(dir_path, "__init__.py") + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(init_file): + with open(init_file, "w"): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py index e2be6459508b..ea6362ef080a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py @@ -38,7 +38,6 @@ TAINodeRemoveReq, TAINodeRestartReq, TNodeVersionInfo, - TUpdateModelInfoReq, ) logger = Logger() @@ -155,13 +154,8 @@ def _wait_and_reconnect(self) -> None: self._try_to_connect() except TException: # can not connect to each config node - self._sync_latest_config_node_list() self._try_to_connect() - def _sync_latest_config_node_list(self) -> None: - # TODO - pass - def _update_config_node_leader(self, status: TSStatus) -> bool: if status.code == TSStatusCode.REDIRECTION_RECOMMEND.get_status_code(): if status.redirectNode is not None: @@ -271,36 +265,3 @@ def get_ainode_configuration(self, node_id: int) -> map: self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) - - def update_model_info( - self, - model_id: str, - model_status: int, - attribute: str = "", - ainode_id=None, - input_length=0, - output_length=0, - ) -> None: - if ainode_id is None: - ainode_id = [] - for _ in range(0, self._RETRY_NUM): - try: - req = TUpdateModelInfoReq(model_id, model_status, attribute) - if ainode_id is not None: - req.aiNodeIds = ainode_id - req.inputLength = input_length - req.outputLength = output_length - status = self._client.updateModelInfo(req) - if not self._update_config_node_leader(status): - verify_success( - status, "An error occurs when calling update model info" - ) - return status - except TTransport.TException: - logger.warning( - "Failed to connect to ConfigNode {} from AINode when executing update model info", - self._config_leader, - ) - self._config_leader = None - self._wait_and_reconnect() - raise TException(self._MSG_RECONNECTION_FAIL) diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index f01e1594f069..6c4eedeb99f7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -29,6 +29,7 @@ TAIHeartbeatResp, TDeleteModelReq, TForecastReq, + TForecastResp, TInferenceReq, TInferenceResp, TLoadModelReq, @@ -78,8 +79,14 @@ def stopAINode(self) -> TSStatus: def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) + def deleteModel(self, req: TDeleteModelReq) -> TSStatus: + return self._model_manager.delete_model(req) + + def showModels(self, req: TShowModelsReq) -> TShowModelsResp: + return self._model_manager.show_models(req) + def loadModel(self, req: TLoadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) + status = self._ensure_model_is_registered(req.existingModelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -88,7 +95,7 @@ def loadModel(self, req: TLoadModelReq) -> TSStatus: return self._inference_manager.load_model(req) def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + status = self._ensure_model_is_registered(req.modelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -96,21 +103,6 @@ def unloadModel(self, req: TUnloadModelReq) -> TSStatus: return status return self._inference_manager.unload_model(req) - def deleteModel(self, req: TDeleteModelReq) -> TSStatus: - return self._model_manager.delete_model(req) - - def inference(self, req: TInferenceReq) -> TInferenceResp: - return self._inference_manager.inference(req) - - def forecast(self, req: TForecastReq) -> TSStatus: - return self._inference_manager.forecast(req) - - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) - - def showModels(self, req: TShowModelsReq) -> TShowModelsResp: - return self._model_manager.show_models(req) - def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: status = _ensure_device_id_is_available(req.deviceIdList) if status.code != TSStatusCode.SUCCESS_STATUS.value: @@ -123,13 +115,28 @@ def showAIDevices(self) -> TShowAIDevicesResp: deviceIdList=get_available_devices(), ) + def inference(self, req: TInferenceReq) -> TInferenceResp: + status = self._ensure_model_is_registered(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TInferenceResp(status, []) + return self._inference_manager.inference(req) + + def forecast(self, req: TForecastReq) -> TForecastResp: + status = self._ensure_model_is_registered(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TForecastResp(status, []) + return self._inference_manager.forecast(req) + + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass - def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: - if not self._model_manager.is_built_in_or_fine_tuned(model_id): + def _ensure_model_is_registered(self, model_id: str) -> TSStatus: + if not self._model_manager.is_model_registered(model_id): return TSStatus( code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", + message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 331cb8ab32a3..3773f69a847a 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = ">=2.7.0" +torch = "^2.7.1,<2.8.0" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" @@ -93,7 +93,10 @@ scipy = "^1.12.0" pandas = "^2.3.2" scikit-learn = "^1.7.1" statsmodels = "^0.14.5" -sktime = "0.38.5" +sktime = "0.40.1" +pmdarima = "2.1.1" +hmmlearn = "0.3.2" +accelerate = "^1.10.1" # ---- Optimizers / utils ---- optuna = "^4.4.0" diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java index 2721fedafb1e..8d9081f43527 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java @@ -21,21 +21,28 @@ import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientPoolFactory; import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; +/** Asynchronously send RPC requests to AINodes. */ public class AsyncAINodeHeartbeatClientPool { - private final IClientManager clientManager; + private final IClientManager clientManager; private AsyncAINodeHeartbeatClientPool() { clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeHeartbeatClientPoolFactory()); + new IClientManager.Factory() + .createClientManager( + new ClientPoolFactory.AsyncAINodeHeartbeatServiceClientPoolFactory()); } + /** + * Only used in LoadManager. + * + * @param endPoint The specific DataNode + */ public void getAINodeHeartBeat( TEndPoint endPoint, TAIHeartbeatReq req, AINodeHeartbeatHandler handler) { try { @@ -56,6 +63,6 @@ private AsyncAINodeHeartbeatClientPoolHolder() { } public static AsyncAINodeHeartbeatClientPool getInstance() { - return AsyncAINodeHeartbeatClientPool.AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; + return AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; } } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java index ccc19f1a9f38..324e35130278 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java @@ -63,7 +63,6 @@ public void writeAuditLog( } } - // TODO: Is the AsyncDataNodeHeartbeatClientPool must be a singleton? private static class AsyncDataNodeHeartbeatClientPoolHolder { private static final AsyncDataNodeHeartbeatClientPool INSTANCE = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java index e0b2c144c0ea..65c1ee0a9fed 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java @@ -21,8 +21,6 @@ import org.apache.iotdb.commons.exception.runtime.SerializationRunTimeException; import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.subscription.ShowTopicPlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; @@ -52,10 +50,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -574,24 +568,6 @@ public static ConfigPhysicalPlan create(final ByteBuffer buffer) throws IOExcept case UPDATE_CQ_LAST_EXEC_TIME: plan = new UpdateCQLastExecTimePlan(); break; - case CreateModel: - plan = new CreateModelPlan(); - break; - case UpdateModelInfo: - plan = new UpdateModelInfoPlan(); - break; - case DropModel: - plan = new DropModelPlan(); - break; - case ShowModel: - plan = new ShowModelPlan(); - break; - case DropModelInNode: - plan = new DropModelInNodePlan(); - break; - case GetModelInfo: - plan = new GetModelInfoPlan(); - break; case CreatePipePlugin: plan = new CreatePipePluginPlan(); break; diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java deleted file mode 100644 index dd79910e51fa..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; - -import java.util.Objects; - -public class GetModelInfoPlan extends ConfigPhysicalReadPlan { - - private String modelId; - - public GetModelInfoPlan() { - super(ConfigPhysicalPlanType.GetModelInfo); - } - - public GetModelInfoPlan(final TGetModelInfoReq getModelInfoReq) { - super(ConfigPhysicalPlanType.GetModelInfo); - this.modelId = getModelInfoReq.getModelId(); - } - - public String getModelId() { - return modelId; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final GetModelInfoPlan that = (GetModelInfoPlan) o; - return Objects.equals(modelId, that.modelId); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java deleted file mode 100644 index eca00e8827d9..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; - -import java.util.Objects; - -public class ShowModelPlan extends ConfigPhysicalReadPlan { - - private String modelName; - - public ShowModelPlan() { - super(ConfigPhysicalPlanType.ShowModel); - } - - public ShowModelPlan(final TShowModelsReq showModelReq) { - super(ConfigPhysicalPlanType.ShowModel); - if (showModelReq.isSetModelId()) { - this.modelName = showModelReq.getModelId(); - } - } - - public boolean isSetModelName() { - return modelName != null; - } - - public String getModelName() { - return modelName; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final ShowModelPlan that = (ShowModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java deleted file mode 100644 index 61e37cdd2187..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class CreateModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public CreateModelPlan() { - super(ConfigPhysicalPlanType.CreateModel); - } - - public CreateModelPlan(String modelName) { - super(ConfigPhysicalPlanType.CreateModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - CreateModelPlan that = (CreateModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java deleted file mode 100644 index 885543f84e15..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelInNodePlan extends ConfigPhysicalPlan { - - private int nodeId; - - public DropModelInNodePlan() { - super(ConfigPhysicalPlanType.DropModelInNode); - } - - public DropModelInNodePlan(int nodeId) { - super(ConfigPhysicalPlanType.DropModelInNode); - this.nodeId = nodeId; - } - - public int getNodeId() { - return nodeId; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - stream.writeInt(nodeId); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - nodeId = buffer.getInt(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof DropModelInNodePlan)) return false; - DropModelInNodePlan that = (DropModelInNodePlan) o; - return nodeId == that.nodeId; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), nodeId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java deleted file mode 100644 index 813b116c645c..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public DropModelPlan() { - super(ConfigPhysicalPlanType.DropModel); - } - - public DropModelPlan(String modelName) { - super(ConfigPhysicalPlanType.DropModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - DropModelPlan that = (DropModelPlan) o; - return modelName.equals(that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java deleted file mode 100644 index ce7219e42813..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - -public class UpdateModelInfoPlan extends ConfigPhysicalPlan { - - private String modelName; - private ModelInformation modelInformation; - - // The node which has the model which is only updated in model registration - private List nodeIds; - - public UpdateModelInfoPlan() { - super(ConfigPhysicalPlanType.UpdateModelInfo); - } - - public UpdateModelInfoPlan(String modelName, ModelInformation modelInformation) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = Collections.emptyList(); - } - - public UpdateModelInfoPlan( - String modelName, ModelInformation modelInformation, List nodeIds) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = nodeIds; - } - - public String getModelName() { - return modelName; - } - - public ModelInformation getModelInformation() { - return modelInformation; - } - - public List getNodeIds() { - return nodeIds; - } - - public void setNodeIds(List nodeIds) { - this.nodeIds = nodeIds; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - this.modelInformation.serialize(stream); - ReadWriteIOUtils.write(nodeIds.size(), stream); - for (Integer nodeId : nodeIds) { - ReadWriteIOUtils.write(nodeId, stream); - } - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - this.modelName = ReadWriteIOUtils.readString(buffer); - this.modelInformation = ModelInformation.deserialize(buffer); - int size = ReadWriteIOUtils.readInt(buffer); - this.nodeIds = new ArrayList<>(); - for (int i = 0; i < size; i++) { - this.nodeIds.add(ReadWriteIOUtils.readInt(buffer)); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - UpdateModelInfoPlan that = (UpdateModelInfoPlan) o; - return modelName.equals(that.modelName) - && modelInformation.equals(that.modelInformation) - && nodeIds.equals(that.nodeIds); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName, modelInformation, nodeIds); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java deleted file mode 100644 index cebc1301b891..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.consensus.common.DataSet; - -public class GetModelInfoResp implements DataSet { - - private final TSStatus status; - - private int targetAINodeId; - private TEndPoint targetAINodeAddress; - - public TSStatus getStatus() { - return status; - } - - public GetModelInfoResp(TSStatus status) { - this.status = status; - } - - public int getTargetAINodeId() { - return targetAINodeId; - } - - public void setTargetAINodeId(int targetAINodeId) { - this.targetAINodeId = targetAINodeId; - } - - public void setTargetAINodeAddress(TAINodeConfiguration aiNodeConfiguration) { - if (aiNodeConfiguration.getLocation() == null) { - return; - } - this.targetAINodeAddress = aiNodeConfiguration.getLocation().getInternalEndPoint(); - } - - public TGetModelInfoResp convertToThriftResponse() { - TGetModelInfoResp resp = new TGetModelInfoResp(status); - resp.setAiNodeAddress(targetAINodeAddress); - return resp; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java deleted file mode 100644 index 7490a53a01c5..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.consensus.common.DataSet; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -// TODO: Will be removed in the future -public class ModelTableResp implements DataSet { - - private final TSStatus status; - private final List serializedAllModelInformation; - private Map modelTypeMap; - private Map algorithmMap; - - public ModelTableResp(TSStatus status) { - this.status = status; - this.serializedAllModelInformation = new ArrayList<>(); - } - - public void addModelInformation(List modelInformationList) throws IOException { - for (ModelInformation modelInformation : modelInformationList) { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - } - - public void addModelInformation(ModelInformation modelInformation) throws IOException { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - - public void setModelTypeMap(Map modelTypeMap) { - this.modelTypeMap = modelTypeMap; - } - - public void setAlgorithmMap(Map algorithmMap) { - this.algorithmMap = algorithmMap; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index 5d4b09adfc71..9d7151a8d20e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -19,22 +19,12 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.IDataSchema; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TFlushReq; import org.apache.iotdb.common.rpc.thrift.TPipeHeartbeatResp; import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet; @@ -58,7 +48,6 @@ import org.apache.iotdb.commons.conf.TrimProperties; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelStatus; import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.path.PathPatternTree; import org.apache.iotdb.commons.path.PathPatternUtil; @@ -97,7 +86,6 @@ import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; import org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan; import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp; @@ -129,7 +117,6 @@ import org.apache.iotdb.confignode.manager.schema.ClusterSchemaQuotaStatistics; import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -144,7 +131,6 @@ import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo; import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo; import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; @@ -163,13 +149,11 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq; @@ -186,7 +170,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -203,8 +186,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -257,11 +238,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.consensus.common.DataSet; import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.schemaengine.template.Template; import org.apache.iotdb.db.schemaengine.template.TemplateAlterOperationType; import org.apache.iotdb.db.schemaengine.template.alter.TemplateAlterOperationUtil; @@ -340,9 +318,6 @@ public class ConfigManager implements IManager { /** CQ. */ private final CQManager cqManager; - /** AI Model. */ - private final ModelManager modelManager; - /** Pipe */ private final PipeManager pipeManager; @@ -362,8 +337,6 @@ public class ConfigManager implements IManager { private static final String DATABASE = "\tDatabase="; - private static final String DOT = "."; - public ConfigManager() throws IOException { // Build the persistence module ClusterInfo clusterInfo = new ClusterInfo(); @@ -375,7 +348,6 @@ public ConfigManager() throws IOException { UDFInfo udfInfo = new UDFInfo(); TriggerInfo triggerInfo = new TriggerInfo(); CQInfo cqInfo = new CQInfo(); - ModelInfo modelInfo = new ModelInfo(); PipeInfo pipeInfo = new PipeInfo(); QuotaInfo quotaInfo = new QuotaInfo(); TTLInfo ttlInfo = new TTLInfo(); @@ -393,7 +365,6 @@ public ConfigManager() throws IOException { udfInfo, triggerInfo, cqInfo, - modelInfo, pipeInfo, subscriptionInfo, quotaInfo, @@ -415,7 +386,6 @@ public ConfigManager() throws IOException { this.udfManager = new UDFManager(this, udfInfo); this.triggerManager = new TriggerManager(this, triggerInfo); this.cqManager = new CQManager(this); - this.modelManager = new ModelManager(this, modelInfo); this.pipeManager = new PipeManager(this, pipeInfo); this.subscriptionManager = new SubscriptionManager(this, subscriptionInfo); this.auditLogger = new CNAuditLogger(this); @@ -1289,11 +1259,6 @@ public TriggerManager getTriggerManager() { return triggerManager; } - @Override - public ModelManager getModelManager() { - return modelManager; - } - @Override public PipeManager getPipeManager() { return pipeManager; @@ -2757,150 +2722,6 @@ public TSStatus transfer(List newUnknownDataList) { return transferResult; } - @Override - public TSStatus createModel(TCreateModelReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.createModel(req) - : status; - } - - private List fetchSchemaForTreeModel(TCreateTrainingReq req) { - List dataSchemaList = new ArrayList<>(); - for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) { - IDataSchema dataSchema = new IDataSchema(req.getDataSchemaForTree().getPath().get(i)); - dataSchema.setTimeRange(req.getTimeRanges().get(i)); - dataSchemaList.add(dataSchema); - } - return dataSchemaList; - } - - private List fetchSchemaForTableModel(TCreateTrainingReq req) { - return Collections.singletonList(new IDataSchema(req.getDataSchemaForTable().getTargetSql())); - } - - public TSStatus createTraining(TCreateTrainingReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - - TTrainingReq trainingReq = new TTrainingReq(); - trainingReq.setModelId(req.getModelId()); - if (req.isSetExistingModelId()) { - trainingReq.setExistingModelId(req.getExistingModelId()); - } - if (req.isSetParameters() && !req.getParameters().isEmpty()) { - trainingReq.setParameters(req.getParameters()); - } - - try { - status = getConsensusManager().write(new CreateModelPlan(req.getModelId())); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new MetadataException("Can't init model " + req.getModelId()); - } - - List dataSchema; - if (req.isTableModel) { - dataSchema = fetchSchemaForTableModel(req); - trainingReq.setDbType("iotdb.table"); - } else { - dataSchema = fetchSchemaForTreeModel(req); - trainingReq.setDbType("iotdb.tree"); - } - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal())); - trainingReq.setTargetDataSchema(dataSchema); - - TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0); - TEndPoint targetAINodeEndPoint = - new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort()); - try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) { - status = client.createTrainingTask(trainingReq); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new IllegalArgumentException(status.message); - } - } - } catch (final Exception e) { - status.setCode(TSStatusCode.CAN_NOT_CONNECT_CONFIGNODE.getStatusCode()); - status.setMessage(e.getMessage()); - try { - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.UNAVAILABLE.ordinal())); - } catch (Exception e2) { - LOGGER.error(e2.getMessage()); - } - } - return status; - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.dropModel(req) - : status; - } - - @Override - public TSStatus loadModel(TLoadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.loadModel(req) - : status; - } - - @Override - public TSStatus unloadModel(TUnloadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.unloadModel(req) - : status; - } - - @Override - public TShowModelsResp showModel(TShowModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showModel(req) - : new TShowModelsResp(status); - } - - @Override - public TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showLoadedModel(req) - : new TShowLoadedModelsResp(status, Collections.emptyMap()); - } - - @Override - public TShowAIDevicesResp showAIDevices() { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showAIDevices() - : new TShowAIDevicesResp(status, Collections.emptyList()); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.getModelInfo(req) - : new TGetModelInfoResp(status); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.updateModelInfo(req) - : status; - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) { TSStatus status = confirmLeader(); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java index 33e77db24907..dff994d70e7e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java @@ -19,13 +19,6 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; @@ -82,7 +75,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -103,7 +95,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -120,8 +111,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -255,13 +244,6 @@ public interface IManager { */ CQManager getCQManager(); - /** - * Get {@link ModelManager}. - * - * @return {@link ModelManager} instance - */ - ModelManager getModelManager(); - /** * Get {@link PipeManager}. * @@ -880,30 +862,6 @@ TDataPartitionTableResp getOrCreateDataPartition( TSStatus transfer(List newUnknownDataList); - /** Create a model. */ - TSStatus createModel(TCreateModelReq req); - - /** Drop a model. */ - TSStatus dropModel(TDropModelReq req); - - /** Load the specific model to the specific devices. */ - TSStatus loadModel(TLoadModelReq req); - - /** Unload the specific model from the specific devices. */ - TSStatus unloadModel(TUnloadModelReq req); - - /** Return the model table. */ - TShowModelsResp showModel(TShowModelsReq req); - - /** Return the loaded model instances. */ - TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req); - - /** Return all available AI devices. */ - TShowAIDevicesResp showAIDevices(); - - /** Update the model state */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req); - /** Set space quota. */ TSStatus setSpaceQuota(TSetSpaceQuotaReq req); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java deleted file mode 100644 index 3efdbc222b6d..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.manager; - -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.exception.NoAvailableAINodeException; -import org.apache.iotdb.confignode.persistence.ModelInfo; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class ModelManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class); - - private final ConfigManager configManager; - private final ModelInfo modelInfo; - - public ModelManager(ConfigManager configManager, ModelInfo modelInfo) { - this.configManager = configManager; - this.modelInfo = modelInfo; - } - - public TSStatus createModel(TCreateModelReq req) { - if (modelInfo.contain(req.modelName)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s already exists", req.modelName)); - } - try { - if (req.uri.isEmpty()) { - return configManager.getConsensusManager().write(new CreateModelPlan(req.modelName)); - } - return configManager.getProcedureManager().createModel(req.modelName, req.uri); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while getting model: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - public TSStatus dropModel(TDropModelReq req) { - if (modelInfo.checkModelType(req.getModelId()) != ModelType.USER_DEFINED) { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("Built-in model %s can't be removed", req.modelId)); - } - if (!modelInfo.contain(req.modelId)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s doesn't exists", req.modelId)); - } - return configManager.getProcedureManager().dropModel(req.getModelId()); - } - - public TSStatus loadModel(TLoadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq loadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq( - req.existingModelId, req.deviceIdList); - return client.loadModel(loadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to load model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TSStatus unloadModel(TUnloadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq unloadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq(req.modelId, req.deviceIdList); - return client.unloadModel(unloadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to unload model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TShowModelsResp showModel(final TShowModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowModelsReq showModelsReq = new TShowModelsReq(); - if (req.isSetModelId()) { - showModelsReq.setModelId(req.getModelId()); - } - TShowModelsResp resp = client.showModels(showModelsReq); - TShowModelsResp res = - new TShowModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setModelIdList(resp.getModelIdList()); - res.setModelTypeMap(resp.getModelTypeMap()); - res.setCategoryMap(resp.getCategoryMap()); - res.setStateMap(resp.getStateMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show models due to", e); - return new TShowModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowLoadedModelsResp showLoadedModel(final TShowLoadedModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowLoadedModelsReq showModelsReq = - new TShowLoadedModelsReq().setDeviceIdList(req.getDeviceIdList()); - TShowLoadedModelsResp resp = client.showLoadedModels(showModelsReq); - TShowLoadedModelsResp res = - new TShowLoadedModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceLoadedModelsMap(resp.getDeviceLoadedModelsMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show loaded models due to", e); - return new TShowLoadedModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowAIDevicesResp showAIDevices() { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp resp = client.showAIDevices(); - TShowAIDevicesResp res = - new TShowAIDevicesResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceIdList(resp.getDeviceIdList()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show AI devices due to", e); - return new TShowAIDevicesResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return new TGetModelInfoResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())) - .setAiNodeAddress( - configManager - .getNodeManager() - .getRegisteredAINodes() - .get(0) - .getLocation() - .getInternalEndPoint()); - } - - // Currently this method is only used by built-in timer_xl - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - if (!modelInfo.contain(req.getModelId())) { - return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode()) - .setMessage(String.format("Model %s doesn't exists", req.getModelId())); - } - try { - ModelInformation modelInformation = - new ModelInformation(ModelType.USER_DEFINED, req.getModelId()); - modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]); - modelInformation.setAttribute(req.getAttributes()); - modelInformation.setInputColumnSize(1); - if (req.isSetOutputLength()) { - modelInformation.setOutputLength(req.getOutputLength()); - } - if (req.isSetInputLength()) { - modelInformation.setInputLength(req.getInputLength()); - } - UpdateModelInfoPlan updateModelInfoPlan = - new UpdateModelInfoPlan(req.getModelId(), modelInformation); - if (req.isSetAiNodeIds()) { - updateModelInfoPlan.setNodeIds(req.getAiNodeIds()); - } - return configManager.getConsensusManager().write(updateModelInfoPlan); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while updating model info: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - private AINodeClient getAINodeClient() throws NoAvailableAINodeException, ClientManagerException { - List aiNodeInfo = configManager.getNodeManager().getRegisteredAINodeInfoList(); - if (aiNodeInfo.isEmpty()) { - throw new NoAvailableAINodeException(); - } - TEndPoint targetAINodeEndPoint = - new TEndPoint(aiNodeInfo.get(0).getInternalAddress(), aiNodeInfo.get(0).getInternalPort()); - try { - return AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public List getModelDistributions(String modelName) { - return modelInfo.getNodeIds(modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java index 2e4227af3fc8..d67e7721eef8 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java @@ -61,8 +61,6 @@ import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler; import org.apache.iotdb.confignode.procedure.env.RemoveDataNodeHandler; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -1414,24 +1412,6 @@ public TSStatus createCQ(TCreateCQReq req, ScheduledExecutorService scheduledExe return waitingProcedureFinished(procedure); } - public TSStatus createModel(String modelName, String uri) { - long procedureId = executor.submitProcedure(new CreateModelProcedure(modelName, uri)); - LOGGER.info("CreateModelProcedure was submitted, procedureId: {}.", procedureId); - return RpcUtils.SUCCESS_STATUS; - } - - public TSStatus dropModel(String modelId) { - DropModelProcedure procedure = new DropModelProcedure(modelId); - executor.submitProcedure(procedure); - TSStatus status = waitingProcedureFinished(procedure); - if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return status; - } else { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(status.getMessage()); - } - } - public TSStatus createPipePlugin( PipePluginMeta pipePluginMeta, byte[] jarFile, boolean isSetIfNotExistsCondition) { final CreatePipePluginProcedure createPipePluginProcedure = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java deleted file mode 100644 index aeada03d15cc..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.persistence; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelTable; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.commons.snapshot.SnapshotProcessor; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; -import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.PublicBAOS; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.annotation.concurrent.ThreadSafe; - -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; - -@ThreadSafe -public class ModelInfo implements SnapshotProcessor { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelInfo.class); - - private static final String SNAPSHOT_FILENAME = "model_info.snapshot"; - - private ModelTable modelTable; - - private final Map> modelNameToNodes; - - private final ReadWriteLock modelTableLock = new ReentrantReadWriteLock(); - - private static final Set builtInForecastModel = new HashSet<>(); - - private static final Set builtInAnomalyDetectionModel = new HashSet<>(); - - static { - builtInForecastModel.add("arima"); - builtInForecastModel.add("naive_forecaster"); - builtInForecastModel.add("stl_forecaster"); - builtInForecastModel.add("holtwinters"); - builtInForecastModel.add("exponential_smoothing"); - builtInForecastModel.add("timer_xl"); - builtInForecastModel.add("sundial"); - builtInAnomalyDetectionModel.add("gaussian_hmm"); - builtInAnomalyDetectionModel.add("gmm_hmm"); - builtInAnomalyDetectionModel.add("stray"); - } - - public ModelInfo() { - this.modelTable = new ModelTable(); - this.modelNameToNodes = new HashMap<>(); - } - - public boolean contain(String modelName) { - return modelTable.containsModel(modelName); - } - - public void acquireModelTableReadLock() { - LOGGER.info("acquire ModelTableReadLock"); - modelTableLock.readLock().lock(); - } - - public void releaseModelTableReadLock() { - LOGGER.info("release ModelTableReadLock"); - modelTableLock.readLock().unlock(); - } - - public void acquireModelTableWriteLock() { - LOGGER.info("acquire ModelTableWriteLock"); - modelTableLock.writeLock().lock(); - } - - public void releaseModelTableWriteLock() { - LOGGER.info("release ModelTableWriteLock"); - modelTableLock.writeLock().unlock(); - } - - // init the model in modeInfo, it won't update the details information of the model - public TSStatus createModel(CreateModelPlan plan) { - try { - acquireModelTableWriteLock(); - String modelName = plan.getModelName(); - modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING)); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } catch (Exception e) { - final String errorMessage = - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes, because of %s", - plan.getModelName(), e); - LOGGER.warn(errorMessage, e); - return new TSStatus(TSStatusCode.CREATE_MODEL_ERROR.getStatusCode()).setMessage(errorMessage); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModelInNode(int aiNodeId) { - acquireModelTableWriteLock(); - try { - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - entry.getValue().remove(Integer.valueOf(aiNodeId)); - // if list is empty, remove this model totally - if (entry.getValue().isEmpty()) { - modelTable.removeModel(entry.getKey()); - modelNameToNodes.remove(entry.getKey()); - } - } - // currently, we only have one AINode at a time, so we can just clear failed model. - modelTable.clearFailedModel(); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModel(String modelName) { - acquireModelTableWriteLock(); - TSStatus status; - if (modelTable.containsModel(modelName)) { - modelTable.removeModel(modelName); - modelNameToNodes.remove(modelName); - status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - status = - new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("model [%s] has not been created.", modelName)); - } - releaseModelTableWriteLock(); - return status; - } - - public List getNodeIds(String modelName) { - return modelNameToNodes.getOrDefault(modelName, Collections.emptyList()); - } - - private ModelInformation getModelByName(String modelName) { - ModelType modelType = checkModelType(modelName); - if (modelType != ModelType.USER_DEFINED) { - if (modelType == ModelType.BUILT_IN_FORECAST && builtInForecastModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName); - } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION - && builtInAnomalyDetectionModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName); - } - } else { - return modelTable.getModelInformationById(modelName); - } - return null; - } - - public ModelTableResp showModel(ShowModelPlan plan) { - acquireModelTableReadLock(); - try { - ModelTableResp modelTableResp = - new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - if (plan.isSetModelName()) { - ModelInformation modelInformation = getModelByName(plan.getModelName()); - if (modelInformation != null) { - modelTableResp.addModelInformation(modelInformation); - } - } else { - modelTableResp.addModelInformation(modelTable.getAllModelInformation()); - for (String modelName : builtInForecastModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName)); - } - for (String modelName : builtInAnomalyDetectionModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName)); - } - } - return modelTableResp; - } catch (IOException e) { - LOGGER.warn("Fail to get ModelTable", e); - return new ModelTableResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - private boolean containsBuiltInModelName(Set builtInModelSet, String modelName) { - // ignore the case - for (String builtInModelName : builtInModelSet) { - if (builtInModelName.equalsIgnoreCase(modelName)) { - return true; - } - } - return false; - } - - public ModelType checkModelType(String modelName) { - if (containsBuiltInModelName(builtInForecastModel, modelName)) { - return ModelType.BUILT_IN_FORECAST; - } else if (containsBuiltInModelName(builtInAnomalyDetectionModel, modelName)) { - return ModelType.BUILT_IN_ANOMALY_DETECTION; - } else { - return ModelType.USER_DEFINED; - } - } - - private int getAvailableAINodeForModel(String modelName, ModelType modelType) { - if (modelType == ModelType.USER_DEFINED) { - List aiNodeIds = modelNameToNodes.get(modelName); - if (aiNodeIds != null) { - return aiNodeIds.get(0); - } - } else { - // any AINode is fine for built-in model - // 0 is always the nodeId for configNode, so it's fine to use 0 as special value - return 0; - } - return -1; - } - - // This method will be used by dataNode to get schema of the model for inference - public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { - acquireModelTableReadLock(); - try { - String modelName = plan.getModelId(); - GetModelInfoResp getModelInfoResp; - ModelInformation modelInformation; - ModelType modelType; - // check if it's a built-in model - if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { - modelInformation = new ModelInformation(modelType, modelName); - } else { - modelInformation = modelTable.getModelInformationById(modelName); - } - - if (modelInformation != null) { - getModelInfoResp = - new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - } else { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("model [%s] has not been created.", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } - PublicBAOS buffer = new PublicBAOS(); - DataOutputStream stream = new DataOutputStream(buffer); - modelInformation.serialize(stream); - // select the nodeId to process the task, currently we default use the first one. - int aiNodeId = getAvailableAINodeForModel(modelName, modelType); - if (aiNodeId == -1) { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("There is no AINode with %s available", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } else { - getModelInfoResp.setTargetAINodeId(aiNodeId); - } - return getModelInfoResp; - } catch (IOException e) { - LOGGER.warn("Fail to get model info", e); - return new GetModelInfoResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - public TSStatus updateModelInfo(UpdateModelInfoPlan plan) { - acquireModelTableWriteLock(); - try { - String modelName = plan.getModelName(); - if (modelTable.containsModel(modelName)) { - modelTable.updateModel(modelName, plan.getModelInformation()); - } - if (!plan.getNodeIds().isEmpty()) { - // only used in model registration, so we can just put the nodeIds in the map without - // checking - modelNameToNodes.put(modelName, plan.getNodeIds()); - } - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - @Override - public boolean processTakeSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (snapshotFile.exists() && snapshotFile.isFile()) { - LOGGER.error( - "Failed to take snapshot of ModelInfo, because snapshot file [{}] is already exist.", - snapshotFile.getAbsolutePath()); - return false; - } - - acquireModelTableReadLock(); - try (FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)) { - modelTable.serialize(fileOutputStream); - ReadWriteIOUtils.write(modelNameToNodes.size(), fileOutputStream); - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - ReadWriteIOUtils.write(entry.getKey(), fileOutputStream); - ReadWriteIOUtils.write(entry.getValue().size(), fileOutputStream); - for (Integer nodeId : entry.getValue()) { - ReadWriteIOUtils.write(nodeId, fileOutputStream); - } - } - fileOutputStream.getFD().sync(); - return true; - } finally { - releaseModelTableReadLock(); - } - } - - @Override - public void processLoadSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (!snapshotFile.exists() || !snapshotFile.isFile()) { - LOGGER.error( - "Failed to load snapshot of ModelInfo, snapshot file [{}] does not exist.", - snapshotFile.getAbsolutePath()); - return; - } - acquireModelTableWriteLock(); - try (FileInputStream fileInputStream = new FileInputStream(snapshotFile)) { - modelTable.clear(); - modelTable = ModelTable.deserialize(fileInputStream); - int size = ReadWriteIOUtils.readInt(fileInputStream); - for (int i = 0; i < size; i++) { - String modelName = ReadWriteIOUtils.readString(fileInputStream); - int nodeSize = ReadWriteIOUtils.readInt(fileInputStream); - List nodes = new LinkedList<>(); - for (int j = 0; j < nodeSize; j++) { - nodes.add(ReadWriteIOUtils.readInt(fileInputStream)); - } - modelNameToNodes.put(modelName, nodes); - } - } finally { - releaseModelTableWriteLock(); - } - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java index fe8b28c4da2e..d6bad518f6f4 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java @@ -35,8 +35,6 @@ import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetUDFJarPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetNodePathsPartitionPlan; @@ -84,10 +82,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -150,7 +144,6 @@ import org.apache.iotdb.confignode.exception.physical.UnknownPhysicalPlanTypeException; import org.apache.iotdb.confignode.manager.pipe.agent.PipeConfigNodeAgent; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -210,8 +203,6 @@ public class ConfigPlanExecutor { private final CQInfo cqInfo; - private final ModelInfo modelInfo; - private final PipeInfo pipeInfo; private final SubscriptionInfo subscriptionInfo; @@ -230,7 +221,6 @@ public ConfigPlanExecutor( UDFInfo udfInfo, TriggerInfo triggerInfo, CQInfo cqInfo, - ModelInfo modelInfo, PipeInfo pipeInfo, SubscriptionInfo subscriptionInfo, QuotaInfo quotaInfo, @@ -262,9 +252,6 @@ public ConfigPlanExecutor( this.cqInfo = cqInfo; this.snapshotProcessorList.add(cqInfo); - this.modelInfo = modelInfo; - this.snapshotProcessorList.add(modelInfo); - this.pipeInfo = pipeInfo; this.snapshotProcessorList.add(pipeInfo); @@ -362,10 +349,6 @@ public DataSet executeQueryPlan(final ConfigPhysicalReadPlan req) return udfInfo.getUDFJar((GetUDFJarPlan) req); case GetAllFunctionTable: return udfInfo.getAllUDFTable(); - case ShowModel: - return modelInfo.showModel((ShowModelPlan) req); - case GetModelInfo: - return modelInfo.getModelInfo((GetModelInfoPlan) req); case GetPipePluginTable: return pipeInfo.getPipePluginInfo().showPipePlugins(); case GetPipePluginJar: @@ -648,14 +631,6 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) return cqInfo.activeCQ((ActiveCQPlan) physicalPlan); case UPDATE_CQ_LAST_EXEC_TIME: return cqInfo.updateCQLastExecutionTime((UpdateCQLastExecTimePlan) physicalPlan); - case CreateModel: - return modelInfo.createModel((CreateModelPlan) physicalPlan); - case UpdateModelInfo: - return modelInfo.updateModelInfo((UpdateModelInfoPlan) physicalPlan); - case DropModel: - return modelInfo.dropModel(((DropModelPlan) physicalPlan).getModelName()); - case DropModelInNode: - return modelInfo.dropModelInNode(((DropModelInNodePlan) physicalPlan).getNodeId()); case CreatePipePlugin: return pipeInfo.getPipePluginInfo().createPipePlugin((CreatePipePluginPlan) physicalPlan); case DropPipePlugin: diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java deleted file mode 100644 index 989061610213..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.manager.ConfigManager; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.CreateModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - -public class CreateModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class); - private static final int RETRY_THRESHOLD = 0; - - private String modelName; - - private String uri; - - private ModelInformation modelInformation = null; - - private List aiNodeIds; - - private String loadErrorMsg = ""; - - public CreateModelProcedure() { - super(); - } - - public CreateModelProcedure(String modelName, String uri) { - super(); - this.modelName = modelName; - this.uri = uri; - this.aiNodeIds = new ArrayList<>(); - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, CreateModelState state) { - if (modelName == null || uri == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case LOADING: - initModel(env); - loadModel(env); - setNextState(CreateModelState.ACTIVE); - break; - case ACTIVE: - modelInformation.updateStatus(ModelStatus.ACTIVE); - updateModel(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing createModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in CreateModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to create model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - modelInformation = new ModelInformation(modelName, ModelStatus.UNAVAILABLE); - modelInformation.setAttribute(loadErrorMsg); - updateModel(env); - setFailure( - new ProcedureException( - String.format("Fail to create model [%s] at STATE [%s]", modelName, state))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void initModel(ConfigNodeProcedureEnv env) throws ConsensusException { - LOGGER.info("Start to add model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - TSStatus response = configManager.getConsensusManager().write(new CreateModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } - - private void checkModelInformationEquals(ModelInformation receiveModelInfo) { - if (modelInformation == null) { - modelInformation = receiveModelInfo; - } else { - if (!modelInformation.equals(receiveModelInfo)) { - throw new ModelManagementException( - String.format( - "Failed to load model [%s] on AI Nodes, model information is not equal in different nodes", - modelName)); - } - } - } - - private void loadModel(ConfigNodeProcedureEnv env) { - for (TAINodeConfiguration curNodeConfig : - env.getConfigManager().getNodeManager().getRegisteredAINodes()) { - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient(curNodeConfig.getLocation().getInternalEndPoint())) { - ModelInformation resp = client.registerModel(modelName, uri); - checkModelInformationEquals(resp); - aiNodeIds.add(curNodeConfig.getLocation().aiNodeId); - } catch (LoadModelException e) { - LOGGER.warn(e.getMessage()); - loadErrorMsg = e.getMessage(); - } catch (Exception e) { - LOGGER.warn( - "Failed to load model on AINode {} from ConfigNode", - curNodeConfig.getLocation().getInternalEndPoint()); - loadErrorMsg = e.getMessage(); - } - } - - if (aiNodeIds.isEmpty()) { - throw new ModelManagementException( - String.format("CREATE MODEL [%s] failed on all AINodes:[%s]", modelName, loadErrorMsg)); - } - } - - private void updateModel(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to update model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - try { - TSStatus response = - configManager - .getConsensusManager() - .write(new UpdateModelInfoPlan(modelName, modelInformation, aiNodeIds)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, CreateModelState state) - throws IOException, InterruptedException, ProcedureException { - // do nothing - } - - @Override - protected boolean isRollbackSupported(CreateModelState state) { - return false; - } - - @Override - protected CreateModelState getState(int stateId) { - return CreateModelState.values()[stateId]; - } - - @Override - protected int getStateId(CreateModelState createModelState) { - return createModelState.ordinal(); - } - - @Override - protected CreateModelState getInitialState() { - return CreateModelState.LOADING; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - ReadWriteIOUtils.write(uri, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - uri = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof CreateModelProcedure) { - CreateModelProcedure thatProc = (CreateModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && Objects.equals(thatProc.modelName, this.modelName) - && Objects.equals(thatProc.uri, this.uri); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName, uri); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java deleted file mode 100644 index daa029e04ddf..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.DropModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Objects; - -import static org.apache.iotdb.confignode.procedure.state.model.DropModelState.CONFIG_NODE_DROPPED; - -public class DropModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class); - private static final int RETRY_THRESHOLD = 1; - - private String modelName; - - public DropModelProcedure() { - super(); - } - - public DropModelProcedure(String modelName) { - super(); - this.modelName = modelName; - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state) { - if (modelName == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case AI_NODE_DROPPED: - LOGGER.info("Start to drop model [{}] on AI Nodes", modelName); - dropModelOnAINode(env); - setNextState(CONFIG_NODE_DROPPED); - break; - case CONFIG_NODE_DROPPED: - dropModelOnConfigNode(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing dropModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in DropModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to drop model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - setFailure( - new ProcedureException( - String.format( - "Fail to drop model [%s] at STATE [%s], %s", - modelName, state, e.getMessage()))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void dropModelOnAINode(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to drop model file [{}] on AI Node", modelName); - - List aiNodes = - env.getConfigManager().getNodeManager().getRegisteredAINodes(); - aiNodes.forEach( - aiNode -> { - int nodeId = aiNode.getLocation().getAiNodeId(); - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient( - env.getConfigManager() - .getNodeManager() - .getRegisteredAINode(nodeId) - .getLocation() - .getInternalEndPoint())) { - TSStatus status = client.deleteModel(new TDeleteModelReq(modelName)); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - status.getMessage()); - } - } catch (Exception e) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - e.getMessage()); - } - }); - } - - private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) { - try { - TSStatus response = - env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(response.getMessage()); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Fail to start training model [%s] on AI Node: %s", modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, DropModelState state) - throws IOException, InterruptedException, ProcedureException { - // no need to rollback - } - - @Override - protected DropModelState getState(int stateId) { - return DropModelState.values()[stateId]; - } - - @Override - protected int getStateId(DropModelState dropModelState) { - return dropModelState.ordinal(); - } - - @Override - protected DropModelState getInitialState() { - return DropModelState.AI_NODE_DROPPED; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.DROP_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof DropModelProcedure) { - DropModelProcedure thatProc = (DropModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && (thatProc.modelName).equals(this.modelName); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java index 2cab08c28244..2a1c6881b141 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java @@ -23,13 +23,12 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; import org.apache.iotdb.confignode.procedure.exception.ProcedureException; import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState; import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.slf4j.Logger; @@ -65,17 +64,11 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState st try { switch (state) { - case MODEL_DELETE: - env.getConfigManager() - .getConsensusManager() - .write(new DropModelInNodePlan(removedAINode.aiNodeId)); - // Cause the AINode is removed, so we don't need to remove the model file. - setNextState(RemoveAINodeState.NODE_STOP); - break; case NODE_STOP: TSStatus resp = null; try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(removedAINode.getInternalEndPoint())) { + AINodeClientManager.getInstance() + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { resp = client.stopAINode(); } catch (Exception e) { LOGGER.warn( @@ -148,7 +141,7 @@ protected int getStateId(RemoveAINodeState removeAINodeState) { @Override protected RemoveAINodeState getInitialState() { - return RemoveAINodeState.MODEL_DELETE; + return RemoveAINodeState.NODE_STOP; } @Override diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java index 8a1a6a1bb03b..49820df66361 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java @@ -20,7 +20,6 @@ package org.apache.iotdb.confignode.procedure.state; public enum RemoveAINodeState { - MODEL_DELETE, NODE_STOP, NODE_REMOVE } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java index e023171f4fa8..f20a6999d593 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java @@ -22,8 +22,6 @@ import org.apache.iotdb.commons.exception.runtime.ThriftSerDeException; import org.apache.iotdb.confignode.procedure.Procedure; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -263,12 +261,6 @@ public Procedure create(ByteBuffer buffer) throws IOException { case DROP_PIPE_PLUGIN_PROCEDURE: procedure = new DropPipePluginProcedure(); break; - case CREATE_MODEL_PROCEDURE: - procedure = new CreateModelProcedure(); - break; - case DROP_MODEL_PROCEDURE: - procedure = new DropModelProcedure(); - break; case AUTH_OPERATE_PROCEDURE: procedure = new AuthOperationProcedure(false); break; @@ -494,10 +486,6 @@ public static ProcedureType getProcedureType(final Procedure procedure) { return ProcedureType.CREATE_PIPE_PLUGIN_PROCEDURE; } else if (procedure instanceof DropPipePluginProcedure) { return ProcedureType.DROP_PIPE_PLUGIN_PROCEDURE; - } else if (procedure instanceof CreateModelProcedure) { - return ProcedureType.CREATE_MODEL_PROCEDURE; - } else if (procedure instanceof DropModelProcedure) { - return ProcedureType.DROP_MODEL_PROCEDURE; } else if (procedure instanceof CreatePipeProcedureV2) { return ProcedureType.CREATE_PIPE_PROCEDURE_V2; } else if (procedure instanceof StartPipeProcedureV2) { diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java index 65ac1fb24ad5..d076a7d9d926 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java @@ -85,7 +85,9 @@ public enum ProcedureType { RENAME_VIEW_PROCEDURE((short) 764), /** AI Model */ + @Deprecated // Since 2.0.6, all models are managed by AINode CREATE_MODEL_PROCEDURE((short) 800), + @Deprecated // Since 2.0.6, all models are managed by AINode DROP_MODEL_PROCEDURE((short) 801), REMOVE_AI_NODE_PROCEDURE((short) 802), diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java index 59ce7352312f..6582a5bfff8e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java @@ -115,7 +115,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -144,7 +143,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -163,8 +161,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -226,7 +222,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.confignode.service.ConfigNode; import org.apache.iotdb.consensus.exception.ConsensusException; import org.apache.iotdb.db.queryengine.plan.relational.type.AuthorRType; @@ -1362,26 +1357,6 @@ public TShowCQResp showCQ() { return configManager.showCQ(); } - @Override - public TSStatus createModel(TCreateModelReq req) { - return configManager.createModel(req); - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - return configManager.dropModel(req); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return configManager.getModelInfo(req); - } - - @Override - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return configManager.updateModelInfo(req); - } - @Override public TSStatus setSpaceQuota(final TSetSpaceQuotaReq req) throws TException { return configManager.setSpaceQuota(req); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java deleted file mode 100644 index 0d784617c090..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.ClientManagerMetrics; -import org.apache.iotdb.commons.client.IClientPoolFactory; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ClientPoolProperty; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.concurrent.ThreadName; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.commons.pool2.impl.GenericKeyedObjectPool; - -import java.util.Optional; - -/** Dedicated factory for AINodeClient + AINodeClientPoolFactory. */ -public class AINodeClientFactory extends ThriftClientFactory { - - private static final int connectionTimeout = - CommonDescriptor.getInstance().getConfig().getDnConnectionTimeoutInMS(); - - public AINodeClientFactory( - ClientManager manager, ThriftClientProperty thriftProperty) { - super(manager, thriftProperty); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public void destroyObject(TEndPoint key, PooledObject pooled) throws Exception { - pooled.getObject().invalidate(); - } - - @Override - public boolean validateObject(TEndPoint key, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(org.apache.thrift.transport.TTransport::isOpen) - .orElse(false); - } - - /** The PoolFactory originally inside ClientPoolFactory — now moved here. */ - public static class AINodeClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - // Build thrift client properties - ThriftClientProperty thriftProperty = - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(connectionTimeout) - .setRpcThriftCompressionEnabled( - CommonDescriptor.getInstance().getConfig().isRpcThriftCompressionEnabled()) - .build(); - - GenericKeyedObjectPool pool = - new GenericKeyedObjectPool<>( - new AINodeClientFactory(manager, thriftProperty), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode( - CommonDescriptor.getInstance().getConfig().getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), pool); - - return pool; - } - } - - public static class AINodeHeartbeatClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - final CommonConfig conf = CommonDescriptor.getInstance().getConfig(); - - GenericKeyedObjectPool clientPool = - new GenericKeyedObjectPool<>( - new AsyncAINodeServiceClient.Factory( - manager, - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) - .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) - .setPrintLogWhenEncounterException(false) - .build(), - ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), clientPool); - - return clientPool; - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java index 2c037cf0f3e5..df80d49b502b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java @@ -73,7 +73,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -102,7 +101,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -121,8 +119,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -184,7 +180,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; @@ -525,7 +520,8 @@ public TAINodeRestartResp restartAINode(TAINodeRestartReq req) throws TException @Override public TGetAINodeLocationResp getAINodeLocation() throws TException { - return client.getAINodeLocation(); + return executeRemoteCallWithRetry( + () -> client.getAINodeLocation(), resp -> !updateConfigNodeLeader(resp.status)); } @Override @@ -1339,28 +1335,6 @@ public TShowCQResp showCQ() throws TException { () -> client.showCQ(), resp -> !updateConfigNodeLeader(resp.status)); } - @Override - public TSStatus createModel(TCreateModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.createModel(req), status -> !updateConfigNodeLeader(status)); - } - - @Override - public TSStatus dropModel(TDropModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.dropModel(req), status -> !updateConfigNodeLeader(status)); - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.getModelInfo(req), resp -> !updateConfigNodeLeader(resp.getStatus())); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.updateModelInfo(req), status -> !updateConfigNodeLeader(status)); - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return executeRemoteCallWithRetry( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java index b5f5df430129..da0d84d8466f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java @@ -27,12 +27,13 @@ import org.apache.iotdb.commons.consensus.ConfigRegionId; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; import org.apache.commons.pool2.impl.GenericKeyedObjectPool; public class DataNodeClientPoolFactory { - private static final IoTDBConfig conf = IoTDBDescriptor.getInstance().getConfig(); + private static final IoTDBConfig CONF = IoTDBDescriptor.getInstance().getConfig(); private DataNodeClientPoolFactory() { // Empty constructor @@ -49,11 +50,11 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() @@ -73,15 +74,38 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS() * 10) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS() * 10) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .setSelectorNumOfAsyncClientManager( - conf.getSelectorNumOfClientManager() / 10 > 0 - ? conf.getSelectorNumOfClientManager() / 10 + CONF.getSelectorNumOfClientManager() / 10 > 0 + ? CONF.getSelectorNumOfClientManager() / 10 : 1) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } + + public static class AINodeClientPoolFactory implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AINodeClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) + .build()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java deleted file mode 100644 index 54150b8f3007..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; -import org.apache.iotdb.ainode.rpc.thrift.TConfigs; -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; -import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.ThriftClient; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.rpc.TConfigurationConst; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.thrift.TException; -import org.apache.thrift.transport.TSSLTransportFactory; -import org.apache.thrift.transport.TSocket; -import org.apache.thrift.transport.TTransport; -import org.apache.thrift.transport.TTransportException; -import org.apache.thrift.transport.layered.TFramedTransport; -import org.apache.tsfile.enums.TSDataType; -import org.apache.tsfile.read.common.block.TsBlock; -import org.apache.tsfile.read.common.block.column.TsBlockSerde; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; - -import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; -import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR; - -public class AINodeClient implements AutoCloseable, ThriftClient { - - private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class); - - private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - - private TEndPoint endPoint; - - private TTransport transport; - - private final ThriftClientProperty property; - private IAINodeRPCService.Client client; - - public static final String MSG_CONNECTION_FAIL = - "Fail to connect to AINode. Please check status of AINode"; - private static final int MAX_RETRY = 3; - - @FunctionalInterface - private interface RemoteCall { - R apply(IAINodeRPCService.Client c) throws TException; - } - - private final TsBlockSerde tsBlockSerde = new TsBlockSerde(); - - ClientManager clientManager; - - private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = - ConfigNodeClientManager.getInstance(); - - private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); - - public static TEndPoint getCurrentEndpoint() { - TAINodeLocation loc = CURRENT_LOCATION.get(); - if (loc == null) { - loc = refreshFromConfigNode(); - } - return (loc == null) ? null : pickEndpointFrom(loc); - } - - public static void updateGlobalAINodeLocation(final TAINodeLocation loc) { - if (loc != null) { - CURRENT_LOCATION.set(loc); - } - } - - private R executeRemoteCallWithRetry(RemoteCall call) throws TException { - TException last = null; - for (int attempt = 1; attempt <= MAX_RETRY; attempt++) { - try { - if (transport == null || !transport.isOpen()) { - final TEndPoint ep = getCurrentEndpoint(); - if (ep == null) { - throw new TException("AINode endpoint unavailable"); - } - this.endPoint = ep; - init(); - } - return call.apply(client); - } catch (TException e) { - last = e; - invalidate(); - final TAINodeLocation loc = refreshFromConfigNode(); - if (loc != null) { - this.endPoint = pickEndpointFrom(loc); - } - try { - Thread.sleep(1000L * attempt); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - } - } - } - throw (last != null ? last : new TException(MSG_CONNECTION_FAIL)); - } - - private static TAINodeLocation refreshFromConfigNode() { - try (final ConfigNodeClient cn = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TGetAINodeLocationResp resp = cn.getAINodeLocation(); - if (resp != null && resp.isSetAiNodeLocation()) { - final TAINodeLocation loc = resp.getAiNodeLocation(); - CURRENT_LOCATION.set(loc); - return loc; - } - } catch (Exception e) { - LoggerFactory.getLogger(AINodeClient.class) - .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); - } - return null; - } - - private static TEndPoint pickEndpointFrom(final TAINodeLocation loc) { - if (loc == null) return null; - if (loc.isSetInternalEndPoint() && loc.getInternalEndPoint() != null) { - return loc.getInternalEndPoint(); - } - return null; - } - - public AINodeClient( - ThriftClientProperty property, - TEndPoint endPoint, - ClientManager clientManager) - throws TException { - this.property = property; - this.clientManager = clientManager; - // Instance default endpoint (pool key). Global location can override it on retries. - this.endPoint = endPoint; - init(); - } - - private void init() throws TException { - try { - if (commonConfig.isEnableInternalSSL()) { - TSSLTransportFactory.TSSLTransportParameters params = - new TSSLTransportFactory.TSSLTransportParameters(); - params.setTrustStore(commonConfig.getTrustStorePath(), commonConfig.getTrustStorePwd()); - params.setKeyStore(commonConfig.getKeyStorePath(), commonConfig.getKeyStorePwd()); - transport = - new TFramedTransport.Factory() - .getTransport( - TSSLTransportFactory.getClientSocket( - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs(), - params)); - } else { - transport = - new TFramedTransport.Factory() - .getTransport( - new TSocket( - TConfigurationConst.defaultTConfiguration, - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs())); - } - if (!transport.isOpen()) { - transport.open(); - } - } catch (TTransportException e) { - throw new TException(MSG_CONNECTION_FAIL); - } - client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); - } - - public TTransport getTransport() { - return transport; - } - - public TSStatus stopAINode() throws TException { - try { - TSStatus status = client.stopAINode(); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(status.message); - } - return status; - } catch (TException e) { - logger.warn( - "Failed to connect to AINode from ConfigNode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public ModelInformation registerModel(String modelName, String uri) throws LoadModelException { - try { - TRegisterModelReq req = new TRegisterModelReq(uri, modelName); - TRegisterModelResp resp = client.registerModel(req); - if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new LoadModelException(resp.status.message, resp.status.getCode()); - } - return parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs()); - } catch (TException e) { - throw new LoadModelException( - e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()); - } - } - - private ModelInformation parseModelInformation( - String modelName, String attributes, TConfigs configs) { - int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray(); - int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray(); - - TSDataType[] inputType = new TSDataType[inputShape[1]]; - TSDataType[] outputType = new TSDataType[outputShape[1]]; - for (int i = 0; i < inputShape[1]; i++) { - inputType[i] = TSDataType.values()[configs.getInput_type().get(i)]; - } - for (int i = 0; i < outputShape[1]; i++) { - outputType[i] = TSDataType.values()[configs.getOutput_type().get(i)]; - } - - return new ModelInformation( - modelName, inputShape, outputShape, inputType, outputType, attributes); - } - - public TSStatus deleteModel(TDeleteModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.deleteModel(req)); - } - - public TSStatus loadModel(TLoadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.loadModel(req)); - } - - public TSStatus unloadModel(TUnloadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.unloadModel(req)); - } - - public TShowModelsResp showModels(TShowModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showModels(req)); - } - - public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showLoadedModels(req)); - } - - public TShowAIDevicesResp showAIDevices() throws TException { - return executeRemoteCallWithRetry(IAINodeRPCService.Client::showAIDevices); - } - - public TInferenceResp inference( - String modelId, - TsBlock inputTsBlock, - Map inferenceAttributes, - TWindowParams windowParams) - throws TException { - try { - TInferenceReq inferenceReq = new TInferenceReq(modelId, tsBlockSerde.serialize(inputTsBlock)); - if (windowParams != null) { - inferenceReq.setWindowParams(windowParams); - } - if (inferenceAttributes != null) { - inferenceReq.setInferenceAttributes(inferenceAttributes); - } - return executeRemoteCallWithRetry(c -> c.inference(inferenceReq)); - } catch (IOException e) { - throw new TException("An exception occurred while serializing input data", e); - } catch (TException e) { - logger.warn( - "Error happens in AINode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public TForecastResp forecast( - String modelId, TsBlock inputTsBlock, int outputLength, Map options) { - try { - TForecastReq forecastReq = - new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock), outputLength); - forecastReq.setOptions(options); - return executeRemoteCallWithRetry(c -> c.forecast(forecastReq)); - } catch (IOException e) { - TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode()); - tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage())); - return new TForecastResp(tsStatus); - } catch (TException e) { - TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode()); - tsStatus.setMessage( - String.format( - "Failed to connect to AINode when executing %s: %s", - Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage())); - return new TForecastResp(tsStatus); - } - } - - public TSStatus createTrainingTask(TTrainingReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.createTrainingTask(req)); - } - - @Override - public void close() throws Exception { - clientManager.returnClient(endPoint, this); - } - - @Override - public void invalidate() { - Optional.ofNullable(transport).ifPresent(TTransport::close); - } - - @Override - public void invalidateAll() { - clientManager.clear(endPoint); - } - - @Override - public boolean printLogWhenEncounterException() { - return property.isPrintLogWhenEncounterException(); - } - - public static class Factory extends ThriftClientFactory { - - public Factory( - ClientManager clientClientManager, - ThriftClientProperty thriftClientProperty) { - super(clientClientManager, thriftClientProperty); - } - - @Override - public void destroyObject(TEndPoint tEndPoint, PooledObject pooledObject) - throws Exception { - pooledObject.getObject().invalidate(); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public boolean validateObject(TEndPoint tEndPoint, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(TTransport::isOpen) - .orElse(false); - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java deleted file mode 100644 index faef1c1ae7b6..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; - -public class AINodeClientManager { - - public static final int DEFAULT_AINODE_ID = 0; - - private static final AINodeClientManager INSTANCE = new AINodeClientManager(); - - private final IClientManager clientManager; - - private volatile TEndPoint defaultAINodeEndPoint; - - private AINodeClientManager() { - this.clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeClientPoolFactory()); - } - - public static AINodeClientManager getInstance() { - return INSTANCE; - } - - public void updateDefaultAINodeLocation(TEndPoint endPoint) { - this.defaultAINodeEndPoint = endPoint; - } - - public AINodeClient borrowClient(TEndPoint endPoint) throws Exception { - return clientManager.borrowClient(endPoint); - } - - public AINodeClient borrowClient(int aiNodeId) throws Exception { - if (aiNodeId != DEFAULT_AINODE_ID) { - throw new IllegalArgumentException("Unsupported AINodeId: " + aiNodeId); - } - if (defaultAINodeEndPoint == null) { - defaultAINodeEndPoint = AINodeClient.getCurrentEndpoint(); - } - return clientManager.borrowClient(defaultAINodeEndPoint); - } - - public void clear(TEndPoint endPoint) { - clientManager.clear(endPoint); - } - - public void clearAll() { - clientManager.close(); - } - - public IClientManager getRawClientManager() { - return clientManager; - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java new file mode 100644 index 000000000000..5eaffc40af9c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatResp; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; +import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ClientManager; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.client.factory.ThriftClientFactory; +import org.apache.iotdb.commons.client.property.ThriftClientProperty; +import org.apache.iotdb.commons.client.sync.SyncThriftClientWithErrorHandler; +import org.apache.iotdb.commons.conf.CommonConfig; +import org.apache.iotdb.commons.conf.CommonDescriptor; +import org.apache.iotdb.commons.consensus.ConfigRegionId; +import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; +import org.apache.iotdb.db.conf.IoTDBConfig; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.ConfigNodeClient; +import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; +import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; + +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLHandshakeException; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class AINodeClient implements IAINodeRPCService.Iface, AutoCloseable, ThriftClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeClient.class); + + private static final CommonConfig COMMON_CONFIG = CommonDescriptor.getInstance().getConfig(); + private static final IoTDBConfig IOTDB_CONFIG = IoTDBDescriptor.getInstance().getConfig(); + + private TTransport transport; + + private final ThriftClientProperty property; + private IAINodeRPCService.Client client; + + private static final int MAX_RETRY = 5; + private static final int RETRY_INTERVAL_MS = 100; + public static final String MSG_ALL_RETRY_FAILED = + String.format( + "Failed to connect to AINode after %d retries, please check the status of AINode", + MAX_RETRY); + public static final String MSG_AINODE_CONNECTION_FAIL = + "Fail to connect to AINode from DataNode %s when executing %s."; + private static final String UNSUPPORTED_INVOCATION = + "This method is not supported for invocation by DataNode"; + + @Override + public TSStatus stopAINode() throws TException { + return executeRemoteCallWithRetry(() -> client.stopAINode()); + } + + @Override + public TShowModelsResp showModels(TShowModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showModels(req)); + } + + @Override + public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showLoadedModels(req)); + } + + @Override + public TShowAIDevicesResp showAIDevices() throws TException { + return executeRemoteCallWithRetry(() -> client.showAIDevices()); + } + + @Override + public TSStatus deleteModel(TDeleteModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.deleteModel(req)); + } + + @Override + public TRegisterModelResp registerModel(TRegisterModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.registerModel(req)); + } + + @Override + public TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) { + throw new UnsupportedOperationException(UNSUPPORTED_INVOCATION); + } + + @Override + public TSStatus createTrainingTask(TTrainingReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.createTrainingTask(req)); + } + + @Override + public TSStatus loadModel(TLoadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.loadModel(req)); + } + + @Override + public TSStatus unloadModel(TUnloadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.unloadModel(req)); + } + + @Override + public TInferenceResp inference(TInferenceReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.inference(req)); + } + + @Override + public TForecastResp forecast(TForecastReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.forecast(req)); + } + + @FunctionalInterface + private interface RemoteCall { + R apply() throws TException; + } + + ClientManager clientManager; + + private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = + ConfigNodeClientManager.getInstance(); + + private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); + + private R executeRemoteCallWithRetry(RemoteCall call) throws TException { + for (int attempt = 0; attempt < MAX_RETRY; attempt++) { + try { + return call.apply(); + } catch (TException e) { + final String message = + String.format( + MSG_AINODE_CONNECTION_FAIL, + IOTDB_CONFIG.getAddressAndPort(), + Thread.currentThread().getStackTrace()[2].getMethodName()); + LOGGER.warn(message, e); + CURRENT_LOCATION.set(null); + if (e.getCause() != null && e.getCause() instanceof SSLHandshakeException) { + throw e; + } + } + try { + TimeUnit.MILLISECONDS.sleep(RETRY_INTERVAL_MS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.warn( + "Unexpected interruption when waiting to try to connect to AINode, may because current node has been down. Will break current execution process to avoid meaningless wait."); + break; + } + tryToConnect(property.getConnectionTimeoutMs()); + } + throw new TException(MSG_ALL_RETRY_FAILED); + } + + private void tryToConnect(int timeoutMs) { + TEndPoint endpoint = getCurrentEndpoint(); + if (endpoint != null) { + try { + connect(endpoint, timeoutMs); + return; + } catch (TException e) { + LOGGER.warn("The current AINode may have been down {}, because", endpoint, e); + CURRENT_LOCATION.set(null); + } + } else { + LOGGER.warn("Cannot connect to any AINode due to there are no available ones."); + } + if (transport != null) { + transport.close(); + } + } + + public void connect(TEndPoint endpoint, int timeoutMs) throws TException { + transport = + COMMON_CONFIG.isEnableInternalSSL() + ? DeepCopyRpcTransportFactory.INSTANCE.getTransport( + endpoint.getIp(), + endpoint.getPort(), + timeoutMs, + COMMON_CONFIG.getTrustStorePath(), + COMMON_CONFIG.getTrustStorePwd(), + COMMON_CONFIG.getKeyStorePath(), + COMMON_CONFIG.getKeyStorePwd()) + : DeepCopyRpcTransportFactory.INSTANCE.getTransport( + // As there is a try-catch already, we do not need to use TSocket.wrap + endpoint.getIp(), endpoint.getPort(), timeoutMs); + if (!transport.isOpen()) { + transport.open(); + } + client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); + } + + public TEndPoint getCurrentEndpoint() { + TAINodeLocation loc = CURRENT_LOCATION.get(); + if (loc == null) { + loc = refreshFromConfigNode(); + } + return (loc == null) ? null : loc.getInternalEndPoint(); + } + + private TAINodeLocation refreshFromConfigNode() { + try (final ConfigNodeClient cn = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + final TGetAINodeLocationResp resp = cn.getAINodeLocation(); + if (resp.isSetAiNodeLocation()) { + final TAINodeLocation loc = resp.getAiNodeLocation(); + CURRENT_LOCATION.set(loc); + return loc; + } + } catch (Exception e) { + LoggerFactory.getLogger(AINodeClient.class) + .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); + } + return null; + } + + public AINodeClient( + ThriftClientProperty property, ClientManager clientManager) { + this.property = property; + this.clientManager = clientManager; + tryToConnect(property.getConnectionTimeoutMs()); + } + + public TTransport getTransport() { + return transport; + } + + @Override + public void close() { + clientManager.returnClient(AINodeClientManager.AINODE_ID_PLACEHOLDER, this); + } + + @Override + public void invalidate() { + Optional.ofNullable(transport).ifPresent(TTransport::close); + } + + @Override + public void invalidateAll() { + clientManager.clear(AINodeClientManager.AINODE_ID_PLACEHOLDER); + } + + @Override + public boolean printLogWhenEncounterException() { + return property.isPrintLogWhenEncounterException(); + } + + public static class Factory extends ThriftClientFactory { + + public Factory( + ClientManager clientClientManager, + ThriftClientProperty thriftClientProperty) { + super(clientClientManager, thriftClientProperty); + } + + @Override + public void destroyObject(Integer aiNodeId, PooledObject pooledObject) { + pooledObject.getObject().invalidate(); + } + + @Override + public PooledObject makeObject(Integer Integer) throws Exception { + return new DefaultPooledObject<>( + SyncThriftClientWithErrorHandler.newErrorHandler( + AINodeClient.class, + AINodeClient.class.getConstructor( + thriftClientProperty.getClass(), clientManager.getClass()), + thriftClientProperty, + clientManager)); + } + + @Override + public boolean validateObject(Integer Integer, PooledObject pooledObject) { + return Optional.ofNullable(pooledObject.getObject().getTransport()) + .map(TTransport::isOpen) + .orElse(false); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java new file mode 100644 index 000000000000..698c8e793883 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; + +public class AINodeClientManager { + + public static final int AINODE_ID_PLACEHOLDER = 0; + + private AINodeClientManager() { + // Empty constructor + } + + public static IClientManager getInstance() { + return AINodeClientManagerHolder.INSTANCE; + } + + private static class AINodeClientManagerHolder { + + private static final IClientManager INSTANCE = + new IClientManager.Factory() + .createClientManager(new DataNodeClientPoolFactory.AINodeClientPoolFactory()); + + private AINodeClientManagerHolder() { + // Empty constructor + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index 7126af78b8b5..29e5580311d0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -19,18 +19,15 @@ package org.apache.iotdb.db.queryengine.execution.operator.process.ai; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; import org.apache.iotdb.db.queryengine.execution.operator.Operator; import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; @@ -75,7 +72,6 @@ public class InferenceOperator implements ProcessOperator { private int resultIndex = 0; private List results; private final TsBlockSerde serde = new TsBlockSerde(); - private InferenceWindowType windowType = null; private final boolean generateTimeColumn; private long maxTimestamp; @@ -109,10 +105,6 @@ public InferenceOperator( this.maxReturnSize = maxReturnSize; this.totalRow = 0; - if (modelInferenceDescriptor.getInferenceWindowParameter() != null) { - windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType(); - } - if (generateTimeColumn) { this.interval = 0; this.minTimestamp = Long.MAX_VALUE; @@ -237,62 +229,6 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) { } } - private TWindowParams getWindowParams() { - TWindowParams windowParams; - if (windowType == null) { - return null; - } - if (windowType == InferenceWindowType.COUNT) { - CountInferenceWindowParameter countInferenceWindowParameter = - (CountInferenceWindowParameter) modelInferenceDescriptor.getInferenceWindowParameter(); - windowParams = new TWindowParams(); - windowParams.setWindowInterval((int) countInferenceWindowParameter.getInterval()); - windowParams.setWindowStep((int) countInferenceWindowParameter.getStep()); - } else { - windowParams = null; - } - return windowParams; - } - - private TsBlock preProcess(TsBlock inputTsBlock) { - // boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn(); - boolean notBuiltIn = false; - if (windowType == null || windowType == InferenceWindowType.HEAD) { - if (notBuiltIn - && totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data does not match the model input %s. Try to use LIMIT in SQL or WINDOW in CALL INFERENCE", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - return inputTsBlock; - } else if (windowType == InferenceWindowType.COUNT) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - } else if (windowType == InferenceWindowType.TAIL) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - // Tail window logic: get the latest data for inference - long windowSize = - (int) - ((BottomInferenceWindowParameter) - modelInferenceDescriptor.getInferenceWindowParameter()) - .getWindowSize(); - return inputTsBlock.subTsBlock((int) (totalRow - windowSize)); - } - return inputTsBlock; - } - private void submitInferenceTask() { if (generateTimeColumn) { @@ -301,20 +237,16 @@ private void submitInferenceTask() { TsBlock inputTsBlock = inputTsBlockBuilder.build(); - TsBlock finalInputTsBlock = preProcess(inputTsBlock); - TWindowParams windowParams = getWindowParams(); - inferenceExecutionFuture = Futures.submit( () -> { try (AINodeClient client = AINodeClientManager.getInstance() - .borrowClient(modelInferenceDescriptor.getTargetAINode())) { + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { return client.inference( - modelInferenceDescriptor.getModelName(), - finalInputTsBlock, - modelInferenceDescriptor.getInferenceAttributes(), - windowParams); + new TInferenceReq( + modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock)) + .setInferenceAttributes(modelInferenceDescriptor.getInferenceAttributes())); } catch (Exception e) { throw new ModelInferenceProcessException(e.getMessage()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java index fc6888165659..daceffce6b7b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java @@ -19,14 +19,11 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; import org.apache.iotdb.common.rpc.thrift.Model; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.audit.UserEntity; import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.conf.IoTDBConstant; @@ -68,8 +65,6 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClient; import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; @@ -157,8 +152,6 @@ public static Iterator getSupplier( return new SubscriptionSupplier(dataTypes, userEntity); case InformationSchema.VIEWS: return new ViewsSupplier(dataTypes, userEntity); - case InformationSchema.MODELS: - return new ModelsSupplier(dataTypes); case InformationSchema.FUNCTIONS: return new FunctionsSupplier(dataTypes); case InformationSchema.CONFIGURATIONS: @@ -798,112 +791,6 @@ public boolean hasNext() { } } - private static class ModelsSupplier extends TsBlockSupplier { - private final ModelIterator iterator; - - private ModelsSupplier(final List dataTypes) throws Exception { - super(dataTypes); - final TEndPoint ep = AINodeClient.getCurrentEndpoint(); - try (final AINodeClient ai = AINodeClientManager.getInstance().borrowClient(ep)) { - iterator = new ModelIterator(ai.showModels(new TShowModelsReq())); - } - } - - private static class ModelIterator implements Iterator { - - private int index = 0; - private final TShowModelsResp resp; - - private ModelIterator(TShowModelsResp resp) { - this.resp = resp; - } - - @Override - public boolean hasNext() { - return index < resp.getModelIdListSize(); - } - - @Override - public ModelInfoInString next() { - String modelId = resp.getModelIdList().get(index++); - return new ModelInfoInString( - modelId, - resp.getModelTypeMap().get(modelId), - resp.getCategoryMap().get(modelId), - resp.getStateMap().get(modelId)); - } - } - - private static class ModelInfoInString { - - private final String modelId; - private final String modelType; - private final String category; - private final String state; - - public ModelInfoInString(String modelId, String modelType, String category, String state) { - this.modelId = modelId; - this.modelType = modelType; - this.category = category; - this.state = state; - } - - public String getModelId() { - return modelId; - } - - public String getModelType() { - return modelType; - } - - public String getCategory() { - return category; - } - - public String getState() { - return state; - } - } - - @Override - protected void constructLine() { - final ModelInfoInString modelInfo = iterator.next(); - columnBuilders[0].writeBinary( - new Binary(modelInfo.getModelId(), TSFileConfig.STRING_CHARSET)); - columnBuilders[1].writeBinary( - new Binary(modelInfo.getModelType(), TSFileConfig.STRING_CHARSET)); - columnBuilders[2].writeBinary( - new Binary(modelInfo.getCategory(), TSFileConfig.STRING_CHARSET)); - columnBuilders[3].writeBinary(new Binary(modelInfo.getState(), TSFileConfig.STRING_CHARSET)); - // if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) { - // columnBuilders[3].writeBinary( - // new Binary( - // INPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + INPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // columnBuilders[4].writeBinary( - // new Binary(ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // } else { - // columnBuilders[3].appendNull(); - // columnBuilders[4].writeBinary( - // new Binary("Built-in model in IoTDB", TSFileConfig.STRING_CHARSET)); - // } - resultBuilder.declarePosition(); - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - } - private static class FunctionsSupplier extends TsBlockSupplier { private final Iterator udfIterator; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 34a289b76c9d..dc56fe118b7b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -25,7 +25,6 @@ import org.apache.iotdb.commons.conf.IoTDBConstant; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -55,14 +54,6 @@ import org.apache.iotdb.db.queryengine.common.schematree.IMeasurementSchemaInfo; import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.HeadInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.TailInferenceWindow; import org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet; import org.apache.iotdb.db.queryengine.plan.analyze.load.LoadTsFileAnalyzer; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; @@ -425,46 +416,14 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem return; } - // Get model metadata from configNode and do some check + // Get model metadata from AINode String modelId = queryStatement.getModelId(); TSStatus status = modelFetcher.fetchModel(modelId, analysis); if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { throw new GetModelInfoException(status.getMessage()); } - // set inference window if there is - if (queryStatement.isSetInferenceWindow()) { - InferenceWindow window = queryStatement.getInferenceWindow(); - if (InferenceWindowType.HEAD == window.getType()) { - long windowSize = ((HeadInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - if (queryStatement.hasLimit() && queryStatement.getRowLimit() < windowSize) { - throw new SemanticException( - "Limit in Sql should be larger than window size in inference"); - } - // optimize head window by limitNode - queryStatement.setRowLimit(windowSize); - } else if (InferenceWindowType.TAIL == window.getType()) { - long windowSize = ((TailInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new BottomInferenceWindowParameter(windowSize); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } else if (InferenceWindowType.COUNT == window.getType()) { - CountInferenceWindow countInferenceWindow = (CountInferenceWindow) window; - // checkWindowSize(countInferenceWindow.getInterval(), modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new CountInferenceWindowParameter( - countInferenceWindow.getInterval(), countInferenceWindow.getStep()); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } - } - - // set inference attributes if there is + // Set inference attributes if there is if (queryStatement.hasInferenceAttributes()) { analysis .getModelInferenceDescriptor() @@ -472,12 +431,6 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem } } - private void checkWindowSize(long windowSize, ModelInformation modelInformation) { - if (modelInformation.isBuiltIn()) { - return; - } - } - private ISchemaTree analyzeSchema( QueryStatement queryStatement, Analysis analysis, @@ -1717,22 +1670,11 @@ static void analyzeOutput( } if (queryStatement.hasModelInference()) { - ModelInformation modelInformation = analysis.getModelInformation(); // check input - checkInputShape(modelInformation, outputExpressions); - checkInputType(analysis, modelInformation, outputExpressions); - + checkInputType(analysis, outputExpressions); // set output List columnHeaders = new ArrayList<>(); - int[] outputShape = modelInformation.getOutputShape(); - TSDataType[] outputDataType = modelInformation.getOutputDataType(); - for (int i = 0; i < outputShape[1]; i++) { - columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME + i, outputDataType[i])); - } - analysis - .getModelInferenceDescriptor() - .setOutputColumnNames( - columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList())); + columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME, TSDataType.DOUBLE)); boolean isIgnoreTimestamp = !queryStatement.isGenerateTime(); analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); return; @@ -1756,74 +1698,16 @@ static void analyzeOutput( analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); } - // check if the result of SQL matches the input of model - private static void checkInputShape( - ModelInformation modelInformation, List> outputExpressions) { - if (modelInformation.isBuiltIn()) { - modelInformation.setInputColumnSize(outputExpressions.size()); - return; - } - - // check inputShape - int[] inputShape = modelInformation.getInputShape(); - if (inputShape.length != 2) { - throw new SemanticException( - String.format( - "The input shape of model is not correct, the dimension of input shape should be 2, actual dimension is %d", - inputShape.length)); - } - int columnNumber = inputShape[1]; - if (columnNumber != outputExpressions.size()) { - throw new SemanticException( - String.format( - "The column number of SQL result does not match the number of model input [%d] for inference", - columnNumber)); - } - } - private static void checkInputType( - Analysis analysis, - ModelInformation modelInformation, - List> outputExpressions) { - - if (modelInformation.isBuiltIn()) { - TSDataType[] inputType = new TSDataType[outputExpressions.size()]; - for (int i = 0; i < outputExpressions.size(); i++) { - Expression inputExpression = outputExpressions.get(i).left; - TSDataType inputDataType = analysis.getType(inputExpression); - if (!inputDataType.isNumeric()) { - throw new SemanticException( - String.format( - "The type of SQL result column [%s in %d] should be numeric when inference", - inputDataType, i)); - } - inputType[i] = inputDataType; - } - modelInformation.setInputDataType(inputType); - return; - } - - TSDataType[] inputType = modelInformation.getInputDataType(); - if (inputType.length != modelInformation.getInputShape()[1]) { - throw new SemanticException( - String.format( - "The inputType does not match the input shape [%d] for inference", - modelInformation.getInputShape()[1])); - } - for (int i = 0; i < inputType.length; i++) { + Analysis analysis, List> outputExpressions) { + for (int i = 0; i < outputExpressions.size(); i++) { Expression inputExpression = outputExpressions.get(i).left; TSDataType inputDataType = analysis.getType(inputExpression); - boolean isExpressionNumeric = inputDataType.isNumeric(); - boolean isModelNumeric = inputType[i].isNumeric(); - if (isExpressionNumeric && isModelNumeric) { - // every model supports numeric by default - continue; - } - if (inputDataType != inputType[i]) { + if (!inputDataType.isNumeric()) { throw new SemanticException( String.format( - "The type of SQL result column [%s in %d] does not match the type of model input [%s] when inference", - inputDataType, i, inputType[i])); + "The type of SQL result column [%s in %d] should be numeric when inference", + inputDataType, i)); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java index 586e12e589ab..1feecaefde9c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java @@ -20,12 +20,8 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; public interface IModelFetcher { /** Get model information by model id from configNode. */ TSStatus fetchModel(String modelId, Analysis analysis); - - // currently only used by table model - ModelInferenceDescriptor fetchModel(String modelName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java index dbeee4e8ed4b..b4123c237bbd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java @@ -20,27 +20,13 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.db.exception.ainode.ModelNotFoundException; -import org.apache.iotdb.db.exception.sql.StatementAnalyzeException; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; -import org.apache.thrift.TException; - +// TODO: This class should contact with AINode directly and cache model info in DataNode public class ModelFetcher implements IModelFetcher { - private final IClientManager configNodeClientManager = - ConfigNodeClientManager.getInstance(); - private static final class ModelFetcherHolder { private static final ModelFetcher INSTANCE = new ModelFetcher(); @@ -55,34 +41,9 @@ public static ModelFetcher getInstance() { private ModelFetcher() {} @Override - public TSStatus fetchModel(String modelName, Analysis analysis) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new StatementAnalyzeException(e.getMessage()); - } - } - - @Override - public ModelInferenceDescriptor fetchModel(String modelName) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new ModelInferenceDescriptor(getModelInfoResp.aiNodeAddress); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new IoTDBRuntimeException( - String.format("fetch model [%s] info failed: %s", modelName, e.getMessage()), - TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - } + public TSStatus fetchModel(String modelId, Analysis analysis) { + analysis.setModelInferenceDescriptor( + new ModelInferenceDescriptor(new ModelInformation(modelId))); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index d0f7c7f99d7e..01f6757f02eb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -19,7 +19,10 @@ package org.apache.iotdb.db.queryengine.plan.execution.config.executor; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; @@ -96,7 +99,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCountTimeSlotListResp; import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; @@ -114,7 +116,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -175,8 +176,8 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; @@ -379,6 +380,8 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = ConfigNodeClientManager.getInstance(); + private static final IClientManager AI_NODE_CLIENT_MANAGER = + AINodeClientManager.getInstance(); /** FIXME Consolidate this clientManager with the upper one. */ private static final IClientManager @@ -3596,16 +3599,16 @@ public SettableFuture showContinuousQueries() { @Override public SettableFuture createModel(String modelId, String uri) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TCreateModelReq req = new TCreateModelReq(modelId, uri); - final TSStatus status = client.createModel(req); - if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { - future.setException(new IoTDBException(status)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TRegisterModelReq req = new TRegisterModelReq(modelId, uri); + final TRegisterModelResp resp = client.registerModel(req); + if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != resp.getStatus().getCode()) { + future.setException(new IoTDBException(resp.getStatus())); } else { future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); } - } catch (final ClientManagerException | TException e) { + } catch (final TException | ClientManagerException e) { future.setException(e); } return future; @@ -3614,9 +3617,9 @@ public SettableFuture createModel(String modelId, String uri) @Override public SettableFuture dropModel(final String modelId) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TSStatus executionStatus = client.dropModel(new TDropModelReq(modelId)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TSStatus executionStatus = client.deleteModel(new TDeleteModelReq(modelId)); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) { future.setException(new IoTDBException(executionStatus)); } else { @@ -3632,7 +3635,7 @@ public SettableFuture dropModel(final String modelId) { public SettableFuture showModels(final String modelId) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowModelsReq req = new TShowModelsReq(); if (modelId != null) { req.setModelId(modelId); @@ -3653,7 +3656,7 @@ public SettableFuture showModels(final String modelId) { public SettableFuture showLoadedModels(List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowLoadedModelsReq req = new TShowLoadedModelsReq(); req.setDeviceIdList(deviceIdList != null ? deviceIdList : new ArrayList<>()); final TShowLoadedModelsResp resp = ai.showLoadedModels(req); @@ -3672,7 +3675,7 @@ public SettableFuture showLoadedModels(List deviceIdLi public SettableFuture showAIDevices() { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowAIDevicesResp resp = ai.showAIDevices(); if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { future.setException(new IoTDBException(resp.getStatus())); @@ -3690,7 +3693,7 @@ public SettableFuture loadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TLoadModelReq req = new TLoadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.loadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3709,7 +3712,7 @@ public SettableFuture unloadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TUnloadModelReq req = new TUnloadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.unloadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3734,7 +3737,7 @@ public SettableFuture createTraining( @Nullable List pathList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TTrainingReq req = new TTrainingReq(); req.setModelId(modelId); req.setParameters(parameters); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java index 09205c9eb564..a01acf86db57 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java @@ -31,6 +31,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -90,7 +91,7 @@ public PlanNode clone() { @Override public List getOutputColumnNames() { - return modelInferenceDescriptor.getOutputColumnNames(); + return Collections.singletonList("output"); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java index b7c6aaa4f4b0..1301ec97eb32 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java @@ -19,9 +19,7 @@ package org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; import org.apache.tsfile.utils.ReadWriteIOUtils; @@ -36,19 +34,15 @@ public class ModelInferenceDescriptor { - private final TEndPoint targetAINode; - private ModelInformation modelInformation; + private final ModelInformation modelInformation; private List outputColumnNames; - private InferenceWindowParameter inferenceWindowParameter; private Map inferenceAttributes; - public ModelInferenceDescriptor(TEndPoint targetAINode) { - this.targetAINode = targetAINode; + public ModelInferenceDescriptor(ModelInformation modelInformation) { + this.modelInformation = modelInformation; } private ModelInferenceDescriptor(ByteBuffer buffer) { - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelInformation = ModelInformation.deserialize(buffer); int outputColumnNamesSize = ReadWriteIOUtils.readInt(buffer); if (outputColumnNamesSize == 0) { @@ -59,12 +53,6 @@ private ModelInferenceDescriptor(ByteBuffer buffer) { this.outputColumnNames.add(ReadWriteIOUtils.readString(buffer)); } } - boolean hasInferenceWindowParameter = ReadWriteIOUtils.readBool(buffer); - if (hasInferenceWindowParameter) { - this.inferenceWindowParameter = InferenceWindowParameter.deserialize(buffer); - } else { - this.inferenceWindowParameter = null; - } int inferenceAttributesSize = ReadWriteIOUtils.readInt(buffer); if (inferenceAttributesSize == 0) { this.inferenceAttributes = null; @@ -85,24 +73,12 @@ public Map getInferenceAttributes() { return inferenceAttributes; } - public void setInferenceWindowParameter(InferenceWindowParameter inferenceWindowParameter) { - this.inferenceWindowParameter = inferenceWindowParameter; - } - - public InferenceWindowParameter getInferenceWindowParameter() { - return inferenceWindowParameter; - } - public ModelInformation getModelInformation() { return modelInformation; } - public TEndPoint getTargetAINode() { - return targetAINode; - } - - public String getModelName() { - return modelInformation.getModelName(); + public String getModelId() { + return modelInformation.getModelId(); } public void setOutputColumnNames(List outputColumnNames) { @@ -114,8 +90,6 @@ public List getOutputColumnNames() { } public void serialize(ByteBuffer byteBuffer) { - ReadWriteIOUtils.write(targetAINode.ip, byteBuffer); - ReadWriteIOUtils.write(targetAINode.port, byteBuffer); modelInformation.serialize(byteBuffer); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, byteBuffer); @@ -125,12 +99,6 @@ public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(outputColumnName, byteBuffer); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, byteBuffer); - } else { - ReadWriteIOUtils.write(true, byteBuffer); - inferenceWindowParameter.serialize(byteBuffer); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, byteBuffer); } else { @@ -143,8 +111,6 @@ public void serialize(ByteBuffer byteBuffer) { } public void serialize(DataOutputStream stream) throws IOException { - ReadWriteIOUtils.write(targetAINode.ip, stream); - ReadWriteIOUtils.write(targetAINode.port, stream); modelInformation.serialize(stream); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, stream); @@ -154,12 +120,6 @@ public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(outputColumnName, stream); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, stream); - } else { - ReadWriteIOUtils.write(true, stream); - inferenceWindowParameter.serialize(stream); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, stream); } else { @@ -184,20 +144,13 @@ public boolean equals(Object o) { return false; } ModelInferenceDescriptor that = (ModelInferenceDescriptor) o; - return targetAINode.equals(that.targetAINode) - && modelInformation.equals(that.modelInformation) + return modelInformation.equals(that.modelInformation) && outputColumnNames.equals(that.outputColumnNames) - && inferenceWindowParameter.equals(that.inferenceWindowParameter) && inferenceAttributes.equals(that.inferenceAttributes); } @Override public int hashCode() { - return Objects.hash( - targetAINode, - modelInformation, - outputColumnNames, - inferenceWindowParameter, - inferenceAttributes); + return Objects.hash(modelInformation, outputColumnNames, inferenceAttributes); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java index c01308f9f375..cd219fd68162 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java @@ -40,7 +40,6 @@ import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableArgumentAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableFunctionInvocationAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; -import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName; @@ -4693,11 +4692,6 @@ public Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional String functionName = node.getName().toString(); TableFunction function = metadata.getTableFunction(functionName); - // set model fetcher for ForecastTableFunction - if (function instanceof ForecastTableFunction) { - ((ForecastTableFunction) function).setModelFetcher(metadata.getModelFetcher()); - } - Node errorLocation = node; if (!node.getArguments().isEmpty()) { errorLocation = node.getArguments().get(0); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java index a5a57cebadb4..61b96809f847 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java @@ -25,6 +25,7 @@ import org.apache.iotdb.commons.udf.builtin.relational.tvf.SessionTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.TumbleTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.VariationTableFunction; +import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ClassifyTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.PatternMatchTableFunction; import org.apache.iotdb.udf.api.relational.TableFunction; @@ -42,7 +43,8 @@ public enum TableBuiltinTableFunction { VARIATION("variation"), CAPACITY("capacity"), FORECAST("forecast"), - PATTERN_MATCH("pattern_match"); + PATTERN_MATCH("pattern_match"), + CLASSIFY("classify"); private final String functionName; @@ -86,6 +88,8 @@ public static TableFunction getBuiltinTableFunction(String functionName) { return new CapacityTableFunction(); case "forecast": return new ForecastTableFunction(); + case "classify": + return new ClassifyTableFunction(); default: throw new UnsupportedOperationException("Unsupported table function: " + functionName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java new file mode 100644 index 000000000000..cf41841a1105 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java @@ -0,0 +1,366 @@ +package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; + +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.exception.IoTDBRuntimeException; +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; +import org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender; +import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.udf.api.exception.UDFException; +import org.apache.iotdb.udf.api.relational.TableFunction; +import org.apache.iotdb.udf.api.relational.access.Record; +import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis; +import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle; +import org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider; +import org.apache.iotdb.udf.api.relational.table.argument.Argument; +import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema; +import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument; +import org.apache.iotdb.udf.api.relational.table.argument.TableArgument; +import org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor; +import org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification; +import org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification; +import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification; +import org.apache.iotdb.udf.api.type.Type; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.TsBlockSerde; +import org.apache.tsfile.utils.PublicBAOS; +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender; +import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; + +public class ClassifyTableFunction implements TableFunction { + + public static class ClassifyTableFunctionHandle implements TableFunctionHandle { + String modelId; + int maxInputLength; + List inputColumnTypes; + + public ClassifyTableFunctionHandle() {} + + public ClassifyTableFunctionHandle( + String modelId, int maxInputLength, List inputColumnTypes) { + this.modelId = modelId; + this.maxInputLength = maxInputLength; + this.inputColumnTypes = inputColumnTypes; + } + + @Override + public byte[] serialize() { + try (PublicBAOS publicBAOS = new PublicBAOS(); + DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { + ReadWriteIOUtils.write(modelId, outputStream); + ReadWriteIOUtils.write(maxInputLength, outputStream); + ReadWriteIOUtils.write(inputColumnTypes.size(), outputStream); + for (Type type : inputColumnTypes) { + ReadWriteIOUtils.write(type.getType(), outputStream); + } + outputStream.flush(); + return publicBAOS.toByteArray(); + } catch (IOException e) { + throw new IoTDBRuntimeException( + String.format( + "Error occurred while serializing ForecastTableFunctionHandle: %s", e.getMessage()), + TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); + } + } + + @Override + public void deserialize(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.wrap(bytes); + this.modelId = ReadWriteIOUtils.readString(buffer); + this.maxInputLength = ReadWriteIOUtils.readInt(buffer); + int size = ReadWriteIOUtils.readInt(buffer); + this.inputColumnTypes = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + inputColumnTypes.add(Type.valueOf(ReadWriteIOUtils.readString(buffer))); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassifyTableFunctionHandle that = (ClassifyTableFunctionHandle) o; + return maxInputLength == that.maxInputLength + && Objects.equals(modelId, that.modelId) + && Objects.equals(inputColumnTypes, that.inputColumnTypes); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, maxInputLength, inputColumnTypes); + } + } + + private static final String INPUT_PARAMETER_NAME = "INPUT"; + private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID"; + public static final String TIMECOL_PARAMETER_NAME = "TIMECOL"; + private static final String DEFAULT_TIME_COL = "time"; + private static final String DEFAULT_OUTPUT_COLUMN_NAME = "category"; + private static final int MAX_INPUT_LENGTH = 2880; + + private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); + + static { + ALLOWED_INPUT_TYPES.add(Type.INT32); + ALLOWED_INPUT_TYPES.add(Type.INT64); + ALLOWED_INPUT_TYPES.add(Type.FLOAT); + ALLOWED_INPUT_TYPES.add(Type.DOUBLE); + } + + @Override + public List getArgumentsSpecifications() { + return Arrays.asList( + TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(), + ScalarParameterSpecification.builder() + .name(MODEL_ID_PARAMETER_NAME) + .type(Type.STRING) + .build(), + ScalarParameterSpecification.builder() + .name(TIMECOL_PARAMETER_NAME) + .type(Type.STRING) + .defaultValue(DEFAULT_TIME_COL) + .build()); + } + + @Override + public TableFunctionAnalysis analyze(Map arguments) throws UDFException { + TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME); + String modelId = (String) ((ScalarArgument) arguments.get(MODEL_ID_PARAMETER_NAME)).getValue(); + // modelId should never be null or empty + if (modelId == null || modelId.isEmpty()) { + throw new SemanticException( + String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); + } + + String timeColumn = + ((String) ((ScalarArgument) arguments.get(TIMECOL_PARAMETER_NAME)).getValue()) + .toLowerCase(Locale.ENGLISH); + + if (timeColumn.isEmpty()) { + throw new SemanticException( + String.format("%s should never be null or empty.", TIMECOL_PARAMETER_NAME)); + } + + // predicated columns should never contain partition by columns and time column + Set excludedColumns = + input.getPartitionBy().stream() + .map(s -> s.toLowerCase(Locale.ENGLISH)) + .collect(Collectors.toSet()); + excludedColumns.add(timeColumn); + int timeColumnIndex = findColumnIndex(input, timeColumn, Collections.singleton(Type.TIMESTAMP)); + + List requiredIndexList = new ArrayList<>(); + requiredIndexList.add(timeColumnIndex); + DescribedSchema.Builder properColumnSchemaBuilder = + new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP); + + List inputColumnTypes = new ArrayList<>(); + List> allInputColumnsName = input.getFieldNames(); + List allInputColumnsType = input.getFieldTypes(); + + for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { + Optional fieldName = allInputColumnsName.get(i); + // All input value columns are required for model forecasting + if (!fieldName.isPresent() + || !excludedColumns.contains(fieldName.get().toLowerCase(Locale.ENGLISH))) { + Type columnType = allInputColumnsType.get(i); + checkType(columnType, fieldName.orElse("")); + inputColumnTypes.add(columnType); + requiredIndexList.add(i); + } + } + properColumnSchemaBuilder.addField(DEFAULT_OUTPUT_COLUMN_NAME, Type.INT32); + + ClassifyTableFunctionHandle functionHandle = + new ClassifyTableFunctionHandle(modelId, MAX_INPUT_LENGTH, inputColumnTypes); + + // outputColumnSchema + return TableFunctionAnalysis.builder() + .properColumnSchema(properColumnSchemaBuilder.build()) + .handle(functionHandle) + .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList) + .build(); + } + + // only allow for INT32, INT64, FLOAT, DOUBLE + private void checkType(Type type, String columnName) { + if (!ALLOWED_INPUT_TYPES.contains(type)) { + throw new SemanticException( + String.format( + "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed", + columnName, type)); + } + } + + @Override + public TableFunctionHandle createTableFunctionHandle() { + return new ClassifyTableFunctionHandle(); + } + + @Override + public TableFunctionProcessorProvider getProcessorProvider( + TableFunctionHandle tableFunctionHandle) { + return new TableFunctionProcessorProvider() { + @Override + public TableFunctionDataProcessor getDataProcessor() { + return new ClassifyDataProcessor((ClassifyTableFunctionHandle) tableFunctionHandle); + } + }; + } + + private static class ClassifyDataProcessor implements TableFunctionDataProcessor { + + private static final TsBlockSerde SERDE = new TsBlockSerde(); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); + + private final String modelId; + private final int maxInputLength; + private final LinkedList inputRecords; + private final TsBlockBuilder inputTsBlockBuilder; + private final List inputColumnAppenderList; + private final List resultColumnAppenderList; + + public ClassifyDataProcessor(ClassifyTableFunctionHandle functionHandle) { + this.modelId = functionHandle.modelId; + this.maxInputLength = functionHandle.maxInputLength; + this.inputRecords = new LinkedList<>(); + List inputTsDataTypeList = + new ArrayList<>(functionHandle.inputColumnTypes.size()); + for (Type type : functionHandle.inputColumnTypes) { + // AINode currently only accept double input + inputTsDataTypeList.add(TSDataType.DOUBLE); + } + this.inputTsBlockBuilder = new TsBlockBuilder(inputTsDataTypeList); + this.inputColumnAppenderList = new ArrayList<>(functionHandle.inputColumnTypes.size()); + for (Type type : functionHandle.inputColumnTypes) { + // AINode currently only accept double input + inputColumnAppenderList.add(createResultColumnAppender(Type.DOUBLE)); + } + this.resultColumnAppenderList = new ArrayList<>(1); + this.resultColumnAppenderList.add(createResultColumnAppender(Type.INT32)); + } + + @Override + public void process( + Record input, + List properColumnBuilders, + ColumnBuilder passThroughIndexBuilder) { + // only keep at most maxInputLength rows + if (maxInputLength != 0 && inputRecords.size() == maxInputLength) { + inputRecords.removeFirst(); + } + inputRecords.add(input); + } + + @Override + public void finish( + List properColumnBuilders, ColumnBuilder passThroughIndexBuilder) { + + // time column + long inputStartTime = inputRecords.getFirst().getLong(0); + long inputEndTime = inputRecords.getLast().getLong(0); + if (inputEndTime < inputStartTime) { + throw new SemanticException( + String.format( + "input end time should never less than start time, start time is %s, end time is %s", + inputStartTime, inputEndTime)); + } + int outputLength = inputRecords.size(); + for (Record inputRecord : inputRecords) { + properColumnBuilders.get(0).writeLong(inputRecord.getLong(0)); + } + + // predicated columns + TsBlock predicatedResult = classify(); + if (predicatedResult.getPositionCount() != outputLength) { + throw new IoTDBRuntimeException( + String.format( + "Model %s output length is %s, doesn't equal to specified %s", + modelId, predicatedResult.getPositionCount(), outputLength), + TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); + } + + for (int columnIndex = 1, size = predicatedResult.getValueColumnCount(); + columnIndex <= size; + columnIndex++) { + Column column = predicatedResult.getColumn(columnIndex - 1); + ColumnBuilder builder = properColumnBuilders.get(columnIndex); + ResultColumnAppender appender = resultColumnAppenderList.get(columnIndex - 1); + for (int row = 0; row < outputLength; row++) { + if (column.isNull(row)) { + builder.appendNull(); + } else { + // convert double to real type + appender.writeDouble(column.getDouble(row), builder); + } + } + } + } + + private TsBlock classify() { + int outputLength = inputRecords.size(); + while (!inputRecords.isEmpty()) { + Record row = inputRecords.removeFirst(); + inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0)); + for (int i = 1, size = row.size(); i < size; i++) { + // we set null input to 0.0 + if (row.isNull(i)) { + inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0); + } else { + // need to transform other types to DOUBLE + inputTsBlockBuilder + .getColumnBuilder(i - 1) + .writeDouble(inputColumnAppenderList.get(i - 1).getDouble(row, i)); + } + } + inputTsBlockBuilder.declarePosition(); + } + TsBlock inputData = inputTsBlockBuilder.build(); + + TForecastResp resp; + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = client.forecast(new TForecastReq(modelId, SERDE.serialize(inputData), outputLength)); + } catch (Exception e) { + throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); + } + + if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + String message = + String.format( + "Error occurred while executing classify:[%s]", resp.getStatus().getMessage()); + throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); + } + + return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 887d7c26d305..1c89fc439079 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -19,14 +19,14 @@ package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; import org.apache.iotdb.db.exception.sql.SemanticException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; +import org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.relational.TableFunction; import org.apache.iotdb.udf.api.relational.access.Record; @@ -70,12 +70,12 @@ import java.util.stream.Collectors; import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex; +import static org.apache.iotdb.db.queryengine.plan.relational.utils.ResultColumnAppender.createResultColumnAppender; import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; public class ForecastTableFunction implements TableFunction { public static class ForecastTableFunctionHandle implements TableFunctionHandle { - TEndPoint targetAINode; String modelId; int maxInputLength; int outputLength; @@ -95,7 +95,6 @@ public ForecastTableFunctionHandle( int outputLength, long outputStartTime, long outputInterval, - TEndPoint targetAINode, List types) { this.keepInput = keepInput; this.maxInputLength = maxInputLength; @@ -104,7 +103,6 @@ public ForecastTableFunctionHandle( this.outputLength = outputLength; this.outputStartTime = outputStartTime; this.outputInterval = outputInterval; - this.targetAINode = targetAINode; this.types = types; } @@ -112,8 +110,6 @@ public ForecastTableFunctionHandle( public byte[] serialize() { try (PublicBAOS publicBAOS = new PublicBAOS(); DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { - ReadWriteIOUtils.write(targetAINode.getIp(), outputStream); - ReadWriteIOUtils.write(targetAINode.getPort(), outputStream); ReadWriteIOUtils.write(modelId, outputStream); ReadWriteIOUtils.write(maxInputLength, outputStream); ReadWriteIOUtils.write(outputLength, outputStream); @@ -138,8 +134,6 @@ public byte[] serialize() { @Override public void deserialize(byte[] bytes) { ByteBuffer buffer = ByteBuffer.wrap(bytes); - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelId = ReadWriteIOUtils.readString(buffer); this.maxInputLength = ReadWriteIOUtils.readInt(buffer); this.outputLength = ReadWriteIOUtils.readInt(buffer); @@ -168,7 +162,6 @@ public boolean equals(Object o) { && outputStartTime == that.outputStartTime && outputInterval == that.outputInterval && keepInput == that.keepInput - && Objects.equals(targetAINode, that.targetAINode) && Objects.equals(modelId, that.modelId) && Objects.equals(options, that.options) && Objects.equals(types, that.types); @@ -177,7 +170,6 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash( - targetAINode, modelId, maxInputLength, outputLength, @@ -219,16 +211,6 @@ public int hashCode() { ALLOWED_INPUT_TYPES.add(Type.DOUBLE); } - // need to set before analyze method is called - // should only be used in fe scope, never be used in TableFunctionProcessorProvider - // The reason we don't directly set modelFetcher=ModelFetcher.getInstance() is that we need to - // mock IModelFetcher in UT - private IModelFetcher modelFetcher = null; - - public void setModelFetcher(IModelFetcher modelFetcher) { - this.modelFetcher = modelFetcher; - } - @Override public List getArgumentsSpecifications() { return Arrays.asList( @@ -284,8 +266,6 @@ public TableFunctionAnalysis analyze(Map arguments) { String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); } - TEndPoint targetAINode = getModelInfo(modelId).getTargetAINode(); - int outputLength = (int) ((ScalarArgument) arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue(); if (outputLength <= 0) { @@ -390,7 +370,6 @@ public TableFunctionAnalysis analyze(Map arguments) { outputLength, outputStartTime, outputInterval, - targetAINode, predicatedColumnTypes); // outputColumnSchema @@ -417,10 +396,6 @@ public TableFunctionDataProcessor getDataProcessor() { }; } - private ModelInferenceDescriptor getModelInfo(String modelId) { - return modelFetcher.fetchModel(modelId); - } - // only allow for INT32, INT64, FLOAT, DOUBLE private void checkType(Type type, String columnName) { if (!ALLOWED_INPUT_TYPES.contains(type)) { @@ -456,9 +431,9 @@ private static Map parseOptions(String options) { private static class ForecastDataProcessor implements TableFunctionDataProcessor { private static final TsBlockSerde SERDE = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); - private final TEndPoint targetAINode; private final String modelId; private final int maxInputLength; private final int outputLength; @@ -471,7 +446,6 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor private final TsBlockBuilder inputTsBlockBuilder; public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { - this.targetAINode = functionHandle.targetAINode; this.modelId = functionHandle.modelId; this.maxInputLength = functionHandle.maxInputLength; this.outputLength = functionHandle.outputLength; @@ -490,21 +464,6 @@ public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList); } - private static ResultColumnAppender createResultColumnAppender(Type type) { - switch (type) { - case INT32: - return new Int32Appender(); - case INT64: - return new Int64Appender(); - case FLOAT: - return new FloatAppender(); - case DOUBLE: - return new DoubleAppender(); - default: - throw new IllegalArgumentException("Unsupported column type: " + type); - } - } - @Override public void process( Record input, @@ -619,8 +578,12 @@ private TsBlock forecast() { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(modelId, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(modelId, SERDE.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); } @@ -643,100 +606,4 @@ private TsBlock forecast() { return res; } } - - private interface ResultColumnAppender { - void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder); - - double getDouble(Record row, int columnIndex); - - void writeDouble(double value, ColumnBuilder columnBuilder); - } - - private static class Int32Appender implements ResultColumnAppender { - - @Override - public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { - if (row.isNull(columnIndex)) { - properColumnBuilder.appendNull(); - } else { - properColumnBuilder.writeInt(row.getInt(columnIndex)); - } - } - - @Override - public double getDouble(Record row, int columnIndex) { - return row.getInt(columnIndex); - } - - @Override - public void writeDouble(double value, ColumnBuilder columnBuilder) { - columnBuilder.writeInt((int) value); - } - } - - private static class Int64Appender implements ResultColumnAppender { - - @Override - public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { - if (row.isNull(columnIndex)) { - properColumnBuilder.appendNull(); - } else { - properColumnBuilder.writeLong(row.getLong(columnIndex)); - } - } - - @Override - public double getDouble(Record row, int columnIndex) { - return row.getLong(columnIndex); - } - - @Override - public void writeDouble(double value, ColumnBuilder columnBuilder) { - columnBuilder.writeLong((long) value); - } - } - - private static class FloatAppender implements ResultColumnAppender { - - @Override - public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { - if (row.isNull(columnIndex)) { - properColumnBuilder.appendNull(); - } else { - properColumnBuilder.writeFloat(row.getFloat(columnIndex)); - } - } - - @Override - public double getDouble(Record row, int columnIndex) { - return row.getFloat(columnIndex); - } - - @Override - public void writeDouble(double value, ColumnBuilder columnBuilder) { - columnBuilder.writeFloat((float) value); - } - } - - private static class DoubleAppender implements ResultColumnAppender { - - @Override - public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { - if (row.isNull(columnIndex)) { - properColumnBuilder.appendNull(); - } else { - properColumnBuilder.writeDouble(row.getDouble(columnIndex)); - } - } - - @Override - public double getDouble(Record row, int columnIndex) { - return row.getDouble(columnIndex); - } - - @Override - public void writeDouble(double value, ColumnBuilder columnBuilder) { - columnBuilder.writeDouble(value); - } - } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java index f0c041ad8053..db706d4980cb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.TableHeaderSchemaValidator; @@ -211,9 +210,4 @@ DataPartition getDataPartitionWithUnclosedTimeRange( final String database, final List sgNameToQueryParamsMap); TableFunction getTableFunction(final String functionName); - - /** - * @return ModelFetcher - */ - IModelFetcher getModelFetcher(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 0342c513f96a..c15071be6978 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -1471,11 +1471,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - return modelFetcher; - } - public static boolean isTwoNumericType(List argumentTypes) { return argumentTypes.size() == 2 && isNumericType(argumentTypes.get(0)) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java index 4676559bd7b1..f8cf497546e6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java @@ -96,7 +96,6 @@ public List getDataNodeLocations(final String tableName) { case InformationSchema.TOPICS: case InformationSchema.SUBSCRIPTIONS: case InformationSchema.VIEWS: - case InformationSchema.MODELS: case InformationSchema.FUNCTIONS: case InformationSchema.CONFIGURATIONS: case InformationSchema.KEYWORDS: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java new file mode 100644 index 000000000000..7c0efb855326 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/utils/ResultColumnAppender.java @@ -0,0 +1,126 @@ +package org.apache.iotdb.db.queryengine.plan.relational.utils; + +import org.apache.iotdb.udf.api.relational.access.Record; +import org.apache.iotdb.udf.api.type.Type; + +import org.apache.tsfile.block.column.ColumnBuilder; + +public interface ResultColumnAppender { + + void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder); + + double getDouble(Record row, int columnIndex); + + void writeDouble(double value, ColumnBuilder columnBuilder); + + /** + * Static factory method to return the appropriate ResultColumnAppender instance based on the + * Type. + */ + static ResultColumnAppender createResultColumnAppender(Type type) { + switch (type) { + case INT32: + return new Int32Appender(); + case INT64: + return new Int64Appender(); + case FLOAT: + return new FloatAppender(); + case DOUBLE: + return new DoubleAppender(); + default: + throw new IllegalArgumentException("Unsupported column type: " + type); + } + } + + /** INT32 Appender */ + class Int32Appender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeInt(row.getInt(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getInt(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeInt((int) value); + } + } + + /** INT64 Appender */ + class Int64Appender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeLong(row.getLong(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getLong(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeLong((long) value); + } + } + + /** FLOAT Appender */ + class FloatAppender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeFloat(row.getFloat(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getFloat(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeFloat((float) value); + } + } + + /** DOUBLE Appender */ + class DoubleAppender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeDouble(row.getDouble(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getDouble(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeDouble(value); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index 260410954d4c..ebecf79f5b70 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -19,14 +19,12 @@ package org.apache.iotdb.db.queryengine.plan.udf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.UDTF; import org.apache.iotdb.udf.api.access.Row; @@ -54,8 +52,8 @@ public class UDTFForecast implements UDTF { private static final TsBlockSerde serde = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); - private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); private String model_id; private int maxInputLength; private int outputLength; @@ -66,7 +64,6 @@ public class UDTFForecast implements UDTF { List types; private LinkedList inputRows; private TsBlockBuilder inputTsBlockBuilder; - private final IModelFetcher modelFetcher = ModelFetcher.getInstance(); private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); @@ -88,6 +85,7 @@ public class UDTFForecast implements UDTF { private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE; private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS"; private static final String DEFAULT_OPTIONS = ""; + private static final int MAX_INPUT_LENGTH = 2880; private void checkType() { for (Type type : this.types) { @@ -112,8 +110,7 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati throw new IllegalArgumentException( "MODEL_ID parameter must be provided and cannot be empty."); } - ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); - this.targetAINode = descriptor.getTargetAINode(); + this.maxInputLength = MAX_INPUT_LENGTH; this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); this.outputLength = @@ -211,8 +208,12 @@ private TsBlock forecast() throws Exception { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(model_id, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(model_id, serde.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException( e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode()); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java index e60a14b727ca..79c031560973 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.AlignedDeviceEntry; @@ -402,11 +401,6 @@ public TableFunction getTableFunction(String functionName) { return null; } - @Override - public IModelFetcher getModelFetcher() { - return null; - } - private static final DataPartition DATA_PARTITION = MockTSBSDataPartition.constructDataPartition(); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index e56b48936b96..7bbfe150ade4 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -19,7 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; @@ -378,7 +377,6 @@ public void testForecastFunction() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan @@ -439,7 +437,6 @@ public void testForecastFunctionWithNoLowerCase() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java index aa9fcdfd1b51..4b1d18944b73 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java @@ -19,8 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -32,12 +30,10 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.function.Exclude; import org.apache.iotdb.db.queryengine.plan.function.Repeat; import org.apache.iotdb.db.queryengine.plan.function.Split; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.SubtractionResolver; @@ -560,21 +556,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - String modelId = "timer_xl"; - IModelFetcher fetcher = Mockito.mock(IModelFetcher.class); - ModelInferenceDescriptor descriptor = Mockito.mock(ModelInferenceDescriptor.class); - Mockito.when(descriptor.getTargetAINode()).thenReturn(new TEndPoint("127.0.0.1", 10810)); - ModelInformation modelInformation = Mockito.mock(ModelInformation.class); - Mockito.when(modelInformation.available()).thenReturn(true); - Mockito.when(modelInformation.getInputShape()).thenReturn(new int[] {1440, 96}); - Mockito.when(descriptor.getModelInformation()).thenReturn(modelInformation); - Mockito.when(descriptor.getModelName()).thenReturn(modelId); - Mockito.when(fetcher.fetchModel(modelId)).thenReturn(descriptor); - return fetcher; - } - private static final DataPartition TABLE_DATA_PARTITION = MockTableModelDataPartition.constructDataPartition(DB1); diff --git a/iotdb-core/node-commons/pom.xml b/iotdb-core/node-commons/pom.xml index 85ff69ee8ac7..e7c508c195d5 100644 --- a/iotdb-core/node-commons/pom.xml +++ b/iotdb-core/node-commons/pom.xml @@ -65,6 +65,11 @@ iotdb-thrift-confignode 2.0.6-SNAPSHOT + + org.apache.iotdb + iotdb-thrift-ainode + 2.0.6-SNAPSHOT + org.apache.iotdb iotdb-thrift diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java index 106d67b6279d..115f322348c0 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java @@ -20,6 +20,7 @@ package org.apache.iotdb.commons.client; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncConfigNodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeExternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeInternalServiceClient; @@ -390,4 +391,31 @@ public GenericKeyedObjectPool create return clientPool; } } + + public static class AsyncAINodeHeartbeatServiceClientPoolFactory + implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AsyncAINodeInternalServiceClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) + .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) + .setPrintLogWhenEncounterException(false) + .build(), + ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java similarity index 83% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java rename to iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java index 26130287697c..8cbd55759633 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.db.protocol.client.ainode; +package org.apache.iotdb.commons.client.async; import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -35,20 +35,20 @@ import java.io.IOException; -public class AsyncAINodeServiceClient extends IAINodeRPCService.AsyncClient +public class AsyncAINodeInternalServiceClient extends IAINodeRPCService.AsyncClient implements ThriftClient { private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - private final boolean printLogWhenEncounterException; private final TEndPoint endPoint; - private final ClientManager clientManager; + private final boolean printLogWhenEncounterException; + private final ClientManager clientManager; - public AsyncAINodeServiceClient( + public AsyncAINodeInternalServiceClient( ThriftClientProperty property, TEndPoint endPoint, TAsyncClientManager tClientManager, - ClientManager clientManager) + ClientManager clientManager) throws IOException { super( property.getProtocolFactory(), @@ -122,10 +122,10 @@ public boolean isReady() { } public static class Factory - extends AsyncThriftClientFactory { + extends AsyncThriftClientFactory { public Factory( - ClientManager clientManager, + ClientManager clientManager, ThriftClientProperty thriftClientProperty, String threadName) { super(clientManager, thriftClientProperty, threadName); @@ -133,14 +133,15 @@ public Factory( @Override public void destroyObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { pooledObject.getObject().close(); } @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { + public PooledObject makeObject(TEndPoint endPoint) + throws Exception { return new DefaultPooledObject<>( - new AsyncAINodeServiceClient( + new AsyncAINodeInternalServiceClient( thriftClientProperty, endPoint, tManagers[clientCnt.incrementAndGet() % tManagers.length], @@ -149,7 +150,7 @@ public PooledObject makeObject(TEndPoint endPoint) thr @Override public boolean validateObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { return pooledObject.getObject().isReady(); } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java index 3fa107685438..01968833db7f 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -32,9 +32,12 @@ public class ModelInformation { + private static final int[] DEFAULT_MODEL_INPUT_SHAPE = new int[] {2880, 1}; + private static final int[] DEFAULT_MODEL_OUTPUT_SHAPE = new int[] {720, 1}; + ModelType modelType; - private final String modelName; + private final String modelId; private final int[] inputShape; @@ -48,9 +51,17 @@ public class ModelInformation { String attribute = ""; + public ModelInformation(String modelId) { + this.modelId = modelId; + this.inputShape = DEFAULT_MODEL_INPUT_SHAPE; + this.inputDataType = new TSDataType[] {TSDataType.DOUBLE}; + this.outputShape = DEFAULT_MODEL_OUTPUT_SHAPE; + this.outputDataType = new TSDataType[] {TSDataType.DOUBLE}; + } + public ModelInformation( ModelType modelType, - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, @@ -58,7 +69,7 @@ public ModelInformation( String attribute, ModelStatus status) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -68,14 +79,14 @@ public ModelInformation( } public ModelInformation( - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, TSDataType[] outputDataType, String attribute) { this.modelType = ModelType.USER_DEFINED; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -83,9 +94,9 @@ public ModelInformation( this.attribute = attribute; } - public ModelInformation(String modelName, ModelStatus status) { + public ModelInformation(String modelId, ModelStatus status) { this.modelType = ModelType.BUILT_IN_FORECAST; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[0]; this.outputShape = new int[0]; this.outputDataType = new TSDataType[0]; @@ -94,9 +105,9 @@ public ModelInformation(String modelName, ModelStatus status) { } // init built-in modelInformation - public ModelInformation(ModelType modelType, String modelName) { + public ModelInformation(ModelType modelType, String modelId) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[2]; this.outputShape = new int[2]; this.inputDataType = new TSDataType[0]; @@ -116,8 +127,8 @@ public void updateStatus(ModelStatus status) { this.status = status; } - public String getModelName() { - return modelName; + public String getModelId() { + return modelId; } public void setInputLength(int length) { @@ -197,7 +208,7 @@ public void setAttribute(String attribute) { public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -222,7 +233,7 @@ public void serialize(DataOutputStream stream) throws IOException { public void serialize(FileOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -247,7 +258,7 @@ public void serialize(FileOutputStream stream) throws IOException { public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(modelType.ordinal(), byteBuffer); ReadWriteIOUtils.write(status.ordinal(), byteBuffer); - ReadWriteIOUtils.write(modelName, byteBuffer); + ReadWriteIOUtils.write(modelId, byteBuffer); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -353,7 +364,7 @@ public static ModelInformation deserialize(InputStream stream) throws IOExceptio public ByteBuffer serializeShowModelResult() throws IOException { PublicBAOS buffer = new PublicBAOS(); DataOutputStream stream = new DataOutputStream(buffer); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelType.toString(), stream); ReadWriteIOUtils.write(status.toString(), stream); ReadWriteIOUtils.write(Arrays.toString(inputShape), stream); @@ -370,7 +381,7 @@ public ByteBuffer serializeShowModelResult() throws IOException { public boolean equals(Object obj) { if (obj instanceof ModelInformation) { ModelInformation other = (ModelInformation) obj; - return modelName.equals(other.modelName) + return modelId.equals(other.modelId) && modelType.equals(other.modelType) && Arrays.equals(inputShape, other.inputShape) && Arrays.equals(outputShape, other.outputShape) diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java index 64aff12f284e..6c6100086316 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java @@ -42,7 +42,7 @@ public boolean containsModel(String modelId) { } public void addModel(ModelInformation modelInformation) { - modelInfoMap.put(modelInformation.getModelName(), modelInformation); + modelInfoMap.put(modelInformation.getModelId(), modelInformation); } public void removeModel(String modelId) { @@ -63,7 +63,7 @@ public ModelInformation getModelInformationById(String modelId) { public void clearFailedModel() { for (ModelInformation modelInformation : modelInfoMap.values()) { if (modelInformation.getStatus() == ModelStatus.UNAVAILABLE) { - modelInfoMap.remove(modelInformation.getModelName()); + modelInfoMap.remove(modelInformation.getModelId()); } } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java index 2db41cc3c2dd..243bc41c40ce 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java @@ -43,7 +43,6 @@ public class InformationSchema { public static final String TOPICS = "topics"; public static final String SUBSCRIPTIONS = "subscriptions"; public static final String VIEWS = "views"; - public static final String MODELS = "models"; public static final String FUNCTIONS = "functions"; public static final String CONFIGURATIONS = "configurations"; public static final String KEYWORDS = "keywords"; @@ -256,23 +255,6 @@ public class InformationSchema { viewTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(VIEWS, viewTable); - final TsTable modelTable = new TsTable(MODELS); - modelTable.addColumnSchema( - new TagColumnSchema(ColumnHeaderConstant.MODEL_ID_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema(ColumnHeaderConstant.MODEL_TYPE_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.STATE.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.CONFIGS.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.NOTES.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); - schemaTables.put(MODELS, modelTable); - final TsTable functionTable = new TsTable(FUNCTIONS); functionTable.addColumnSchema( new TagColumnSchema(ColumnHeaderConstant.FUNCTION_NAME_TABLE_MODEL, TSDataType.STRING)); diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index e0680cd29b79..1dc2f025f5c3 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -60,13 +60,7 @@ struct TRegisterModelResp { struct TInferenceReq { 1: required string modelId 2: required binary dataset - 3: optional TWindowParams windowParams - 4: optional map inferenceAttributes -} - -struct TWindowParams { - 1: required i32 windowInterval - 2: required i32 windowStep + 3: optional map inferenceAttributes } struct TInferenceResp { diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift index d8f6318063eb..f2b8ec6b8b07 100644 --- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift +++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift @@ -1096,34 +1096,6 @@ struct TUnsetSchemaTemplateReq { 4: optional bool isGeneratedByPipe } -struct TCreateModelReq { - 1: required string modelName - 2: required string uri -} - -struct TDropModelReq { - 1: required string modelId -} - -struct TGetModelInfoReq { - 1: required string modelId -} - -struct TGetModelInfoResp { - 1: required common.TSStatus status - 2: optional binary modelInfo - 3: optional common.TEndPoint aiNodeAddress -} - -struct TUpdateModelInfoReq { - 1: required string modelId - 2: required i32 modelStatus - 3: optional string attributes - 4: optional list aiNodeIds - 5: optional i32 inputLength - 6: optional i32 outputLength -} - struct TDataSchemaForTable{ 1: required string targetSql } @@ -1132,16 +1104,6 @@ struct TDataSchemaForTree{ 1: required list path } -struct TCreateTrainingReq { - 1: required string modelId - 2: required bool isTableModel - 3: required string existingModelId - 4: optional TDataSchemaForTable dataSchemaForTable - 5: optional TDataSchemaForTree dataSchemaForTree - 6: optional map parameters - 7: optional list> timeRanges -} - // ==================================================== // Quota // ==================================================== @@ -2006,31 +1968,6 @@ service IConfigNodeRPCService { */ TShowCQResp showCQ() - // ==================================================== - // AI Model - // ==================================================== - - /** - * Create a model - * - * @return SUCCESS_STATUS if the model was created successfully - */ - common.TSStatus createModel(TCreateModelReq req) - - /** - * Drop a model - * - * @return SUCCESS_STATUS if the model was removed successfully - */ - common.TSStatus dropModel(TDropModelReq req) - - /** - * Return the model info by model_id - */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req) - - common.TSStatus updateModelInfo(TUpdateModelInfoReq req) - // ====================================================== // Quota // ======================================================