From c5e8663f30d96c75f22df3f7871fabe2bc7f1105 Mon Sep 17 00:00:00 2001 From: RkGrit Date: Mon, 10 Nov 2025 17:20:30 +0800 Subject: [PATCH 01/38] refactor_built_in_models --- .../iotdb/ainode/core/model/model_factory.py | 60 ++ .../iotdb/ainode/core/model/model_storage.py | 7 +- .../ainode/core/model/sktime/__init__.py | 17 + .../configuration_sktime.py} | 732 +++++------------- .../core/model/sktime/modeling_sktime.py | 261 +++++++ .../model/sundial/configuration_sundial.py | 2 + 6 files changed, 526 insertions(+), 553 deletions(-) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py rename iotdb-core/ainode/iotdb/ainode/core/model/{built_in_model_factory.py => sktime/configuration_sktime.py} (56%) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py index 26d863156f379..ceedf11b4e3c5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py @@ -21,6 +21,7 @@ from urllib.parse import urljoin import yaml +from huggingface_hub import hf_hub_download from iotdb.ainode.core.constant import ( MODEL_CONFIG_FILE_IN_YAML, @@ -34,12 +35,71 @@ download_file, download_snapshot_from_hf, ) +from iotdb.ainode.core.model.model_enums import BuiltInModelType from iotdb.ainode.core.util.serde import get_data_type_byte_from_str from iotdb.thrift.ainode.ttypes import TConfigs +from iotdb.ainode.core.model.model_info import TIMER_REPO_ID +from iotdb.ainode.core.constant import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, +) logger = Logger() +def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: + weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) + if not os.path.exists(weights_path): + logger.info( + f"Model weights file not found at {weights_path}, downloading from HuggingFace..." + ) + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + local_dir=local_dir, + ) + logger.info(f"Got file to {weights_path}") + except Exception as e: + logger.error( + f"Failed to download model weights file to {local_dir} due to {e}" + ) + return False + if not os.path.exists(config_path): + logger.info( + f"Model config file not found at {config_path}, downloading from HuggingFace..." + ) + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE_IN_JSON, + local_dir=local_dir, + ) + logger.info(f"Got file to {config_path}") + except Exception as e: + logger.error( + f"Failed to download model config file to {local_dir} due to {e}" + ) + return False + return True + + +def download_built_in_ltsm_from_hf_if_necessary( + model_type: BuiltInModelType, local_dir: str +) -> bool: + """ + Download the built-in ltsm from HuggingFace repository when necessary. + + Return: + bool: True if the model is existed or downloaded successfully, False otherwise. + """ + repo_id = TIMER_REPO_ID[model_type] + if not _download_file_from_hf_if_necessary(local_dir, repo_id): + return False + return True + + def fetch_model_by_uri( uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType ): diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index e346f569102e3..1c63b56e519c9 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -38,17 +38,14 @@ UnsupportedError, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) +from iotdb.ainode.core.model.sktime.modeling_sktime import fetch_built_in_model from iotdb.ainode.core.model.model_enums import ( BuiltInModelType, ModelCategory, ModelFileType, ModelStates, ) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri +from iotdb.ainode.core.model.model_factory import fetch_model_by_uri, download_built_in_ltsm_from_hf_if_necessary from iotdb.ainode.core.model.model_info import ( BUILT_IN_LTSM_MAP, BUILT_IN_MACHINE_LEARNING_MODEL_MAP, diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py similarity index 56% rename from iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py index 3b55142350bad..18fea61b6ff03 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -15,29 +15,12 @@ # specific language governing permissions and limitations # under the License. # -import os + from abc import abstractmethod from typing import Callable, Dict, List - -import numpy as np -from huggingface_hub import hf_hub_download -from sklearn.preprocessing import MinMaxScaler -from sktime.detection.hmm_learn import GMMHMM, GaussianHMM -from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA -from sktime.forecasting.exp_smoothing import ExponentialSmoothing -from sktime.forecasting.naive import NaiveForecaster -from sktime.forecasting.trend import STLForecaster - -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - AttributeName, -) +from enum import Enum from iotdb.ainode.core.exception import ( BuiltInModelNotSupportError, - InferenceModelInternalError, ListRangeException, NumericalRangeException, StringRangeException, @@ -45,134 +28,119 @@ ) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import TIMER_REPO_ID -from iotdb.ainode.core.model.sundial import modeling_sundial -from iotdb.ainode.core.model.timerxl import modeling_timer logger = Logger() -def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: - weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) - if not os.path.exists(weights_path): - logger.info( - f"Model weights file not found at {weights_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - local_dir=local_dir, - ) - logger.info(f"Got file to {weights_path}") - except Exception as e: - logger.error( - f"Failed to download model weights file to {local_dir} due to {e}" - ) - return False - if not os.path.exists(config_path): - logger.info( - f"Model config file not found at {config_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, - local_dir=local_dir, - ) - logger.info(f"Got file to {config_path}") - except Exception as e: - logger.error( - f"Failed to download model config file to {local_dir} due to {e}" - ) - return False - return True - - -def download_built_in_ltsm_from_hf_if_necessary( - model_type: BuiltInModelType, local_dir: str -) -> bool: - """ - Download the built-in ltsm from HuggingFace repository when necessary. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - repo_id = TIMER_REPO_ID[model_type] - if not _download_file_from_hf_if_necessary(local_dir, repo_id): - return False - return True - - -def get_model_attributes(model_type: BuiltInModelType): - if model_type == BuiltInModelType.ARIMA: - attribute_map = arima_attribute_map - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - attribute_map = naive_forecaster_attribute_map - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - attribute_map = exponential_smoothing_attribute_map - elif model_type == BuiltInModelType.STL_FORECASTER: - attribute_map = stl_forecaster_attribute_map - elif model_type == BuiltInModelType.GMM_HMM: - attribute_map = gmmhmm_attribute_map - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - attribute_map = gaussian_hmm_attribute_map - elif model_type == BuiltInModelType.STRAY: - attribute_map = stray_attribute_map - elif model_type == BuiltInModelType.TIMER_XL: - attribute_map = timerxl_attribute_map - elif model_type == BuiltInModelType.SUNDIAL: - attribute_map = sundial_attribute_map - else: - raise BuiltInModelNotSupportError(model_type.value) - return attribute_map - - -def fetch_built_in_model( - model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str] -) -> Callable: - """ - Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config. - Args: - model_type: the type of the built-in model - model_dir: for huggingface models only, the directory where the model is stored - Returns: - model: the built-in model - """ - default_attributes = get_model_attributes(model_type) - # parse the attributes from inference_attrs - attributes = parse_attribute(inference_attrs, default_attributes) - - # build the built-in model - if model_type == BuiltInModelType.ARIMA: - model = ArimaModel(attributes) - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - model = ExponentialSmoothingModel(attributes) - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - model = NaiveForecasterModel(attributes) - elif model_type == BuiltInModelType.STL_FORECASTER: - model = STLForecasterModel(attributes) - elif model_type == BuiltInModelType.GMM_HMM: - model = GMMHMMModel(attributes) - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - model = GaussianHmmModel(attributes) - elif model_type == BuiltInModelType.STRAY: - model = STRAYModel(attributes) - elif model_type == BuiltInModelType.TIMER_XL: - model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) - elif model_type == BuiltInModelType.SUNDIAL: - model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) - else: - raise BuiltInModelNotSupportError(model_type.value) - - return model +class AttributeName(Enum): + # forecast Attribute + PREDICT_LENGTH = "predict_length" + + # NaiveForecaster + STRATEGY = "strategy" + SP = "sp" + + # STLForecaster + # SP = 'sp' + SEASONAL = "seasonal" + SEASONAL_DEG = "seasonal_deg" + TREND_DEG = "trend_deg" + LOW_PASS_DEG = "low_pass_deg" + SEASONAL_JUMP = "seasonal_jump" + TREND_JUMP = "trend_jump" + LOSS_PASS_JUMP = "low_pass_jump" + + # ExponentialSmoothing + DAMPED_TREND = "damped_trend" + INITIALIZATION_METHOD = "initialization_method" + OPTIMIZED = "optimized" + REMOVE_BIAS = "remove_bias" + USE_BRUTE = "use_brute" + + # Arima + ORDER = "order" + SEASONAL_ORDER = "seasonal_order" + METHOD = "method" + MAXITER = "maxiter" + SUPPRESS_WARNINGS = "suppress_warnings" + OUT_OF_SAMPLE_SIZE = "out_of_sample_size" + SCORING = "scoring" + WITH_INTERCEPT = "with_intercept" + TIME_VARYING_REGRESSION = "time_varying_regression" + ENFORCE_STATIONARITY = "enforce_stationarity" + ENFORCE_INVERTIBILITY = "enforce_invertibility" + SIMPLE_DIFFERENCING = "simple_differencing" + MEASUREMENT_ERROR = "measurement_error" + MLE_REGRESSION = "mle_regression" + HAMILTON_REPRESENTATION = "hamilton_representation" + CONCENTRATE_SCALE = "concentrate_scale" + + # GAUSSIAN_HMM + N_COMPONENTS = "n_components" + COVARIANCE_TYPE = "covariance_type" + MIN_COVAR = "min_covar" + STARTPROB_PRIOR = "startprob_prior" + TRANSMAT_PRIOR = "transmat_prior" + MEANS_PRIOR = "means_prior" + MEANS_WEIGHT = "means_weight" + COVARS_PRIOR = "covars_prior" + COVARS_WEIGHT = "covars_weight" + ALGORITHM = "algorithm" + N_ITER = "n_iter" + TOL = "tol" + PARAMS = "params" + INIT_PARAMS = "init_params" + IMPLEMENTATION = "implementation" + + # GMMHMM + # N_COMPONENTS = "n_components" + N_MIX = "n_mix" + # MIN_COVAR = "min_covar" + # STARTPROB_PRIOR = "startprob_prior" + # TRANSMAT_PRIOR = "transmat_prior" + WEIGHTS_PRIOR = "weights_prior" + + # MEANS_PRIOR = "means_prior" + # MEANS_WEIGHT = "means_weight" + # ALGORITHM = "algorithm" + # COVARIANCE_TYPE = "covariance_type" + # N_ITER = "n_iter" + # TOL = "tol" + # INIT_PARAMS = "init_params" + # PARAMS = "params" + # IMPLEMENTATION = "implementation" + + # STRAY + ALPHA = "alpha" + K = "k" + KNN_ALGORITHM = "knn_algorithm" + P = "p" + SIZE_THRESHOLD = "size_threshold" + OUTLIER_TAIL = "outlier_tail" + + # timerxl + INPUT_TOKEN_LEN = "input_token_len" + HIDDEN_SIZE = "hidden_size" + INTERMEDIATE_SIZE = "intermediate_size" + OUTPUT_TOKEN_LENS = "output_token_lens" + NUM_HIDDEN_LAYERS = "num_hidden_layers" + NUM_ATTENTION_HEADS = "num_attention_heads" + HIDDEN_ACT = "hidden_act" + USE_CACHE = "use_cache" + ROPE_THETA = "rope_theta" + ATTENTION_DROPOUT = "attention_dropout" + INITIALIZER_RANGE = "initializer_range" + MAX_POSITION_EMBEDDINGS = "max_position_embeddings" + CKPT_PATH = "ckpt_path" + + # sundial + DROPOUT_RATE = "dropout_rate" + FLOW_LOSS_DEPTH = "flow_loss_depth" + NUM_SAMPLING_STEPS = "num_sampling_steps" + DIFFUSION_BATCH_MUL = "diffusion_batch_mul" + + def name(self) -> str: + return self.value class Attribute(object): @@ -198,11 +166,11 @@ def parse(self, string_value: str): class IntAttribute(Attribute): def __init__( - self, - name: str, - default_value: int, - default_low: int, - default_high: int, + self, + name: str, + default_value: int, + default_low: int, + default_high: int, ): super(IntAttribute, self).__init__(name) self.__default_value = default_value @@ -229,11 +197,11 @@ def parse(self, string_value: str): class FloatAttribute(Attribute): def __init__( - self, - name: str, - default_value: float, - default_low: float, - default_high: float, + self, + name: str, + default_value: float, + default_low: float, + default_high: float, ): super(FloatAttribute, self).__init__(name) self.__default_value = default_value @@ -376,216 +344,8 @@ def parse(self, string_value: str): return tuple_value -def parse_attribute( - input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] -): - """ - Args: - input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of - the attribute - attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute - object - Returns: - a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute - """ - attributes = {} - for attribute_name in attribute_map: - # user specified the attribute - if attribute_name in input_attributes: - attribute = attribute_map[attribute_name] - value = attribute.parse(input_attributes[attribute_name]) - attribute.validate_value(value) - attributes[attribute_name] = value - # user did not specify the attribute, use the default value - else: - try: - attributes[attribute_name] = attribute_map[ - attribute_name - ].get_default_value() - except NotImplementedError as e: - logger.error(f"attribute {attribute_name} is not implemented.") - raise e - return attributes - - -sundial_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=16, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=768, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=3072, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=12, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=12, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.DROPOUT_RATE.value: FloatAttribute( - name=AttributeName.DROPOUT_RATE.value, - default_value=0.1, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute( - name=AttributeName.FLOW_LOSS_DEPTH.value, - default_value=3, - default_low=1, - default_high=50, - ), - AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute( - name=AttributeName.NUM_SAMPLING_STEPS.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute( - name=AttributeName.DIFFUSION_BATCH_MUL.value, - default_value=4, - default_low=1, - default_high=5000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "sundial", - ), - value_choices=[""], - ), -} - -timerxl_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=96, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=1024, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=2048, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=8, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=8, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.ATTENTION_DROPOUT.value: FloatAttribute( - name=AttributeName.ATTENTION_DROPOUT.value, - default_value=0.0, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "timerxl", - "model.safetensors", - ), - value_choices=[""], - ), -} - # built-in sktime model attributes + # NaiveForecaster naive_forecaster_attribute_map = { AttributeName.PREDICT_LENGTH.value: IntAttribute( @@ -603,6 +363,7 @@ def parse_attribute( name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 ), } + # ExponentialSmoothing exponential_smoothing_attribute_map = { AttributeName.PREDICT_LENGTH.value: IntAttribute( @@ -633,6 +394,7 @@ def parse_attribute( default_value=False, ), } + # Arima arima_attribute_map = { AttributeName.PREDICT_LENGTH.value: IntAttribute( @@ -712,6 +474,7 @@ def parse_attribute( default_value=False, ), } + # STLForecaster stl_forecaster_attribute_map = { AttributeName.PREDICT_LENGTH.value: IntAttribute( @@ -1045,194 +808,67 @@ def parse_attribute( } -class BuiltInModel(object): - def __init__(self, attributes): - self._attributes = attributes - self._model = None - - @abstractmethod - def inference(self, data): - raise NotImplementedError - - -class ArimaModel(BuiltInModel): - def __init__(self, attributes): - super(ArimaModel, self).__init__(attributes) - self._model = ARIMA( - order=attributes["order"], - seasonal_order=attributes["seasonal_order"], - method=attributes["method"], - suppress_warnings=attributes["suppress_warnings"], - maxiter=attributes["maxiter"], - out_of_sample_size=attributes["out_of_sample_size"], - scoring=attributes["scoring"], - with_intercept=attributes["with_intercept"], - time_varying_regression=attributes["time_varying_regression"], - enforce_stationarity=attributes["enforce_stationarity"], - enforce_invertibility=attributes["enforce_invertibility"], - simple_differencing=attributes["simple_differencing"], - measurement_error=attributes["measurement_error"], - mle_regression=attributes["mle_regression"], - hamilton_representation=attributes["hamilton_representation"], - concentrate_scale=attributes["concentrate_scale"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class ExponentialSmoothingModel(BuiltInModel): - def __init__(self, attributes): - super(ExponentialSmoothingModel, self).__init__(attributes) - self._model = ExponentialSmoothing( - damped_trend=attributes["damped_trend"], - initialization_method=attributes["initialization_method"], - optimized=attributes["optimized"], - remove_bias=attributes["remove_bias"], - use_brute=attributes["use_brute"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class NaiveForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(NaiveForecasterModel, self).__init__(attributes) - self._model = NaiveForecaster( - strategy=attributes["strategy"], sp=attributes["sp"] - ) +def get_attributes(model_type: BuiltInModelType): + """ + Get the attribute map of the built-in model. - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STLForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(STLForecasterModel, self).__init__(attributes) - self._model = STLForecaster( - sp=attributes["sp"], - seasonal=attributes["seasonal"], - seasonal_deg=attributes["seasonal_deg"], - trend_deg=attributes["trend_deg"], - low_pass_deg=attributes["low_pass_deg"], - seasonal_jump=attributes["seasonal_jump"], - trend_jump=attributes["trend_jump"], - low_pass_jump=attributes["low_pass_jump"], - ) + Args: + model_type: the type of the built-in model - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GMMHMMModel(BuiltInModel): - def __init__(self, attributes): - super(GMMHMMModel, self).__init__(attributes) - self._model = GMMHMM( - n_components=attributes["n_components"], - n_mix=attributes["n_mix"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - weights_prior=attributes["weights_prior"], - algorithm=attributes["algorithm"], - covariance_type=attributes["covariance_type"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) + Returns: + the attribute map of the built-in model - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GaussianHmmModel(BuiltInModel): - def __init__(self, attributes): - super(GaussianHmmModel, self).__init__(attributes) - self._model = GaussianHMM( - n_components=attributes["n_components"], - covariance_type=attributes["covariance_type"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - covars_prior=attributes["covars_prior"], - covars_weight=attributes["covars_weight"], - algorithm=attributes["algorithm"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) + """ + if model_type == BuiltInModelType.ARIMA: + attribute_map = arima_attribute_map + elif model_type == BuiltInModelType.NAIVE_FORECASTER: + attribute_map = naive_forecaster_attribute_map + elif ( + model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING + or model_type == BuiltInModelType.HOLTWINTERS + ): + attribute_map = exponential_smoothing_attribute_map + elif model_type == BuiltInModelType.STL_FORECASTER: + attribute_map = stl_forecaster_attribute_map + elif model_type == BuiltInModelType.GMM_HMM: + attribute_map = gmmhmm_attribute_map + elif model_type == BuiltInModelType.GAUSSIAN_HMM: + attribute_map = gaussian_hmm_attribute_map + elif model_type == BuiltInModelType.STRAY: + attribute_map = stray_attribute_map + else: + raise BuiltInModelNotSupportError(model_type.value) + return attribute_map - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STRAYModel(BuiltInModel): - def __init__(self, attributes): - super(STRAYModel, self).__init__(attributes) - self._model = STRAY( - alpha=attributes["alpha"], - k=attributes["k"], - knn_algorithm=attributes["knn_algorithm"], - p=attributes["p"], - size_threshold=attributes["size_threshold"], - outlier_tail=attributes["outlier_tail"], - ) - def inference(self, data): - try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - # change the output to int - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) +def update_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] +): + """ + Update the attribute of the built-in model using the input attributes. + Args: + input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of + the attribute + attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute + object + Returns: + a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute + """ + attributes = {} + for attribute_name in attribute_map: + # user specified the attribute + if attribute_name in input_attributes: + attribute = attribute_map[attribute_name] + value = attribute.parse(input_attributes[attribute_name]) + attribute.validate_value(value) + attributes[attribute_name] = value + # user did not specify the attribute, use the default value + else: + try: + attributes[attribute_name] = attribute_map[ + attribute_name + ].get_default_value() + except NotImplementedError as e: + logger.error(f"attribute {attribute_name} is not implemented.") + raise e + return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py new file mode 100644 index 0000000000000..7e8e41c4dcf11 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Any, Dict +from abc import abstractmethod +import numpy as np +from sklearn.preprocessing import MinMaxScaler +from sktime.detection.hmm_learn import GMMHMM, GaussianHMM +from sktime.detection.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.core.model.sktime.configuration_sktime import get_attributes, update_attribute +from iotdb.ainode.core.model.model_enums import BuiltInModelType +from iotdb.ainode.core.exception import InferenceModelInternalError, BuiltInModelNotSupportError +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +class BuiltInModel(object): + def __init__(self, attributes): + self._attributes = attributes + self._model = None + + @abstractmethod + def inference(self, data): + raise NotImplementedError + + +class ArimaModel(BuiltInModel): + def __init__(self, attributes): + super(ArimaModel, self).__init__(attributes) + self._model = ARIMA( + order=attributes["order"], + seasonal_order=attributes["seasonal_order"], + method=attributes["method"], + suppress_warnings=attributes["suppress_warnings"], + maxiter=attributes["maxiter"], + out_of_sample_size=attributes["out_of_sample_size"], + scoring=attributes["scoring"], + with_intercept=attributes["with_intercept"], + time_varying_regression=attributes["time_varying_regression"], + enforce_stationarity=attributes["enforce_stationarity"], + enforce_invertibility=attributes["enforce_invertibility"], + simple_differencing=attributes["simple_differencing"], + measurement_error=attributes["measurement_error"], + mle_regression=attributes["mle_regression"], + hamilton_representation=attributes["hamilton_representation"], + concentrate_scale=attributes["concentrate_scale"], + ) + + def inference(self, data): + try: + predict_length = self._attributes["predict_length"] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ExponentialSmoothingModel(BuiltInModel): + def __init__(self, attributes): + super(ExponentialSmoothingModel, self).__init__(attributes) + self._model = ExponentialSmoothing( + damped_trend=attributes["damped_trend"], + initialization_method=attributes["initialization_method"], + optimized=attributes["optimized"], + remove_bias=attributes["remove_bias"], + use_brute=attributes["use_brute"], + ) + + def inference(self, data): + try: + predict_length = self._attributes["predict_length"] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class NaiveForecasterModel(BuiltInModel): + def __init__(self, attributes): + super(NaiveForecasterModel, self).__init__(attributes) + self._model = NaiveForecaster( + strategy=attributes["strategy"], sp=attributes["sp"] + ) + + def inference(self, data): + try: + predict_length = self._attributes["predict_length"] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class STLForecasterModel(BuiltInModel): + def __init__(self, attributes): + super(STLForecasterModel, self).__init__(attributes) + self._model = STLForecaster( + sp=attributes["sp"], + seasonal=attributes["seasonal"], + seasonal_deg=attributes["seasonal_deg"], + trend_deg=attributes["trend_deg"], + low_pass_deg=attributes["low_pass_deg"], + seasonal_jump=attributes["seasonal_jump"], + trend_jump=attributes["trend_jump"], + low_pass_jump=attributes["low_pass_jump"], + ) + + def inference(self, data): + try: + predict_length = self._attributes["predict_length"] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + output = np.array(output, dtype=np.float64) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class GMMHMMModel(BuiltInModel): + def __init__(self, attributes): + super(GMMHMMModel, self).__init__(attributes) + self._model = GMMHMM( + n_components=attributes["n_components"], + n_mix=attributes["n_mix"], + min_covar=attributes["min_covar"], + startprob_prior=attributes["startprob_prior"], + transmat_prior=attributes["transmat_prior"], + means_prior=attributes["means_prior"], + means_weight=attributes["means_weight"], + weights_prior=attributes["weights_prior"], + algorithm=attributes["algorithm"], + covariance_type=attributes["covariance_type"], + n_iter=attributes["n_iter"], + tol=attributes["tol"], + params=attributes["params"], + init_params=attributes["init_params"], + implementation=attributes["implementation"], + ) + + def inference(self, data): + try: + self._model.fit(data) + output = self._model.predict(data) + output = np.array(output, dtype=np.int32) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class GaussianHmmModel(BuiltInModel): + def __init__(self, attributes): + super(GaussianHmmModel, self).__init__(attributes) + self._model = GaussianHMM( + n_components=attributes["n_components"], + covariance_type=attributes["covariance_type"], + min_covar=attributes["min_covar"], + startprob_prior=attributes["startprob_prior"], + transmat_prior=attributes["transmat_prior"], + means_prior=attributes["means_prior"], + means_weight=attributes["means_weight"], + covars_prior=attributes["covars_prior"], + covars_weight=attributes["covars_weight"], + algorithm=attributes["algorithm"], + n_iter=attributes["n_iter"], + tol=attributes["tol"], + params=attributes["params"], + init_params=attributes["init_params"], + implementation=attributes["implementation"], + ) + + def inference(self, data): + try: + self._model.fit(data) + output = self._model.predict(data) + output = np.array(output, dtype=np.int32) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class STRAYModel(BuiltInModel): + def __init__(self, attributes): + super(STRAYModel, self).__init__(attributes) + self._model = STRAY( + alpha=attributes["alpha"], + k=attributes["k"], + knn_algorithm=attributes["knn_algorithm"], + p=attributes["p"], + size_threshold=attributes["size_threshold"], + outlier_tail=attributes["outlier_tail"], + ) + + def inference(self, data): + try: + data = MinMaxScaler().fit_transform(data) + output = self._model.fit_transform(data) + # change the output to int + output = np.array(output, dtype=np.int32) + return output + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +def fetch_built_in_model( + model_type: BuiltInModelType, inference_attrs: Dict[str, str] +) -> Any: + default_attributes = get_attributes(model_type) + attributes = update_attribute(inference_attrs, default_attributes) + + if model_type == BuiltInModelType.ARIMA: + model = ArimaModel(attributes) + elif ( + model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING + or model_type == BuiltInModelType.HOLTWINTERS + ): + model = ExponentialSmoothingModel(attributes) + elif model_type == BuiltInModelType.NAIVE_FORECASTER: + model = NaiveForecasterModel(attributes) + elif model_type == BuiltInModelType.STL_FORECASTER: + model = STLForecasterModel(attributes) + elif model_type == BuiltInModelType.GMM_HMM: + model = GMMHMMModel(attributes) + elif model_type == BuiltInModelType.GAUSSIAN_HMM: + model = GaussianHmmModel(attributes) + elif model_type == BuiltInModelType.STRAY: + model = STRAYModel(attributes) + # elif model_type == BuiltInModelType.TIMER_XL: + # model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) + # elif model_type == BuiltInModelType.SUNDIAL: + # model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) + else: + raise BuiltInModelNotSupportError(model_type.value) + + return model diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py index 21eefef2933b3..5b9eb7f1f6b03 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py @@ -63,3 +63,5 @@ def __init__( super().__init__( **kwargs, ) + +# TODO: Lacking checkpoint_path \ No newline at end of file From e07dad84280d0c89a5c68afa9ffa8daeaf4af742 Mon Sep 17 00:00:00 2001 From: RkGrit Date: Wed, 12 Nov 2025 23:39:52 +0800 Subject: [PATCH 02/38] delete old code in model folder --- .../ainode/core/manager/model_manager.py | 169 ---- .../iotdb/ainode/core/model/__init__.py | 17 - .../iotdb/ainode/core/model/model_enums.py | 70 -- .../iotdb/ainode/core/model/model_factory.py | 351 ------- .../iotdb/ainode/core/model/model_info.py | 154 --- .../iotdb/ainode/core/model/model_storage.py | 453 --------- .../ainode/core/model/sktime/__init__.py | 17 - .../core/model/sktime/configuration_sktime.py | 874 ------------------ .../core/model/sktime/modeling_sktime.py | 261 ------ .../ainode/core/model/sundial/__init__.py | 17 - .../model/sundial/configuration_sundial.py | 67 -- .../ainode/core/model/sundial/flow_loss.py | 255 ----- .../core/model/sundial/modeling_sundial.py | 656 ------------- .../core/model/sundial/ts_generation_mixin.py | 383 -------- .../ainode/core/model/timerxl/__init__.py | 17 - .../core/model/timerxl/configuration_timer.py | 59 -- .../core/model/timerxl/modeling_timer.py | 644 ------------- .../core/model/timerxl/ts_generation_mixin.py | 370 -------- .../iotdb/ainode/core/model/uri_utils.py | 137 --- 19 files changed, 4971 deletions(-) delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/__init__.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_info.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py deleted file mode 100644 index d84bca77c8430..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ /dev/null @@ -1,169 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from typing import Callable, Dict - -from torch import nn -from yaml import YAMLError - -from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType, ModelStates -from iotdb.ainode.core.model.model_info import ModelInfo -from iotdb.ainode.core.model.model_storage import ModelStorage -from iotdb.ainode.core.rpc.status import get_status -from iotdb.ainode.core.util.decorator import singleton -from iotdb.thrift.ainode.ttypes import ( - TDeleteModelReq, - TRegisterModelReq, - TRegisterModelResp, - TShowModelsReq, - TShowModelsResp, -) -from iotdb.thrift.common.ttypes import TSStatus - -logger = Logger() - - -@singleton -class ModelManager: - def __init__(self): - self.model_storage = ModelStorage() - - def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: - logger.info(f"register model {req.modelId} from {req.uri}") - try: - configs, attributes = self.model_storage.register_model( - req.modelId, req.uri - ) - return TRegisterModelResp( - get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes - ) - except InvalidUriError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, e.message) - ) - except BadConfigValueError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) - ) - except YAMLError as e: - logger.warning(e) - if hasattr(e, "problem_mark"): - mark = e.problem_mark - return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file, " - f"at line {mark.line + 1} column {mark.column + 1}.", - ) - ) - return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file", - ) - ) - except Exception as e: - logger.warning(e) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) - - def delete_model(self, req: TDeleteModelReq) -> TSStatus: - logger.info(f"delete model {req.modelId}") - try: - self.model_storage.delete_model(req.modelId) - return get_status(TSStatusCode.SUCCESS_STATUS) - except Exception as e: - logger.warning(e) - return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False - ) -> Callable: - """ - Load the model with the given model_id. - """ - logger.info(f"Load model {model_id}") - try: - model = self.model_storage.load_model( - model_id, inference_attrs, acceleration - ) - logger.info(f"Model {model_id} loaded") - return model - except Exception as e: - logger.error(f"Failed to load model {model_id}: {e}") - raise - - def save_model(self, model_id: str, model: nn.Module) -> TSStatus: - """ - Save the model using save_pretrained - """ - logger.info(f"Saving model {model_id}") - try: - self.model_storage.save_model(model_id, model) - logger.info(f"Saving model {model_id} successfully") - return get_status( - TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved successfully" - ) - except Exception as e: - logger.error(f"Save model failed: {e}") - return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. - - Args: - model_id (str): The ID of the model. - - Returns: - str: The path to the checkpoint file for the model. - """ - return self.model_storage.get_ckpt_path(model_id) - - def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - return self.model_storage.show_models(req) - - def register_built_in_model(self, model_info: ModelInfo): - self.model_storage.register_built_in_model(model_info) - - def get_model_info(self, model_id: str) -> ModelInfo: - return self.model_storage.get_model_info(model_id) - - def update_model_state(self, model_id: str, state: ModelStates): - self.model_storage.update_model_state(model_id, state) - - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: - """ - Get the type of the model with the given model_id. - """ - return self.model_storage.get_built_in_model_type(model_id) - - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. - - Args: - model_id (str): The ID of the model. - - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return self.model_storage.is_built_in_or_fine_tuned(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/__init__.py deleted file mode 100644 index 2a1e720805f29..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py deleted file mode 100644 index 348f9924316b6..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from enum import Enum -from typing import List - - -class BuiltInModelType(Enum): - # forecast models - ARIMA = "Arima" - HOLTWINTERS = "HoltWinters" - EXPONENTIAL_SMOOTHING = "ExponentialSmoothing" - NAIVE_FORECASTER = "NaiveForecaster" - STL_FORECASTER = "StlForecaster" - - # anomaly detection models - GAUSSIAN_HMM = "GaussianHmm" - GMM_HMM = "GmmHmm" - STRAY = "Stray" - - # large time series models (LTSM) - TIMER_XL = "Timer-XL" - # sundial - SUNDIAL = "Timer-Sundial" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def is_built_in_model(model_type: str) -> bool: - """ - Check if the given model type corresponds to a built-in model. - """ - return model_type in BuiltInModelType.values() - - -class ModelFileType(Enum): - SAFETENSORS = "safetensors" - PYTORCH = "pytorch" - UNKNOWN = "unknown" - - -class ModelCategory(Enum): - BUILT_IN = "BUILT-IN" - FINE_TUNED = "FINE-TUNED" - USER_DEFINED = "USER-DEFINED" - - -class ModelStates(Enum): - ACTIVE = "ACTIVE" - INACTIVE = "INACTIVE" - LOADING = "LOADING" - DROPPING = "DROPPING" - TRAINING = "TRAINING" - FAILED = "FAILED" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py deleted file mode 100644 index ceedf11b4e3c5..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py +++ /dev/null @@ -1,351 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import glob -import os -import shutil -from urllib.parse import urljoin - -import yaml -from huggingface_hub import hf_hub_download - -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.uri_utils import ( - UriType, - download_file, - download_snapshot_from_hf, -) -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.util.serde import get_data_type_byte_from_str -from iotdb.thrift.ainode.ttypes import TConfigs -from iotdb.ainode.core.model.model_info import TIMER_REPO_ID -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) - -logger = Logger() - - -def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: - weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) - if not os.path.exists(weights_path): - logger.info( - f"Model weights file not found at {weights_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - local_dir=local_dir, - ) - logger.info(f"Got file to {weights_path}") - except Exception as e: - logger.error( - f"Failed to download model weights file to {local_dir} due to {e}" - ) - return False - if not os.path.exists(config_path): - logger.info( - f"Model config file not found at {config_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, - local_dir=local_dir, - ) - logger.info(f"Got file to {config_path}") - except Exception as e: - logger.error( - f"Failed to download model config file to {local_dir} due to {e}" - ) - return False - return True - - -def download_built_in_ltsm_from_hf_if_necessary( - model_type: BuiltInModelType, local_dir: str -) -> bool: - """ - Download the built-in ltsm from HuggingFace repository when necessary. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - repo_id = TIMER_REPO_ID[model_type] - if not _download_file_from_hf_if_necessary(local_dir, repo_id): - return False - return True - - -def fetch_model_by_uri( - uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Fetch the model files from the specified URI. - - Args: - uri_type (UriType): type of the URI, either repo, file, http or https - uri (str): a network or a local path of the model to be registered - storage_path (str): path to save the whole model, including weights, config, codes, etc. - model_file_type (ModelFileType): The type of model file, either safetensors or pytorch - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if uri_type == UriType.REPO or uri_type in [UriType.HTTP, UriType.HTTPS]: - return _fetch_model_from_network(uri, storage_path, model_file_type) - elif uri_type == UriType.FILE: - return _fetch_model_from_local(uri, storage_path, model_file_type) - else: - raise InvalidUriError(f"Invalid URI type: {uri_type}") - - -def _fetch_model_from_network( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - download_snapshot_from_hf(uri, storage_path) - return _process_huggingface_files(storage_path) - - # TODO: The following codes might be refactored in future - # concat uri to get complete url - uri = uri if uri.endswith("/") else uri + "/" - target_model_path = urljoin(uri, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = urljoin(uri, MODEL_CONFIG_FILE_IN_YAML) - - # download config file - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - download_file(target_config_path, config_storage_path) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, download model file - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - download_file(target_model_path, model_storage_path) - return configs, attributes - - -def _fetch_model_from_local( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - # copy anything in the uri to local_dir - for file in os.listdir(uri): - shutil.copy(os.path.join(uri, file), storage_path) - return _process_huggingface_files(storage_path) - # concat uri to get complete path - target_model_path = os.path.join(uri, MODEL_WEIGHTS_FILE_IN_PT) - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = os.path.join(uri, MODEL_CONFIG_FILE_IN_YAML) - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - - # check if file exist - exist_model_file = os.path.exists(target_model_path) - exist_config_file = os.path.exists(target_config_path) - - configs = None - attributes = None - if exist_model_file and exist_config_file: - # copy config.yaml - shutil.copy(target_config_path, config_storage_path) - logger.info( - f"copy file from {target_config_path} to {config_storage_path} success" - ) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, copy model file - shutil.copy(target_model_path, model_storage_path) - logger.info( - f"copy file from {target_model_path} to {model_storage_path} success" - ) - - elif not exist_model_file or not exist_config_file: - raise InvalidUriError(uri) - - return configs, attributes - - -def _parse_inference_config(config_dict): - """ - Args: - config_dict: dict - - configs: dict - - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2] - - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2] - - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - attributes: dict - Returns: - configs: TConfigs - attributes: str - """ - configs = config_dict["configs"] - - # check if input_shape and output_shape are two-dimensional array - if not ( - isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2 - ): - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "input_shape should be a two-dimensional array.", - ) - if not ( - isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2 - ): - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "output_shape should be a two-dimensional array.", - ) - - # check if input_shape and output_shape are positive integer - input_shape_is_positive_number = ( - isinstance(configs["input_shape"][0], int) - and isinstance(configs["input_shape"][1], int) - and configs["input_shape"][0] > 0 - and configs["input_shape"][1] > 0 - ) - if not input_shape_is_positive_number: - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "element in input_shape should be positive integer.", - ) - - output_shape_is_positive_number = ( - isinstance(configs["output_shape"][0], int) - and isinstance(configs["output_shape"][1], int) - and configs["output_shape"][0] > 0 - and configs["output_shape"][1] > 0 - ) - if not output_shape_is_positive_number: - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "element in output_shape should be positive integer.", - ) - - # check if input_type and output_type are one-dimensional array with right length - if "input_type" in configs and not ( - isinstance(configs["input_type"], list) - and len(configs["input_type"]) == configs["input_shape"][1] - ): - raise BadConfigValueError( - "input_type", - configs["input_type"], - "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].", - ) - - if "output_type" in configs and not ( - isinstance(configs["output_type"], list) - and len(configs["output_type"]) == configs["output_shape"][1] - ): - raise BadConfigValueError( - "output_type", - configs["output_type"], - "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].", - ) - - # parse input_type and output_type to byte - if "input_type" in configs: - input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]] - else: - input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][ - 1 - ] - - if "output_type" in configs: - output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]] - else: - output_type = [get_data_type_byte_from_str("float32")] * configs[ - "output_shape" - ][1] - - # parse attributes - attributes = "" - if "attributes" in config_dict: - attributes = str(config_dict["attributes"]) - - return ( - TConfigs( - configs["input_shape"], configs["output_shape"], input_type, output_type - ), - attributes, - ) - - -def _process_huggingface_files(local_dir: str): - """ - TODO: Currently, we use this function to convert the model config from huggingface, we will refactor this in the future. - """ - config_file = None - for config_name in ["config.json", "model_config.json"]: - config_path = os.path.join(local_dir, config_name) - if os.path.exists(config_path): - config_file = config_path - break - - if not config_file: - raise InvalidUriError(f"No config.json found in {local_dir}") - - safetensors_files = glob.glob(os.path.join(local_dir, "*.safetensors")) - if not safetensors_files: - raise InvalidUriError(f"No .safetensors files found in {local_dir}") - - simple_config = { - "configs": { - "input_shape": [96, 1], - "output_shape": [96, 1], - "input_type": ["float32"], - "output_type": ["float32"], - }, - "attributes": { - "model_type": "huggingface_model", - "source_dir": local_dir, - "files": [os.path.basename(f) for f in safetensors_files], - }, - } - - configs, attributes = _parse_inference_config(simple_config) - return configs, attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py deleted file mode 100644 index 167bfd76640d1..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ /dev/null @@ -1,154 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import glob -import os - -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) - - -def get_model_file_type(model_path: str) -> ModelFileType: - """ - Determine the file type of the specified model directory. - """ - if _has_safetensors_format(model_path): - return ModelFileType.SAFETENSORS - elif _has_pytorch_format(model_path): - return ModelFileType.PYTORCH - else: - return ModelFileType.UNKNOWN - - -def _has_safetensors_format(path: str) -> bool: - """Check if directory contains safetensors files.""" - safetensors_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)) - json_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_JSON)) - return len(safetensors_files) > 0 and len(json_files) > 0 - - -def _has_pytorch_format(path: str) -> bool: - """Check if directory contains pytorch files.""" - pt_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_PT)) - yaml_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_YAML)) - return len(pt_files) > 0 and len(yaml_files) > 0 - - -def get_built_in_model_type(model_type: str) -> BuiltInModelType: - if not BuiltInModelType.is_built_in_model(model_type): - raise ValueError(f"Invalid built-in model type: {model_type}") - return BuiltInModelType(model_type) - - -class ModelInfo: - def __init__( - self, - model_id: str, - model_type: str, - category: ModelCategory, - state: ModelStates, - ): - self.model_id = model_id - self.model_type = model_type - self.category = category - self.state = state - - -TIMER_REPO_ID = { - BuiltInModelType.TIMER_XL: "thuml/timer-base-84m", - BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m", -} - -# Built-in machine learning models, they can be employed directly -BUILT_IN_MACHINE_LEARNING_MODEL_MAP = { - # forecast models - "arima": ModelInfo( - model_id="arima", - model_type=BuiltInModelType.ARIMA.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - "holtwinters": ModelInfo( - model_id="holtwinters", - model_type=BuiltInModelType.HOLTWINTERS.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - "exponential_smoothing": ModelInfo( - model_id="exponential_smoothing", - model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - "naive_forecaster": ModelInfo( - model_id="naive_forecaster", - model_type=BuiltInModelType.NAIVE_FORECASTER.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - "stl_forecaster": ModelInfo( - model_id="stl_forecaster", - model_type=BuiltInModelType.STL_FORECASTER.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - # anomaly detection models - "gaussian_hmm": ModelInfo( - model_id="gaussian_hmm", - model_type=BuiltInModelType.GAUSSIAN_HMM.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - "gmm_hmm": ModelInfo( - model_id="gmm_hmm", - model_type=BuiltInModelType.GMM_HMM.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), - "stray": ModelInfo( - model_id="stray", - model_type=BuiltInModelType.STRAY.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.ACTIVE, - ), -} - -# Built-in large time series models (LTSM), their weights are not included in AINode by default -BUILT_IN_LTSM_MAP = { - "timer_xl": ModelInfo( - model_id="timer_xl", - model_type=BuiltInModelType.TIMER_XL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, - ), - "sundial": ModelInfo( - model_id="sundial", - model_type=BuiltInModelType.SUNDIAL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, - ), -} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py deleted file mode 100644 index 1c63b56e519c9..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ /dev/null @@ -1,453 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import concurrent.futures -import json -import os -import shutil -from collections.abc import Callable -from typing import Dict - -import torch -from torch import nn - -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.sktime.modeling_sktime import fetch_built_in_model -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri, download_built_in_ltsm_from_hf_if_necessary -from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, - ModelInfo, - get_built_in_model_type, - get_model_file_type, -) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy -from iotdb.ainode.core.util.lock import ModelLockPool -from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp -from iotdb.thrift.common.ttypes import TSStatus - -logger = Logger() - - -class ModelStorage(object): - def __init__(self): - self._model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() - ) - if not os.path.exists(self._model_dir): - try: - os.makedirs(self._model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._builtin_model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir() - ) - if not os.path.exists(self._builtin_model_dir): - try: - os.makedirs(self._builtin_model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._lock_pool = ModelLockPool() - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1 - ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading. - self._model_info_map: Dict[str, ModelInfo] = {} - self._init_model_info_map() - - def _init_model_info_map(self): - """ - Initialize the model info map. - """ - # 1. initialize built-in and ready-to-use models - for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP: - self._model_info_map[model_id] = BUILT_IN_MACHINE_LEARNING_MODEL_MAP[ - model_id - ] - # 2. retrieve fine-tuned models from the built-in model directory - fine_tuned_models = self._retrieve_fine_tuned_models() - for model_id in fine_tuned_models: - self._model_info_map[model_id] = fine_tuned_models[model_id] - # 3. automatically downloading the weights of built-in LSTM models when necessary - for model_id in BUILT_IN_LTSM_MAP: - if model_id not in self._model_info_map: - self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id] - future = self._executor.submit( - self._download_built_in_model_if_necessary, model_id - ) - future.add_done_callback( - lambda f, mid=model_id: self._callback_model_download_result(f, mid) - ) - # 4. retrieve user-defined models from the model directory - user_defined_models = self._retrieve_user_defined_models() - for model_id in user_defined_models: - self._model_info_map[model_id] = user_defined_models[model_id] - - def _retrieve_fine_tuned_models(self): - """ - Retrieve fine-tuned models from the built-in model directory. - - Returns: - {"model_id": ModelInfo} - """ - result = {} - build_in_dirs = [ - d - for d in os.listdir(self._builtin_model_dir) - if os.path.isdir(os.path.join(self._builtin_model_dir, d)) - ] - for model_id in build_in_dirs: - config_file_path = os.path.join( - self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON - ) - if os.path.isfile(config_file_path): - with open(config_file_path, "r") as f: - model_config = json.load(f) - if "model_type" in model_config: - model_type = model_config["model_type"] - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=ModelCategory.FINE_TUNED, - state=ModelStates.ACTIVE, - ) - # Refactor the built-in model category - if "timer_xl" == model_id: - model_info.category = ModelCategory.BUILT_IN - if "sundial" == model_id: - model_info.category = ModelCategory.BUILT_IN - # Compatible patch with the codes in HuggingFace - if "timer" == model_type: - model_info.model_type = BuiltInModelType.TIMER_XL.value - if "sundial" == model_type: - model_info.model_type = BuiltInModelType.SUNDIAL.value - result[model_id] = model_info - return result - - def _download_built_in_model_if_necessary(self, model_id: str) -> bool: - """ - Download the built-in model if it is not already downloaded. - - Args: - model_id (str): The ID of the model to download. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - with self._lock_pool.get_lock(model_id).write_lock(): - local_dir = os.path.join(self._builtin_model_dir, model_id) - return download_built_in_ltsm_from_hf_if_necessary( - get_built_in_model_type(self._model_info_map[model_id].model_type), - local_dir, - ) - - def _callback_model_download_result(self, future, model_id: str): - with self._lock_pool.get_lock(model_id).write_lock(): - if future.result(): - self._model_info_map[model_id].state = ModelStates.ACTIVE - logger.info( - f"The built-in model: {model_id} is active and ready to use." - ) - else: - self._model_info_map[model_id].state = ModelStates.INACTIVE - - def _retrieve_user_defined_models(self): - """ - Retrieve user_defined models from the model directory. - - Returns: - {"model_id": ModelInfo} - """ - result = {} - user_dirs = [ - d - for d in os.listdir(self._model_dir) - if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights" - ] - for model_id in user_dirs: - result[model_id] = ModelInfo( - model_id=model_id, - model_type="", - category=ModelCategory.USER_DEFINED, - state=ModelStates.ACTIVE, - ) - return result - - def register_model(self, model_id: str, uri: str): - """ - Args: - model_id: id of model to register - uri: network or local dir path of the model to register - Returns: - configs: TConfigs - attributes: str - """ - with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - # create storage dir if not exist - if not os.path.exists(storage_path): - os.makedirs(storage_path) - uri_type, parsed_uri, model_file_type = get_model_register_strategy(uri) - self._model_info_map[model_id] = ModelInfo( - model_id=model_id, - model_type="", - category=ModelCategory.USER_DEFINED, - state=ModelStates.LOADING, - ) - try: - # TODO: The uri should be fetched asynchronously - configs, attributes = fetch_model_by_uri( - uri_type, parsed_uri, storage_path, model_file_type - ) - self._model_info_map[model_id].state = ModelStates.ACTIVE - return configs, attributes - except Exception as e: - logger.error(f"Failed to register model {model_id}: {e}") - self._model_info_map[model_id].state = ModelStates.INACTIVE - raise e - - def delete_model(self, model_id: str) -> None: - """ - Args: - model_id: id of model to delete - Returns: - None - """ - # check if the model is built-in - with self._lock_pool.get_lock(model_id).read_lock(): - if self._is_built_in(model_id): - raise BuiltInModelDeletionError(model_id) - - # delete the user-defined or fine-tuned model - with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - storage_path = os.path.join(self._builtin_model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - if model_id in self._model_info_map: - del self._model_info_map[model_id] - logger.info(f"Model {model_id} deleted successfully.") - - def _is_built_in(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in model. - - Args: - model_id (str): The ID of the model. - - Returns: - bool: True if the model is built-in, False otherwise. - """ - return ( - model_id in self._model_info_map - and self._model_info_map[model_id].category == ModelCategory.BUILT_IN - ) - - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. - - Args: - model_id (str): The ID of the model. - - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return model_id in self._model_info_map and ( - self._model_info_map[model_id].category == ModelCategory.BUILT_IN - or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED - ) - - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool - ) -> Callable: - """ - Load a model with automatic detection of .safetensors or .pt format - - Returns: - model: The model instance corresponding to specific model_id - """ - with self._lock_pool.get_lock(model_id).read_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - return fetch_built_in_model( - get_built_in_model_type(self._model_info_map[model_id].model_type), - model_dir, - inference_attrs, - ) - else: - # load the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - model_file_type = get_model_file_type(model_dir) - if model_file_type == ModelFileType.SAFETENSORS: - # TODO: Support this function - raise UnsupportedError("SAFETENSORS format") - else: - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - - if not os.path.exists(model_path): - raise ModelNotExistError(model_path) - model = torch.jit.load(model_path) - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - - try: - model = torch.compile(model) - except Exception as e: - logger.warning( - f"acceleration failed, fallback to normal mode: {str(e)}" - ) - return model - - def save_model(self, model_id: str, model: nn.Module): - """ - Save the model using save_pretrained - - Returns: - Whether saving succeeded - """ - with self._lock_pool.get_lock(model_id).write_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - model.save_pretrained(model_dir) - else: - # save the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - os.makedirs(model_dir, exist_ok=True) - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - try: - scripted_model = ( - model - if isinstance(model, torch.jit.ScriptModule) - else torch.jit.script(model) - ) - torch.jit.save(scripted_model, model_path) - except Exception as e: - logger.error(f"Failed to save scripted model: {e}") - - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. - - Args: - model_id (str): The ID of the model. - - Returns: - str: The path to the checkpoint file for the model. - """ - # Only support built-in models for now - return os.path.join(self._builtin_model_dir, f"{model_id}") - - def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - resp_status = TSStatus( - code=TSStatusCode.SUCCESS_STATUS.value, - message="Show models successfully", - ) - if req.modelId: - if req.modelId in self._model_info_map: - model_info = self._model_info_map[req.modelId] - return TShowModelsResp( - status=resp_status, - modelIdList=[req.modelId], - modelTypeMap={req.modelId: model_info.model_type}, - categoryMap={req.modelId: model_info.category.value}, - stateMap={req.modelId: model_info.state.value}, - ) - else: - return TShowModelsResp( - status=resp_status, - modelIdList=[], - modelTypeMap={}, - categoryMap={}, - stateMap={}, - ) - return TShowModelsResp( - status=resp_status, - modelIdList=list(self._model_info_map.keys()), - modelTypeMap=dict( - (model_id, model_info.model_type) - for model_id, model_info in self._model_info_map.items() - ), - categoryMap=dict( - (model_id, model_info.category.value) - for model_id, model_info in self._model_info_map.items() - ), - stateMap=dict( - (model_id, model_info.state.value) - for model_id, model_info in self._model_info_map.items() - ), - ) - - def register_built_in_model(self, model_info: ModelInfo): - with self._lock_pool.get_lock(model_info.model_id).write_lock(): - self._model_info_map[model_info.model_id] = model_info - - def get_model_info(self, model_id: str) -> ModelInfo: - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return self._model_info_map[model_id] - else: - raise ValueError(f"Model {model_id} does not exist.") - - def update_model_state(self, model_id: str, state: ModelStates): - with self._lock_pool.get_lock(model_id).write_lock(): - if model_id in self._model_info_map: - self._model_info_map[model_id].state = state - else: - raise ValueError(f"Model {model_id} does not exist.") - - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: - """ - Get the type of the model with the given model_id. - - Args: - model_id (str): The ID of the model. - - Returns: - str: The type of the model. - """ - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return get_built_in_model_type( - self._model_info_map[model_id].model_type - ) - else: - raise ValueError(f"Model {model_id} does not exist.") diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py deleted file mode 100644 index 2a1e720805f29..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py deleted file mode 100644 index 18fea61b6ff03..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py +++ /dev/null @@ -1,874 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from abc import abstractmethod -from typing import Callable, Dict, List -from enum import Enum -from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - ListRangeException, - NumericalRangeException, - StringRangeException, - WrongAttributeTypeError, -) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType - -logger = Logger() - - -class AttributeName(Enum): - # forecast Attribute - PREDICT_LENGTH = "predict_length" - - # NaiveForecaster - STRATEGY = "strategy" - SP = "sp" - - # STLForecaster - # SP = 'sp' - SEASONAL = "seasonal" - SEASONAL_DEG = "seasonal_deg" - TREND_DEG = "trend_deg" - LOW_PASS_DEG = "low_pass_deg" - SEASONAL_JUMP = "seasonal_jump" - TREND_JUMP = "trend_jump" - LOSS_PASS_JUMP = "low_pass_jump" - - # ExponentialSmoothing - DAMPED_TREND = "damped_trend" - INITIALIZATION_METHOD = "initialization_method" - OPTIMIZED = "optimized" - REMOVE_BIAS = "remove_bias" - USE_BRUTE = "use_brute" - - # Arima - ORDER = "order" - SEASONAL_ORDER = "seasonal_order" - METHOD = "method" - MAXITER = "maxiter" - SUPPRESS_WARNINGS = "suppress_warnings" - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" - SCORING = "scoring" - WITH_INTERCEPT = "with_intercept" - TIME_VARYING_REGRESSION = "time_varying_regression" - ENFORCE_STATIONARITY = "enforce_stationarity" - ENFORCE_INVERTIBILITY = "enforce_invertibility" - SIMPLE_DIFFERENCING = "simple_differencing" - MEASUREMENT_ERROR = "measurement_error" - MLE_REGRESSION = "mle_regression" - HAMILTON_REPRESENTATION = "hamilton_representation" - CONCENTRATE_SCALE = "concentrate_scale" - - # GAUSSIAN_HMM - N_COMPONENTS = "n_components" - COVARIANCE_TYPE = "covariance_type" - MIN_COVAR = "min_covar" - STARTPROB_PRIOR = "startprob_prior" - TRANSMAT_PRIOR = "transmat_prior" - MEANS_PRIOR = "means_prior" - MEANS_WEIGHT = "means_weight" - COVARS_PRIOR = "covars_prior" - COVARS_WEIGHT = "covars_weight" - ALGORITHM = "algorithm" - N_ITER = "n_iter" - TOL = "tol" - PARAMS = "params" - INIT_PARAMS = "init_params" - IMPLEMENTATION = "implementation" - - # GMMHMM - # N_COMPONENTS = "n_components" - N_MIX = "n_mix" - # MIN_COVAR = "min_covar" - # STARTPROB_PRIOR = "startprob_prior" - # TRANSMAT_PRIOR = "transmat_prior" - WEIGHTS_PRIOR = "weights_prior" - - # MEANS_PRIOR = "means_prior" - # MEANS_WEIGHT = "means_weight" - # ALGORITHM = "algorithm" - # COVARIANCE_TYPE = "covariance_type" - # N_ITER = "n_iter" - # TOL = "tol" - # INIT_PARAMS = "init_params" - # PARAMS = "params" - # IMPLEMENTATION = "implementation" - - # STRAY - ALPHA = "alpha" - K = "k" - KNN_ALGORITHM = "knn_algorithm" - P = "p" - SIZE_THRESHOLD = "size_threshold" - OUTLIER_TAIL = "outlier_tail" - - # timerxl - INPUT_TOKEN_LEN = "input_token_len" - HIDDEN_SIZE = "hidden_size" - INTERMEDIATE_SIZE = "intermediate_size" - OUTPUT_TOKEN_LENS = "output_token_lens" - NUM_HIDDEN_LAYERS = "num_hidden_layers" - NUM_ATTENTION_HEADS = "num_attention_heads" - HIDDEN_ACT = "hidden_act" - USE_CACHE = "use_cache" - ROPE_THETA = "rope_theta" - ATTENTION_DROPOUT = "attention_dropout" - INITIALIZER_RANGE = "initializer_range" - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - CKPT_PATH = "ckpt_path" - - # sundial - DROPOUT_RATE = "dropout_rate" - FLOW_LOSS_DEPTH = "flow_loss_depth" - NUM_SAMPLING_STEPS = "num_sampling_steps" - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" - - def name(self) -> str: - return self.value - - -class Attribute(object): - def __init__(self, name: str): - """ - Args: - name: the name of the attribute - """ - self._name = name - - @abstractmethod - def get_default_value(self): - raise NotImplementedError - - @abstractmethod - def validate_value(self, value): - raise NotImplementedError - - @abstractmethod - def parse(self, string_value: str): - raise NotImplementedError - - -class IntAttribute(Attribute): - def __init__( - self, - name: str, - default_value: int, - default_low: int, - default_high: int, - ): - super(IntAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - int_value = int(string_value) - except: - raise WrongAttributeTypeError(self._name, "int") - return int_value - - -class FloatAttribute(Attribute): - def __init__( - self, - name: str, - default_value: float, - default_low: float, - default_high: float, - ): - super(FloatAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - float_value = float(string_value) - except: - raise WrongAttributeTypeError(self._name, "float") - return float_value - - -class StringAttribute(Attribute): - def __init__(self, name: str, default_value: str, value_choices: List[str]): - super(StringAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_choices = value_choices - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if value in self.__value_choices: - return True - raise StringRangeException(self._name, value, self.__value_choices) - - def parse(self, string_value: str): - return string_value - - -class BooleanAttribute(Attribute): - def __init__(self, name: str, default_value: bool): - super(BooleanAttribute, self).__init__(name) - self.__default_value = default_value - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if isinstance(value, bool): - return True - raise WrongAttributeTypeError(self._name, "bool") - - def parse(self, string_value: str): - if string_value.lower() == "true": - return True - elif string_value.lower() == "false": - return False - else: - raise WrongAttributeTypeError(self._name, "bool") - - -class ListAttribute(Attribute): - def __init__(self, name: str, default_value: List, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(ListAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, list): - raise WrongAttributeTypeError(self._name, "list") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - list_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "list") - if not isinstance(list_value, list): - raise WrongAttributeTypeError(self._name, "list") - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - return list_value - - -class TupleAttribute(Attribute): - def __init__(self, name: str, default_value: tuple, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(TupleAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - tuple_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "tuple") - if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - list_value = list(tuple_value) - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - tuple_value = tuple(list_value) - return tuple_value - - -# built-in sktime model attributes - -# NaiveForecaster -naive_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.STRATEGY.value: StringAttribute( - name=AttributeName.STRATEGY.value, - default_value="last", - value_choices=["last", "mean"], - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 - ), -} - -# ExponentialSmoothing -exponential_smoothing_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.DAMPED_TREND.value: BooleanAttribute( - name=AttributeName.DAMPED_TREND.value, - default_value=False, - ), - AttributeName.INITIALIZATION_METHOD.value: StringAttribute( - name=AttributeName.INITIALIZATION_METHOD.value, - default_value="estimated", - value_choices=["estimated", "heuristic", "legacy-heuristic", "known"], - ), - AttributeName.OPTIMIZED.value: BooleanAttribute( - name=AttributeName.OPTIMIZED.value, - default_value=True, - ), - AttributeName.REMOVE_BIAS.value: BooleanAttribute( - name=AttributeName.REMOVE_BIAS.value, - default_value=False, - ), - AttributeName.USE_BRUTE.value: BooleanAttribute( - name=AttributeName.USE_BRUTE.value, - default_value=False, - ), -} - -# Arima -arima_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.ORDER.value: TupleAttribute( - name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int - ), - AttributeName.SEASONAL_ORDER.value: TupleAttribute( - name=AttributeName.SEASONAL_ORDER.value, - default_value=(0, 0, 0, 0), - value_type=int, - ), - AttributeName.METHOD.value: StringAttribute( - name=AttributeName.METHOD.value, - default_value="lbfgs", - value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], - ), - AttributeName.MAXITER.value: IntAttribute( - name=AttributeName.MAXITER.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( - name=AttributeName.SUPPRESS_WARNINGS.value, - default_value=True, - ), - AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute( - name=AttributeName.OUT_OF_SAMPLE_SIZE.value, - default_value=0, - default_low=0, - default_high=5000, - ), - AttributeName.SCORING.value: StringAttribute( - name=AttributeName.SCORING.value, - default_value="mse", - value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], - ), - AttributeName.WITH_INTERCEPT.value: BooleanAttribute( - name=AttributeName.WITH_INTERCEPT.value, - default_value=True, - ), - AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute( - name=AttributeName.TIME_VARYING_REGRESSION.value, - default_value=False, - ), - AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_STATIONARITY.value, - default_value=True, - ), - AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_INVERTIBILITY.value, - default_value=True, - ), - AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute( - name=AttributeName.SIMPLE_DIFFERENCING.value, - default_value=False, - ), - AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute( - name=AttributeName.MEASUREMENT_ERROR.value, - default_value=False, - ), - AttributeName.MLE_REGRESSION.value: BooleanAttribute( - name=AttributeName.MLE_REGRESSION.value, - default_value=True, - ), - AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute( - name=AttributeName.HAMILTON_REPRESENTATION.value, - default_value=False, - ), - AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( - name=AttributeName.CONCENTRATE_SCALE.value, - default_value=False, - ), -} - -# STLForecaster -stl_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000 - ), - AttributeName.SEASONAL.value: IntAttribute( - name=AttributeName.SEASONAL.value, - default_value=7, - default_low=1, - default_high=5000, - ), - AttributeName.SEASONAL_DEG.value: IntAttribute( - name=AttributeName.SEASONAL_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_DEG.value: IntAttribute( - name=AttributeName.TREND_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOW_PASS_DEG.value: IntAttribute( - name=AttributeName.LOW_PASS_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.SEASONAL_JUMP.value: IntAttribute( - name=AttributeName.SEASONAL_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_JUMP.value: IntAttribute( - name=AttributeName.TREND_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOSS_PASS_JUMP.value: IntAttribute( - name=AttributeName.LOSS_PASS_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), -} - -# GAUSSIAN_HMM -gaussian_hmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["spherical", "diag", "full", "tied"], - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_PRIOR.value: FloatAttribute( - name=AttributeName.COVARS_PRIOR.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_WEIGHT.value: FloatAttribute( - name=AttributeName.COVARS_WEIGHT.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# GMMHMM -gmmhmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.N_MIX.value: IntAttribute( - name=AttributeName.N_MIX.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.WEIGHTS_PRIOR.value: FloatAttribute( - name=AttributeName.WEIGHTS_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["sperical", "diag", "full", "tied"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# STRAY -stray_attribute_map = { - AttributeName.ALPHA.value: FloatAttribute( - name=AttributeName.ALPHA.value, - default_value=0.01, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.K.value: IntAttribute( - name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000 - ), - AttributeName.KNN_ALGORITHM.value: StringAttribute( - name=AttributeName.KNN_ALGORITHM.value, - default_value="brute", - value_choices=["brute", "kd_tree", "ball_tree", "auto"], - ), - AttributeName.P.value: FloatAttribute( - name=AttributeName.P.value, - default_value=0.5, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.SIZE_THRESHOLD.value: IntAttribute( - name=AttributeName.SIZE_THRESHOLD.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.OUTLIER_TAIL.value: StringAttribute( - name=AttributeName.OUTLIER_TAIL.value, - default_value="max", - value_choices=["min", "max"], - ), -} - - -def get_attributes(model_type: BuiltInModelType): - """ - Get the attribute map of the built-in model. - - Args: - model_type: the type of the built-in model - - Returns: - the attribute map of the built-in model - - """ - if model_type == BuiltInModelType.ARIMA: - attribute_map = arima_attribute_map - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - attribute_map = naive_forecaster_attribute_map - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - attribute_map = exponential_smoothing_attribute_map - elif model_type == BuiltInModelType.STL_FORECASTER: - attribute_map = stl_forecaster_attribute_map - elif model_type == BuiltInModelType.GMM_HMM: - attribute_map = gmmhmm_attribute_map - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - attribute_map = gaussian_hmm_attribute_map - elif model_type == BuiltInModelType.STRAY: - attribute_map = stray_attribute_map - else: - raise BuiltInModelNotSupportError(model_type.value) - return attribute_map - - -def update_attribute( - input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] -): - """ - Update the attribute of the built-in model using the input attributes. - Args: - input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of - the attribute - attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute - object - Returns: - a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute - """ - attributes = {} - for attribute_name in attribute_map: - # user specified the attribute - if attribute_name in input_attributes: - attribute = attribute_map[attribute_name] - value = attribute.parse(input_attributes[attribute_name]) - attribute.validate_value(value) - attributes[attribute_name] = value - # user did not specify the attribute, use the default value - else: - try: - attributes[attribute_name] = attribute_map[ - attribute_name - ].get_default_value() - except NotImplementedError as e: - logger.error(f"attribute {attribute_name} is not implemented.") - raise e - return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py deleted file mode 100644 index 7e8e41c4dcf11..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py +++ /dev/null @@ -1,261 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from typing import Any, Dict -from abc import abstractmethod -import numpy as np -from sklearn.preprocessing import MinMaxScaler -from sktime.detection.hmm_learn import GMMHMM, GaussianHMM -from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA -from sktime.forecasting.exp_smoothing import ExponentialSmoothing -from sktime.forecasting.naive import NaiveForecaster -from sktime.forecasting.trend import STLForecaster - -from iotdb.ainode.core.model.sktime.configuration_sktime import get_attributes, update_attribute -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.exception import InferenceModelInternalError, BuiltInModelNotSupportError -from iotdb.ainode.core.log import Logger - -logger = Logger() - - -class BuiltInModel(object): - def __init__(self, attributes): - self._attributes = attributes - self._model = None - - @abstractmethod - def inference(self, data): - raise NotImplementedError - - -class ArimaModel(BuiltInModel): - def __init__(self, attributes): - super(ArimaModel, self).__init__(attributes) - self._model = ARIMA( - order=attributes["order"], - seasonal_order=attributes["seasonal_order"], - method=attributes["method"], - suppress_warnings=attributes["suppress_warnings"], - maxiter=attributes["maxiter"], - out_of_sample_size=attributes["out_of_sample_size"], - scoring=attributes["scoring"], - with_intercept=attributes["with_intercept"], - time_varying_regression=attributes["time_varying_regression"], - enforce_stationarity=attributes["enforce_stationarity"], - enforce_invertibility=attributes["enforce_invertibility"], - simple_differencing=attributes["simple_differencing"], - measurement_error=attributes["measurement_error"], - mle_regression=attributes["mle_regression"], - hamilton_representation=attributes["hamilton_representation"], - concentrate_scale=attributes["concentrate_scale"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class ExponentialSmoothingModel(BuiltInModel): - def __init__(self, attributes): - super(ExponentialSmoothingModel, self).__init__(attributes) - self._model = ExponentialSmoothing( - damped_trend=attributes["damped_trend"], - initialization_method=attributes["initialization_method"], - optimized=attributes["optimized"], - remove_bias=attributes["remove_bias"], - use_brute=attributes["use_brute"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class NaiveForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(NaiveForecasterModel, self).__init__(attributes) - self._model = NaiveForecaster( - strategy=attributes["strategy"], sp=attributes["sp"] - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STLForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(STLForecasterModel, self).__init__(attributes) - self._model = STLForecaster( - sp=attributes["sp"], - seasonal=attributes["seasonal"], - seasonal_deg=attributes["seasonal_deg"], - trend_deg=attributes["trend_deg"], - low_pass_deg=attributes["low_pass_deg"], - seasonal_jump=attributes["seasonal_jump"], - trend_jump=attributes["trend_jump"], - low_pass_jump=attributes["low_pass_jump"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GMMHMMModel(BuiltInModel): - def __init__(self, attributes): - super(GMMHMMModel, self).__init__(attributes) - self._model = GMMHMM( - n_components=attributes["n_components"], - n_mix=attributes["n_mix"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - weights_prior=attributes["weights_prior"], - algorithm=attributes["algorithm"], - covariance_type=attributes["covariance_type"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GaussianHmmModel(BuiltInModel): - def __init__(self, attributes): - super(GaussianHmmModel, self).__init__(attributes) - self._model = GaussianHMM( - n_components=attributes["n_components"], - covariance_type=attributes["covariance_type"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - covars_prior=attributes["covars_prior"], - covars_weight=attributes["covars_weight"], - algorithm=attributes["algorithm"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STRAYModel(BuiltInModel): - def __init__(self, attributes): - super(STRAYModel, self).__init__(attributes) - self._model = STRAY( - alpha=attributes["alpha"], - k=attributes["k"], - knn_algorithm=attributes["knn_algorithm"], - p=attributes["p"], - size_threshold=attributes["size_threshold"], - outlier_tail=attributes["outlier_tail"], - ) - - def inference(self, data): - try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - # change the output to int - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -def fetch_built_in_model( - model_type: BuiltInModelType, inference_attrs: Dict[str, str] -) -> Any: - default_attributes = get_attributes(model_type) - attributes = update_attribute(inference_attrs, default_attributes) - - if model_type == BuiltInModelType.ARIMA: - model = ArimaModel(attributes) - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - model = ExponentialSmoothingModel(attributes) - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - model = NaiveForecasterModel(attributes) - elif model_type == BuiltInModelType.STL_FORECASTER: - model = STLForecasterModel(attributes) - elif model_type == BuiltInModelType.GMM_HMM: - model = GMMHMMModel(attributes) - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - model = GaussianHmmModel(attributes) - elif model_type == BuiltInModelType.STRAY: - model = STRAYModel(attributes) - # elif model_type == BuiltInModelType.TIMER_XL: - # model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) - # elif model_type == BuiltInModelType.SUNDIAL: - # model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) - else: - raise BuiltInModelNotSupportError(model_type.value) - - return model diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py deleted file mode 100644 index 2a1e720805f29..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py deleted file mode 100644 index 5b9eb7f1f6b03..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py +++ /dev/null @@ -1,67 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from typing import List - -from transformers import PretrainedConfig - - -class SundialConfig(PretrainedConfig): - model_type = "sundial" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - input_token_len: int = 16, - hidden_size: int = 768, - intermediate_size: int = 3072, - output_token_lens: List[int] = [720], - num_hidden_layers: int = 12, - num_attention_heads: int = 12, - hidden_act: str = "silu", - use_cache: bool = True, - rope_theta: int = 10000, - dropout_rate: float = 0.1, - initializer_range: float = 0.02, - max_position_embeddings: int = 10000, - flow_loss_depth: int = 3, - num_sampling_steps: int = 50, - diffusion_batch_mul: int = 4, - **kwargs, - ): - self.input_token_len = input_token_len - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.output_token_lens = output_token_lens - self.use_cache = use_cache - self.rope_theta = rope_theta - self.dropout_rate = dropout_rate - self.initializer_range = initializer_range - self.max_position_embeddings = max_position_embeddings - self.flow_loss_depth = flow_loss_depth - self.num_sampling_steps = num_sampling_steps - self.diffusion_batch_mul = diffusion_batch_mul - - super().__init__( - **kwargs, - ) - -# TODO: Lacking checkpoint_path \ No newline at end of file diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py deleted file mode 100644 index b3fe95dbe2d27..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py +++ /dev/null @@ -1,255 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import math - -import torch -import torch.nn as nn - - -class FlowLoss(nn.Module): - """Flow Loss""" - - def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps): - super(FlowLoss, self).__init__() - self.in_channels = target_channels - self.net = SimpleMLPAdaLN( - in_channels=target_channels, - model_channels=width, - out_channels=target_channels, - z_channels=z_channels, - num_res_blocks=depth, - ) - self.num_sampling_steps = num_sampling_steps - - def forward(self, target, z, mask=None, mask_y=None): - noise = torch.randn_like(target) - t = torch.rand(target.shape[0], device=target.device) - - noised_target = t[:, None] * target + (1 - t[:, None]) * noise - - predict_v = self.net(noised_target, t * 1000, z) - - weights = 1.0 / torch.arange( - 1, self.in_channels + 1, dtype=torch.float32, device=target.device - ) - if mask_y is not None: - loss = (mask_y * weights * (predict_v - target) ** 2).sum(dim=-1) - else: - loss = (weights * (predict_v - target) ** 2).sum(dim=-1) - - if mask is not None: - loss = (loss * mask).sum() / mask.sum() - return loss.mean() - - def sample(self, z, num_samples=1): - z = z.repeat(num_samples, 1) - noise = torch.randn(z.shape[0], self.in_channels).to(z.device) - x = noise - dt = 1.0 / self.num_sampling_steps - for i in range(self.num_sampling_steps): - t = (torch.ones((x.shape[0])) * i / self.num_sampling_steps).to(x.device) - pred = self.net(x, t * 1000, z) - x = x + (pred - noise) * dt - x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1) - return x - - -def modulate(x, shift, scale): - return x * (1 + scale) + shift - - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=half, dtype=torch.float32) - / half - ).to(device=t.device) - args = t[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 - ) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) - return t_emb - - -class ResBlock(nn.Module): - """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - """ - - def __init__(self, channels): - super().__init__() - self.channels = channels - - self.in_ln = nn.LayerNorm(channels, eps=1e-6) - self.mlp = nn.Sequential( - nn.Linear(channels, channels, bias=True), - nn.SiLU(), - nn.Linear(channels, channels, bias=True), - ) - - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True) - ) - - def forward(self, x, y): - shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) - h = modulate(self.in_ln(x), shift_mlp, scale_mlp) - h = self.mlp(h) - return x + gate_mlp * h - - -class FinalLayer(nn.Module): - """ - The final layer adopted from DiT. - """ - - def __init__(self, model_channels, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm( - model_channels, elementwise_affine=False, eps=1e-6 - ) - self.linear = nn.Linear(model_channels, out_channels, bias=True) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - - -class SimpleMLPAdaLN(nn.Module): - """ - The MLP for Diffusion Loss. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param z_channels: channels in the condition. - :param num_res_blocks: number of residual blocks per downsample. - """ - - def __init__( - self, - in_channels, - model_channels, - out_channels, - z_channels, - num_res_blocks, - ): - super().__init__() - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - - self.time_embed = TimestepEmbedder(model_channels) - self.cond_embed = nn.Linear(z_channels, model_channels) - - self.input_proj = nn.Linear(in_channels, model_channels) - - res_blocks = [] - for i in range(num_res_blocks): - res_blocks.append( - ResBlock( - model_channels, - ) - ) - - self.res_blocks = nn.ModuleList(res_blocks) - self.final_layer = FinalLayer(model_channels, out_channels) - - self.initialize_weights() - - def initialize_weights(self): - def _basic_init(module): - if isinstance(module, nn.Linear): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) - - # Initialize timestep embedding MLP - nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) - nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) - - # Zero-out adaLN modulation layers - for block in self.res_blocks: - nn.init.constant_(block.adaLN_modulation[-1].weight, 0) - nn.init.constant_(block.adaLN_modulation[-1].bias, 0) - - # Zero-out output layers - nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) - nn.init.constant_(self.final_layer.linear.weight, 0) - nn.init.constant_(self.final_layer.linear.bias, 0) - - def forward(self, x, t, c): - """ - Apply the model to an input batch. - :param x: an [N x C] Tensor of inputs. - :param t: a 1-D batch of timesteps. - :param c: conditioning from AR transformer. - :return: an [N x C] Tensor of outputs. - """ - x = self.input_proj(x) - t = self.time_embed(t) - c = self.cond_embed(c) - y = t + c - - for block in self.res_blocks: - x = block(x, y) - - return self.final_layer(x, y) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py deleted file mode 100644 index 3ebf516f705e0..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ /dev/null @@ -1,656 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import os -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors -from torch import nn -from transformers import Cache, DynamicCache, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, -) - -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss -from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin - -logger = Logger() - - -def rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class SundialPatchEmbedding(nn.Module): - def __init__(self, config: SundialConfig): - super().__init__() - self.dropout = nn.Dropout(config.dropout_rate) - self.hidden_layer = nn.Linear( - config.input_token_len * 2, config.intermediate_size - ) - self.act = ACT2FN[config.hidden_act] - self.output_layer = nn.Linear(config.intermediate_size, config.hidden_size) - self.residual_layer = nn.Linear(config.input_token_len * 2, config.hidden_size) - self.input_token_len = config.input_token_len - - def forward(self, x): - mask = torch.ones_like(x, dtype=torch.float32) - input_length = x.shape[-1] - padding_length = ( - self.input_token_len - (input_length % self.input_token_len) - ) % self.input_token_len - x = F.pad(x, (padding_length, 0)) - mask = F.pad(mask, (padding_length, 0)) - x = x.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len) - mask = mask.unfold( - dimension=-1, size=self.input_token_len, step=self.input_token_len - ) - - x = torch.cat([x, mask], dim=-1) - hid = self.act(self.hidden_layer(x)) - out = self.dropout(self.output_layer(hid)) - res = self.residual_layer(x) - out = out + res - return out - - -class SundialRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) - / self.dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64 - ).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class SundialAttention(nn.Module): - def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None): - super().__init__() - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.attention_dropout = config.dropout_rate - self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_emb = SundialRotaryEmbedding( - self.head_dim, max_position_embeddings=config.max_position_embeddings - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx - ) - - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attention_mask, - dropout_p=(self.attention_dropout if self.training else 0.0), - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class SundialMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) - - -class SundialDecoderLayer(nn.Module): - def __init__(self, config: SundialConfig, layer_idx: int): - super().__init__() - self.self_attn = SundialAttention(config, layer_idx) - - self.ffn_layer = SundialMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.norm1 = torch.nn.LayerNorm(config.hidden_size) - self.norm2 = torch.nn.LayerNorm(config.hidden_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - Optional[torch.Tensor], - Optional[Cache], - ]: - residual = hidden_states - - hidden_states = self.norm1(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.norm2(hidden_states) - hidden_states = self.ffn_layer(hidden_states) - hidden_states = residual + hidden_states - - if not output_attentions: - self_attn_weights = None - - return hidden_states, self_attn_weights, present_key_value - - -class SundialPreTrainedModel(PreTrainedModel): - config_class = SundialConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["SundialDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = False - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, torch.nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, torch.nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -class SundialModel(SundialPreTrainedModel): - def __init__(self, config: SundialConfig): - super().__init__(config) - self.embed_layer = SundialPatchEmbedding(config) - self.layers = nn.ModuleList( - [ - SundialDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self.norm = torch.nn.LayerNorm(config.hidden_size) - self.gradient_checkpointing = False - - def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[ - Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] - ] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: - # input_ids is the input of time series, its shape is [batch_size, seq_len] - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_layer(input_ids) - seq_length = inputs_embeds.shape[1] - - past_key_values_length = 0 - use_legacy_cache = False - - if past_key_values is not None: - use_legacy_cache = not isinstance(past_key_values, Cache) - # Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility. - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # Suppose the sequence length of each layer is the same - past_key_values_length = past_key_values.get_seq_length() - - # When training + checkpoints, caching is usually disabled (just do not transfer) - if ( - self.gradient_checkpointing - and self.training - and isinstance(past_key_values, Cache) - ): - past_key_values = None - past_key_values_length = 0 - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - position_ids = position_ids.view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=None, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if isinstance(past_key_values, Cache): - next_decoder_cache = layer_outputs[2] - - hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if isinstance(past_key_values, Cache): - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin): - def __init__(self, config: SundialConfig): - super().__init__(config) - self.config = config - self.model = SundialModel(self.config) - self.flow_loss = FlowLoss( - self.config.output_token_lens[-1], - self.config.hidden_size, - self.config.flow_loss_depth, - self.config.hidden_size, - self.config.num_sampling_steps, - ) - self.post_init() - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[ - Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] - ] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.FloatTensor] = None, - loss_masks: Optional[torch.FloatTensor] = None, - mask_y: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - max_output_length: Optional[int] = None, - revin: Optional[bool] = False, - num_samples: Optional[int] = 1, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if revin: - means = input_ids.mean(1, keepdim=True).detach() - stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach() - stdev = torch.where( - stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device) - ) - input_ids = (input_ids - means) / stdev - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state - predictions = None - - loss = None - if labels is not None: - if revin: - labels = (labels - means) / stdev - output_token_len = self.config.output_token_lens[-1] - seq_len = hidden_states.shape[1] * self.config.input_token_len - labels = labels[ - :, : seq_len - self.config.input_token_len + output_token_len - ] - shift_labels = labels.unfold( - dimension=-1, size=output_token_len, step=self.config.input_token_len - ) - - bsz, L, _ = shift_labels.shape - shift_labels = shift_labels.reshape(bsz * L, -1).repeat( - self.config.diffusion_batch_mul, 1 - ) - hidden_states = hidden_states.reshape(bsz * L, -1).repeat( - self.config.diffusion_batch_mul, 1 - ) - loss_masks = loss_masks.reshape(bsz * L).repeat( - self.config.diffusion_batch_mul - ) - mask_y = mask_y.repeat(L * self.config.diffusion_batch_mul, 1) - - loss = self.flow_loss(shift_labels, hidden_states, loss_masks, mask_y) - else: - if max_output_length is None: - output_token_len = self.config.output_token_lens[0] - max_output_length = output_token_len - else: - output_token_len = self.config.output_token_lens[0] - for h in self.config.output_token_lens[1:]: - if h > max_output_length: - break - else: - output_token_len = h - - bsz = hidden_states.shape[0] - hidden_states = hidden_states[:, -1, :] - predictions = self.flow_loss.sample(hidden_states, num_samples) - if output_token_len > max_output_length: - predictions = predictions[:, :, :max_output_length] - if revin: - predictions = predictions * stdev + means - if not return_dict: - output = (predictions,) + outputs[1:] - return (loss) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - logits=predictions, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - revin=False, - num_samples=1, - **kwargs, - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > ( - input_ids.shape[1] // self.config.input_token_len - ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, past_length * self.config.input_token_len :] - # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - token_num = ( - input_ids.shape[1] + self.config.input_token_len - 1 - ) // self.config.input_token_len - position_ids = position_ids[:, -token_num:] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - "revin": revin, - "num_samples": num_samples, - } - ) - return model_inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py deleted file mode 100644 index f09621f2cb0a5..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py +++ /dev/null @@ -1,383 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import warnings -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList -from transformers.generation import EosTokenCriteria, validate_stopping_criteria -from transformers.generation.utils import ( - GenerateDecoderOnlyOutput, - GenerateEncoderDecoderOutput, - GenerateNonBeamOutput, - GenerateOutput, - GenerationConfig, -) -from transformers.utils import ModelOutput - - -class TSGenerationMixin(GenerationMixin): - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - revin: Optional[bool] = True, - num_samples: Optional[int] = 1, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - if len(inputs.shape) != 2: - raise ValueError("Input shape must be: [batch_size, seq_len]") - batch_size, cur_len = inputs.shape - if cur_len < self.config.input_token_len: - raise ValueError( - f"Input length must be at least {self.config.input_token_len}" - ) - if revin: - means = inputs.mean(dim=-1, keepdim=True) - stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5 - inputs = (inputs - means) / stdev - outputs = super().generate( - inputs=inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, - num_samples=num_samples, - **kwargs, - ) - if revin: - stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1) - means = means.unsqueeze(1).repeat(1, num_samples, 1) - outputs = (outputs * stdev) + means - return outputs - - def _sample( - self, - input_ids: torch.Tensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.Tensor]: - input_ids = input_ids.to(self.device) - batch_size, cur_len = input_ids.shape - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - if eos_token_id is not None: - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() - for criteria in stopping_criteria - if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - raw_logits = () if (return_dict_in_generate and output_logits) else None - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - # keep track of which sequences are already finished - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - this_peer_finished = False - unfinished_sequences = torch.ones( - batch_size, dtype=torch.long, device=input_ids.device - ) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - true_seq_len = ( - cur_len + self.config.input_token_len - 1 - ) // self.config.input_token_len - model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ - :, -true_seq_len: - ] - max_length = stopping_criteria.max_length - generate_results = None - while self._has_unfinished_sequences( - this_peer_finished, synced_gpus, device=input_ids.device - ): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - input_length = input_ids.shape[1] - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - max_output_length=max_length - input_length, - ) - - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits - - # pre-process distribution - next_tokens_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_tokens_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # argmax - # next_tokens = torch.argmax(next_tokens_scores, dim=-1) - next_tokens = next_tokens_scores - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - # update generated ids, model inputs, and length for next step - horizon_length = next_tokens.shape[-1] // self.config.input_token_len - - past_key_values = model_kwargs.get("past_key_values") - if past_key_values is None or generate_results is None: - generate_results = next_tokens - else: - generate_results = torch.cat([generate_results, next_tokens], dim=-1) - input_ids = torch.cat([input_ids, next_tokens.median(dim=1)[0]], dim=-1) - - if streamer is not None: - streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - horizon_length=horizon_length, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores - ) - this_peer_finished = unfinished_sequences.max() == 0 - - if input_ids.shape[-1] > max_length: - input_ids = input_ids[:, :max_length] - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return generate_results[:, :, : (max_length - cur_len)] - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - horizon_length: int = 1, - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - if "past_key_values" in outputs: - model_kwargs["past_key_values"] = outputs.past_key_values - elif "mems" in outputs: - model_kwargs["past_key_values"] = outputs.mems - elif "past_buckets_states" in outputs: - model_kwargs["past_key_values"] = outputs.past_buckets_states - - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 - ) - - if not is_encoder_decoder: - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones( - (attention_mask.shape[0], horizon_length) - ), - ], - dim=-1, - ) - else: - # update decoder attention mask - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - model_kwargs["decoder_attention_mask"] = torch.cat( - [ - decoder_attention_mask, - decoder_attention_mask.new_ones( - (decoder_attention_mask.shape[0], horizon_length) - ), - ], - dim=-1, - ) - - if ( - "cache_position" in model_kwargs - and model_kwargs["cache_position"] is not None - ): - model_kwargs["cache_position"] = ( - model_kwargs["cache_position"][-1:] + horizon_length - ) - - return model_kwargs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py deleted file mode 100644 index 2a1e720805f29..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py deleted file mode 100644 index 34f9de91b633d..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py +++ /dev/null @@ -1,59 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from typing import List - -from transformers import PretrainedConfig - - -class TimerConfig(PretrainedConfig): - model_type = "timer" - keys_to_ignore_at_inference = ["past_key_values"] - - def __init__( - self, - input_token_len: int = 1, - hidden_size: int = 1024, - intermediate_size: int = 2048, - output_token_lens: List[int] = [1, 8, 32, 64], - num_hidden_layers: int = 8, - num_attention_heads: int = 8, - hidden_act: str = "silu", - use_cache: bool = True, - rope_theta: int = 10000, - attention_dropout: float = 0.0, - initializer_range: float = 0.02, - max_position_embeddings: int = 10000, - **kwargs, - ): - self.input_token_len = input_token_len - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.output_token_lens = output_token_lens - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.initializer_range = initializer_range - self.max_position_embeddings = max_position_embeddings - - super().__init__( - **kwargs, - ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py deleted file mode 100644 index 0a33c682742aa..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ /dev/null @@ -1,644 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from torch import nn -from transformers import Cache, DynamicCache, PreTrainedModel -from transformers.activations import ACT2FN -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, -) - -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin - -logger = Logger() - - -def rotate_half(x): - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class TimerPatchEmbedding(nn.Module): - def __init__(self, config: TimerConfig): - super().__init__() - self.input_token_len = config.input_token_len - self.emb = nn.Linear(config.input_token_len, config.hidden_size, bias=False) - - def forward(self, hidden_state: torch.Tensor): - hidden_state = hidden_state.unfold( - dimension=-1, size=self.input_token_len, step=self.input_token_len - ) - return self.emb(hidden_state) - - -class TimerPointEmbedding(nn.Module): - def __init__(self, config: TimerConfig): - super().__init__() - self.emb_layer = nn.Linear( - config.input_token_len, config.hidden_size, bias=False - ) - self.gate_layer = nn.Linear( - config.input_token_len, config.hidden_size, bias=False - ) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x) - return emb - - -class TimeMoeRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / ( - self.base - ** ( - torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) - / self.dim - ) - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype(), - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=torch.int64 - ).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -class TimerAttention(nn.Module): - def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None): - super().__init__() - self.layer_idx = layer_idx - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.attention_dropout = config.attention_dropout - self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) - self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) - self.rotary_emb = TimeMoeRotaryEmbedding( - self.head_dim, max_position_embeddings=config.max_position_embeddings - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - key_states = key_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - value_states = value_states.view( - bsz, q_len, self.num_heads, self.head_dim - ).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - - if past_key_value is not None: - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx - ) - - attn_output = F.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attention_mask, - dropout_p=self.attention_dropout, - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class TimerMLP(nn.Module): - def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): - super().__init__() - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, hidden_state): - return self.down_proj( - self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) - ) - - -class TimerDecoderLayer(nn.Module): - def __init__(self, config: TimerConfig, layer_idx: int): - super().__init__() - self.self_attn = TimerAttention(config, layer_idx) - - self.ffn_layer = TimerMLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.norm1 = torch.nn.LayerNorm(config.hidden_size) - self.norm2 = torch.nn.LayerNorm(config.hidden_size) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - Optional[torch.Tensor], - Optional[Cache], - ]: - residual = hidden_states - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - ) - hidden_states = residual + hidden_states - hidden_states = self.norm1(hidden_states) - - # Fully Connected - residual = hidden_states - hidden_states = self.ffn_layer(hidden_states) - hidden_states = residual + hidden_states - hidden_states = self.norm2(hidden_states) - - if not output_attentions: - self_attn_weights = None - - return hidden_states, self_attn_weights, present_key_value - - -class TimerPreTrainedModel(PreTrainedModel): - config_class = TimerConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["TimeDecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_sdpa = False - _supports_cache_class = True - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, torch.nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, torch.nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -class TimerModel(TimerPreTrainedModel): - def __init__(self, config: TimerConfig): - super().__init__(config) - self.embed_layer = TimerPatchEmbedding(config) - self.layers = nn.ModuleList( - [ - TimerDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers) - ] - ) - self.norm = torch.nn.LayerNorm(config.hidden_size) - self.gradient_checkpointing = False - - def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[ - Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] - ] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, MoeModelOutputWithPast]: - # input_ids is the input of time series, its shape is [batch_size, seq_len] - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if inputs_embeds is None: - inputs_embeds = self.embed_layer(input_ids) - seq_length = inputs_embeds.shape[1] - - past_key_values_length = 0 - use_legacy_cache = False - - if past_key_values is not None: - use_legacy_cache = not isinstance(past_key_values, Cache) - # Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility. - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() - - # When training + checkpoints, caching is usually disabled (just do not transfer) - if ( - self.gradient_checkpointing - and self.training - and isinstance(past_key_values, Cache) - ): - past_key_values = None - past_key_values_length = 0 - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - position_ids = position_ids.view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=None, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if isinstance(past_key_values, Cache): - next_decoder_cache = layer_outputs[2] - - hidden_states = self.norm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = None - if isinstance(past_key_values, Cache): - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) - - if not return_dict: - return tuple( - v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None - ) - return MoeModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin): - def __init__(self, config: TimerConfig): - super().__init__(config) - self.config = config - self.model = TimerModel(self.config) - lm_head_list = [] - self.output_token_len_map = {} - for i, output_token_len in enumerate(self.config.output_token_lens): - lm_head_list.append( - nn.Linear(self.config.hidden_size, output_token_len, bias=False) - ) - self.output_token_len_map[output_token_len] = i - self.lm_heads = nn.ModuleList(lm_head_list) - self.loss_function = torch.nn.MSELoss(reduction="none") - self.post_init() - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[ - Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] - ] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.FloatTensor] = None, - loss_masks: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - max_output_length: Optional[int] = None, - revin: Optional[bool] = False, - ) -> Union[Tuple, MoeCausalLMOutputWithPast]: - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if revin: - mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std( - dim=-1, keepdim=True - ) - input_ids = (input_ids - mean) / std - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state - predictions = None - - loss = None - if labels is not None: - ar_loss = 0.0 - for lm_head, output_token_len in zip( - self.lm_heads, self.config.output_token_lens - ): - one_predictions = lm_head(hidden_states) - one_loss = self.calc_ar_loss( - one_predictions, labels, loss_masks, output_token_len - ) - ar_loss += one_loss - if predictions is None: - predictions = one_predictions - loss = ar_loss / len(self.config.output_token_lens) - else: - if max_output_length is None: - output_token_len = self.config.output_token_lens[0] - max_output_length = output_token_len - else: - output_token_len = self.config.output_token_lens[0] - for h in self.config.output_token_lens[1:]: - if h > max_output_length: - break - else: - output_token_len = h - lm_head = self.lm_heads[self.output_token_len_map[output_token_len]] - predictions = lm_head(hidden_states)[:, -1, :] - if output_token_len > max_output_length: - predictions = predictions[:, :max_output_length] - if revin: - predictions = predictions * std + mean - if not return_dict: - output = (predictions,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return MoeCausalLMOutputWithPast( - loss=loss, - logits=predictions, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len): - seq_len = predictions.shape[1] * self.config.input_token_len - labels = labels[:, : seq_len - self.config.input_token_len + output_token_len] - shift_labels = labels.unfold( - dimension=-1, size=output_token_len, step=self.config.input_token_len - ) - - # Calculate loss with mask - losses = self.loss_function(predictions, shift_labels).mean(dim=-1) - if loss_masks is not None: - losses = losses * loss_masks - loss = losses.sum() / loss_masks.sum() - else: - loss = torch.mean(losses) - - return loss - - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - revin=True, - **kwargs, - ): - # Omit tokens covered by past_key_values - if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = past_key_values.get_seq_length() - else: - past_length = past_key_values[0][0].shape[2] - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > ( - input_ids.shape[1] // self.config.input_token_len - ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, past_length * self.config.input_token_len :] - # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. - - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[ - :, -(input_ids.shape[1] // self.config.input_token_len) : - ] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "attention_mask": attention_mask, - "revin": revin, - } - ) - return model_inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py deleted file mode 100644 index 18f711b8e1a13..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py +++ /dev/null @@ -1,370 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import warnings -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList -from transformers.generation import EosTokenCriteria, validate_stopping_criteria -from transformers.generation.utils import ( - GenerateDecoderOnlyOutput, - GenerateEncoderDecoderOutput, - GenerateNonBeamOutput, - GenerateOutput, - GenerationConfig, -) -from transformers.utils import ModelOutput - - -class TSGenerationMixin(GenerationMixin): - - @torch.no_grad() - def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[ - Callable[[int, torch.Tensor], List[int]] - ] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - **kwargs, - ) -> Union[GenerateOutput, torch.LongTensor]: - if len(inputs.shape) == 2: - batch_size, cur_len = inputs.shape - if cur_len < self.config.input_token_len: - raise ValueError( - f"Input length must be at least {self.config.input_token_len}" - ) - elif cur_len % self.config.input_token_len != 0: - new_len = ( - cur_len // self.config.input_token_len - ) * self.config.input_token_len - inputs = inputs[:, -new_len:] - else: - raise ValueError("Input shape must be: [batch_size, seq_len]") - return super().generate( - inputs=inputs, - generation_config=generation_config, - logits_processor=logits_processor, - stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, - synced_gpus=synced_gpus, - assistant_model=assistant_model, - streamer=streamer, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, - **kwargs, - ) - - def _sample( - self, - input_ids: torch.Tensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, - **model_kwargs, - ) -> Union[GenerateNonBeamOutput, torch.Tensor]: - input_ids = input_ids.to(self.device) - batch_size, cur_len = input_ids.shape - # init values - logits_processor = ( - logits_processor if logits_processor is not None else LogitsProcessorList() - ) - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - if max_length is not None: - warnings.warn( - "`max_length` is deprecated in this function, use" - " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", - UserWarning, - ) - stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length - ) - pad_token_id = ( - pad_token_id - if pad_token_id is not None - else self.generation_config.pad_token_id - ) - if eos_token_id is not None: - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - else: - # remove when the method is totally private - # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever - eos_token_id = [ - criteria.eos_token_id.tolist() - for criteria in stopping_criteria - if hasattr(criteria, "eos_token_id") - ] - eos_token_id = eos_token_id[0] if eos_token_id else None - if eos_token_id is None and self.generation_config.eos_token_id is not None: - eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - output_scores = ( - output_scores - if output_scores is not None - else self.generation_config.output_scores - ) - output_attentions = ( - output_attentions - if output_attentions is not None - else self.generation_config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.generation_config.output_hidden_states - ) - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else self.generation_config.return_dict_in_generate - ) - - # init attention / hidden states / scores tuples - raw_logits = () if (return_dict_in_generate and output_logits) else None - scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # if model is an encoder-decoder, retrieve encoder attention weights and hidden states - if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = ( - model_kwargs["encoder_outputs"].get("attentions") - if output_attentions - else None - ) - encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get("hidden_states") - if output_hidden_states - else None - ) - - # keep track of which sequences are already finished - if "inputs_embeds" in model_kwargs: - cur_len = model_kwargs["inputs_embeds"].shape[1] - this_peer_finished = False - unfinished_sequences = torch.ones( - batch_size, dtype=torch.long, device=input_ids.device - ) - model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) - true_seq_len = cur_len // self.config.input_token_len - model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ - :, -true_seq_len: - ] - max_length = stopping_criteria.max_length - while self._has_unfinished_sequences( - this_peer_finished, synced_gpus, device=input_ids.device - ): - # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) - - input_length = input_ids.shape[1] - - # forward pass to get next token - outputs = self( - **model_inputs, - return_dict=True, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - max_output_length=max_length - input_length, - ) - - if synced_gpus and this_peer_finished: - continue # don't waste resources running the code we don't need - - next_token_logits = outputs.logits - - # pre-process distribution - next_tokens_scores = logits_processor(input_ids, next_token_logits) - - # Store scores, attentions and hidden_states when required - if return_dict_in_generate: - if output_scores: - scores += (next_tokens_scores,) - if output_logits: - raw_logits += (next_token_logits,) - if output_attentions: - decoder_attentions += ( - (outputs.decoder_attentions,) - if self.config.is_encoder_decoder - else (outputs.attentions,) - ) - if self.config.is_encoder_decoder: - cross_attentions += (outputs.cross_attentions,) - - if output_hidden_states: - decoder_hidden_states += ( - (outputs.decoder_hidden_states,) - if self.config.is_encoder_decoder - else (outputs.hidden_states,) - ) - - # argmax - # next_tokens = torch.argmax(next_tokens_scores, dim=-1) - next_tokens = next_tokens_scores - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - # update generated ids, model inputs, and length for next step - horizon_length = next_tokens.shape[1] // self.config.input_token_len - - input_ids = torch.cat([input_ids, next_tokens], dim=-1) - if streamer is not None: - streamer.put(next_tokens.cpu()) - model_kwargs = self._update_model_kwargs_for_generation( - outputs, - model_kwargs, - horizon_length=horizon_length, - is_encoder_decoder=self.config.is_encoder_decoder, - ) - unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores - ) - this_peer_finished = unfinished_sequences.max() == 0 - - if input_ids.shape[1] > max_length: - input_ids = input_ids[:, :max_length] - - if streamer is not None: - streamer.end() - - if return_dict_in_generate: - if self.config.is_encoder_decoder: - return GenerateEncoderDecoderOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - encoder_attentions=encoder_attentions, - encoder_hidden_states=encoder_hidden_states, - decoder_attentions=decoder_attentions, - cross_attentions=cross_attentions, - decoder_hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return GenerateDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - logits=raw_logits, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - past_key_values=model_kwargs.get("past_key_values"), - ) - else: - return input_ids[:, -(max_length - cur_len) :] - - def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - horizon_length: int = 1, - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - ) -> Dict[str, Any]: - # update past_key_values - if "past_key_values" in outputs: - model_kwargs["past_key_values"] = outputs.past_key_values - elif "mems" in outputs: - model_kwargs["past_key_values"] = outputs.mems - elif "past_buckets_states" in outputs: - model_kwargs["past_key_values"] = outputs.past_buckets_states - - if getattr(outputs, "state", None) is not None: - model_kwargs["state"] = outputs.state - - # update token_type_ids with last value - if "token_type_ids" in model_kwargs: - token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 - ) - - if not is_encoder_decoder: - # update attention mask - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones( - (attention_mask.shape[0], horizon_length) - ), - ], - dim=-1, - ) - else: - # update decoder attention mask - if "decoder_attention_mask" in model_kwargs: - decoder_attention_mask = model_kwargs["decoder_attention_mask"] - model_kwargs["decoder_attention_mask"] = torch.cat( - [ - decoder_attention_mask, - decoder_attention_mask.new_ones( - (decoder_attention_mask.shape[0], horizon_length) - ), - ], - dim=-1, - ) - - if ( - "cache_position" in model_kwargs - and model_kwargs["cache_position"] is not None - ): - model_kwargs["cache_position"] = ( - model_kwargs["cache_position"][-1:] + horizon_length - ) - - return model_kwargs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py deleted file mode 100644 index b2e759e00ce07..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from enum import Enum -from typing import List - -from huggingface_hub import snapshot_download -from requests import Session -from requests.adapters import HTTPAdapter - -from iotdb.ainode.core.constant import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_RECONNECT_TIMEOUT, - DEFAULT_RECONNECT_TIMES, -) -from iotdb.ainode.core.exception import UnsupportedError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.model_info import get_model_file_type - -HTTP_PREFIX = "http://" -HTTPS_PREFIX = "https://" - -logger = Logger() - - -class UriType(Enum): - REPO = "repo" - FILE = "file" - HTTP = "http" - HTTPS = "https" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def parse_uri_type(uri: str): - """ - Parse the URI type from the given string. - """ - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): - return UriType.FILE - elif uri.startswith("http://"): - return UriType.HTTP - elif uri.startswith("https://"): - return UriType.HTTPS - else: - raise ValueError(f"Invalid URI type for {uri}") - - -def get_model_register_strategy(uri: str): - """ - Determine the loading strategy for a model based on its URI/path. - - Args: - uri (str): The URI of the model to be registered. - - Returns: - uri_type (UriType): The type of the URI, which can be one of: REPO, FILE, HTTP, or HTTPS. - parsed_uri (str): Parsed uri to get related file - model_file_type (ModelFileType): The type of the model file, which can be one of: SAFETENSORS, PYTORCH, or UNKNOWN. - """ - - uri_type = UriType.parse_uri_type(uri) - if uri_type in (UriType.HTTP, UriType.HTTPS): - # TODO: support HTTP(S) URI - raise UnsupportedError("CREATE MODEL FROM HTTP(S) URI") - else: - parsed_uri = uri[7:] - if uri_type == UriType.FILE: - # handle ~ in URI - parsed_uri = os.path.expanduser(parsed_uri) - model_file_type = get_model_file_type(uri) - elif uri_type == UriType.REPO: - # Currently, UriType.REPO only corresponds to huggingface repository with SAFETENSORS format - model_file_type = ModelFileType.SAFETENSORS - else: - raise ValueError(f"Invalid URI type for {uri}") - return uri_type, parsed_uri, model_file_type - - -def download_snapshot_from_hf(repo_id: str, local_dir: str): - """ - Download everything from a HuggingFace repository. - - Args: - repo_id (str): The HuggingFace repository ID. - local_dir (str): The local directory to save the downloaded files. - """ - try: - snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - ) - except Exception as e: - logger.error(f"Failed to download HuggingFace model {repo_id}: {e}") - raise e - - -def download_file(url: str, storage_path: str) -> None: - """ - Args: - url: url of file to download - storage_path: path to save the file - Returns: - None - """ - logger.info(f"Start Downloading file from {url} to {storage_path}") - session = Session() - adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES) - session.mount(HTTP_PREFIX, adapter) - session.mount(HTTPS_PREFIX, adapter) - response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) - response.raise_for_status() - with open(storage_path, "wb") as file: - for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): - if chunk: - file.write(chunk) - logger.info(f"Download file from {url} to {storage_path} success") From f6ef03f20790d447f409dc153c633469b791889d Mon Sep 17 00:00:00 2001 From: RkGrit Date: Mon, 24 Nov 2025 21:40:59 +0800 Subject: [PATCH 03/38] Reconstruct model management and model loading --- iotdb-core/ainode/iotdb/ainode/core/config.py | 16 + .../ainode/iotdb/ainode/core/constant.py | 143 +--- .../core/inference/inference_request.py | 6 +- .../core/inference/inference_request_pool.py | 110 +-- .../core/inference/pipeline/__init__.py | 29 + .../core/inference/pipeline/basic_pipeline.py | 111 +++ .../inference/pipeline/sktime_pipeline.py | 62 ++ .../inference/pipeline/sundial_pipeline.py | 48 ++ .../inference/pipeline/timerxl_pipeline.py | 44 ++ .../ainode/core/inference/pool_controller.py | 26 +- .../pool_scheduler/basic_pool_scheduler.py | 7 +- .../strategy/abstract_inference_pipeline.py | 60 -- .../timer_sundial_inference_pipeline.py | 51 -- .../strategy/timerxl_inference_pipeline.py | 51 -- .../ainode/core/manager/inference_manager.py | 141 +--- .../ainode/core/manager/model_manager.py | 121 ++++ .../ainode/iotdb/ainode/core/manager/utils.py | 13 +- .../{inference/strategy => model}/__init__.py | 0 .../iotdb/ainode/core/model/model_enums.py | 58 ++ .../iotdb/ainode/core/model/model_info.py | 129 ++++ .../iotdb/ainode/core/model/model_loader.py | 170 +++++ .../iotdb/ainode/core/model/model_storage.py | 576 ++++++++++++++++ .../ainode/core/model/sktime/__init__.py | 17 + .../core/model/sktime/arima/config.json | 22 + .../core/model/sktime/configuration_sktime.py | 379 ++++++++++ .../sktime/exponential_smoothing/config.json | 11 + .../model/sktime/gaussian_hmm/config.json | 20 + .../core/model/sktime/gmm_hmm/config.json | 20 + .../core/model/sktime/modeling_sktime.py | 178 +++++ .../model/sktime/naive_forecaster/config.json | 8 + .../model/sktime/stl_forecaster/config.json | 14 + .../core/model/sktime/stray/config.json | 11 + .../ainode/core/model/sundial/__init__.py | 17 + .../model/sundial/configuration_sundial.py | 65 ++ .../ainode/core/model/sundial/flow_loss.py | 255 +++++++ .../core/model/sundial/modeling_sundial.py | 651 ++++++++++++++++++ .../core/model/sundial/ts_generation_mixin.py | 383 +++++++++++ .../ainode/core/model/timerxl/__init__.py | 17 + .../core/model/timerxl/configuration_timer.py | 59 ++ .../core/model/timerxl/modeling_timer.py | 640 +++++++++++++++++ .../core/model/timerxl/ts_generation_mixin.py | 370 ++++++++++ .../ainode/iotdb/ainode/core/model/utils.py | 94 +++ .../ainode/iotdb/ainode/core/rpc/handler.py | 105 ++- 43 files changed, 4722 insertions(+), 586 deletions(-) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py rename iotdb-core/ainode/iotdb/ainode/core/{inference/strategy => model}/__init__.py (100%) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_info.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/utils.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index afcf0683d7d04..04ec3ee68c16a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -31,6 +31,7 @@ AINODE_CONF_FILE_NAME, AINODE_CONF_GIT_FILE_NAME, AINODE_CONF_POM_FILE_NAME, + AINODE_FINETUNE_MODELS_DIR, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, AINODE_INFERENCE_MAX_PREDICT_LENGTH, @@ -44,6 +45,7 @@ AINODE_SYSTEM_FILE_NAME, AINODE_TARGET_CONFIG_NODE_LIST, AINODE_THRIFT_COMPRESSION_ENABLED, + AINODE_USER_DEFINED_MODELS_DIR, AINODE_VERSION_INFO, ) from iotdb.ainode.core.exception import BadNodeUrlError @@ -96,6 +98,8 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR + self._ain_finetune_models_dir = AINODE_FINETUNE_MODELS_DIR + self._ain_user_defined_models_dir = AINODE_USER_DEFINED_MODELS_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -210,6 +214,18 @@ def get_ain_builtin_models_dir(self) -> str: def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None: self._ain_builtin_models_dir = ain_builtin_models_dir + def get_ain_finetune_models_dir(self) -> str: + return self._ain_finetune_models_dir + + def set_ain_finetune_models_dir(self, ain_finetune_models_dir: str) -> None: + self._ain_finetune_models_dir = ain_finetune_models_dir + + def get_ain_user_defined_models_dir(self) -> str: + return self._ain_user_defined_models_dir + + def set_ain_user_defined_models_dir(self, ain_user_defined_models_dir: str) -> None: + self._ain_user_defined_models_dir = ain_user_defined_models_dir + def get_ain_system_dir(self) -> str: return self._ain_system_dir diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index b9923d3e3ee7e..74decf9e88a61 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -18,9 +18,7 @@ import logging import os from enum import Enum -from typing import List -from iotdb.ainode.core.model.model_enums import BuiltInModelType from iotdb.thrift.common.ttypes import TEndPoint IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "") @@ -52,21 +50,26 @@ AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { - BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB - BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB + "sundial": 1036 * 1024**2, # 1036 MiB + "timerxl": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( 1.2 # the overhead ratio for inference, used to estimate the pool size ) -# AINode folder structure AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") AINODE_BUILTIN_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/weights" + IOTDB_AINODE_HOME, "data/ainode/models/builtin" ) # For built-in models, we only need to store their weights and config. -AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") -AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs") +AINODE_FINETUNE_MODELS_DIR = os.path.join( + IOTDB_AINODE_HOME, "data/ainode/models/finetune" +) +AINODE_USER_DEFINED_MODELS_DIR = os.path.join( + IOTDB_AINODE_HOME, "data/ainode/models/user_defined" +) +AINODE_SYSTEM_DIR = "data/ainode/system" +AINODE_LOG_DIR = "logs" # AINode log LOG_FILE_TYPE = ["all", "info", "warn", "error"] @@ -141,132 +144,8 @@ def name(self): return self.value -class ForecastModelType(Enum): - DLINEAR = "dlinear" - DLINEAR_INDIVIDUAL = "dlinear_individual" - NBEATS = "nbeats" - - @classmethod - def values(cls) -> List[str]: - values = [] - for item in list(cls): - values.append(item.value) - return values - - class ModelInputName(Enum): DATA_X = "data_x" TIME_STAMP_X = "time_stamp_x" TIME_STAMP_Y = "time_stamp_y" DEC_INP = "dec_inp" - - -class AttributeName(Enum): - # forecast Attribute - PREDICT_LENGTH = "predict_length" - - # NaiveForecaster - STRATEGY = "strategy" - SP = "sp" - - # STLForecaster - # SP = 'sp' - SEASONAL = "seasonal" - SEASONAL_DEG = "seasonal_deg" - TREND_DEG = "trend_deg" - LOW_PASS_DEG = "low_pass_deg" - SEASONAL_JUMP = "seasonal_jump" - TREND_JUMP = "trend_jump" - LOSS_PASS_JUMP = "low_pass_jump" - - # ExponentialSmoothing - DAMPED_TREND = "damped_trend" - INITIALIZATION_METHOD = "initialization_method" - OPTIMIZED = "optimized" - REMOVE_BIAS = "remove_bias" - USE_BRUTE = "use_brute" - - # Arima - ORDER = "order" - SEASONAL_ORDER = "seasonal_order" - METHOD = "method" - MAXITER = "maxiter" - SUPPRESS_WARNINGS = "suppress_warnings" - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" - SCORING = "scoring" - WITH_INTERCEPT = "with_intercept" - TIME_VARYING_REGRESSION = "time_varying_regression" - ENFORCE_STATIONARITY = "enforce_stationarity" - ENFORCE_INVERTIBILITY = "enforce_invertibility" - SIMPLE_DIFFERENCING = "simple_differencing" - MEASUREMENT_ERROR = "measurement_error" - MLE_REGRESSION = "mle_regression" - HAMILTON_REPRESENTATION = "hamilton_representation" - CONCENTRATE_SCALE = "concentrate_scale" - - # GAUSSIAN_HMM - N_COMPONENTS = "n_components" - COVARIANCE_TYPE = "covariance_type" - MIN_COVAR = "min_covar" - STARTPROB_PRIOR = "startprob_prior" - TRANSMAT_PRIOR = "transmat_prior" - MEANS_PRIOR = "means_prior" - MEANS_WEIGHT = "means_weight" - COVARS_PRIOR = "covars_prior" - COVARS_WEIGHT = "covars_weight" - ALGORITHM = "algorithm" - N_ITER = "n_iter" - TOL = "tol" - PARAMS = "params" - INIT_PARAMS = "init_params" - IMPLEMENTATION = "implementation" - - # GMMHMM - # N_COMPONENTS = "n_components" - N_MIX = "n_mix" - # MIN_COVAR = "min_covar" - # STARTPROB_PRIOR = "startprob_prior" - # TRANSMAT_PRIOR = "transmat_prior" - WEIGHTS_PRIOR = "weights_prior" - - # MEANS_PRIOR = "means_prior" - # MEANS_WEIGHT = "means_weight" - # ALGORITHM = "algorithm" - # COVARIANCE_TYPE = "covariance_type" - # N_ITER = "n_iter" - # TOL = "tol" - # INIT_PARAMS = "init_params" - # PARAMS = "params" - # IMPLEMENTATION = "implementation" - - # STRAY - ALPHA = "alpha" - K = "k" - KNN_ALGORITHM = "knn_algorithm" - P = "p" - SIZE_THRESHOLD = "size_threshold" - OUTLIER_TAIL = "outlier_tail" - - # timerxl - INPUT_TOKEN_LEN = "input_token_len" - HIDDEN_SIZE = "hidden_size" - INTERMEDIATE_SIZE = "intermediate_size" - OUTPUT_TOKEN_LENS = "output_token_lens" - NUM_HIDDEN_LAYERS = "num_hidden_layers" - NUM_ATTENTION_HEADS = "num_attention_heads" - HIDDEN_ACT = "hidden_act" - USE_CACHE = "use_cache" - ROPE_THETA = "rope_theta" - ATTENTION_DROPOUT = "attention_dropout" - INITIALIZER_RANGE = "initializer_range" - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - CKPT_PATH = "ckpt_path" - - # sundial - DROPOUT_RATE = "dropout_rate" - FLOW_LOSS_DEPTH = "flow_loss_depth" - NUM_SAMPLING_STEPS = "num_sampling_steps" - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" - - def name(self) -> str: - return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index 82c72cc37abf5..a70445c5efd4c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. # + import threading from typing import Any import torch -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.atmoic_int import AtomicInt @@ -41,7 +39,6 @@ def __init__( req_id: str, model_id: str, inputs: torch.Tensor, - inference_pipeline: AbstractInferencePipeline, max_new_tokens: int = 96, **infer_kwargs, ): @@ -52,7 +49,6 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.inference_pipeline = inference_pipeline self.max_new_tokens = ( max_new_tokens # Number of time series data points to generate ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 6b054c91fe31c..fcaa4c7a7543b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -25,19 +25,17 @@ import numpy as np import torch import torch.multiprocessing as mp -from transformers import PretrainedConfig from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest +from iotdb.ainode.core.inference.pipeline import get_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device @@ -62,7 +60,6 @@ def __init__( pool_id: int, model_info: ModelInfo, device: str, - config: PretrainedConfig, request_queue: mp.Queue, result_queue: mp.Queue, ready_event, @@ -71,7 +68,6 @@ def __init__( super().__init__() self.pool_id = pool_id self.model_info = model_info - self.config = config self.pool_kwargs = pool_kwargs self.ready_event = ready_event self.device = convert_device_id_to_torch_device(device) @@ -86,8 +82,8 @@ def __init__( self._batcher = BasicBatcher() self._stop_event = mp.Event() - self._model = None - self._model_manager = None + # self._inference_pipeline = get_pipeline(self.model_info.model_id, self.device) + self._logger = None # Fix inference seed @@ -98,9 +94,6 @@ def __init__( def _activate_requests(self): requests = self._request_scheduler.schedule_activate() for request in requests: - request.inputs = request.inference_pipeline.preprocess_inputs( - request.inputs - ) request.mark_running() self._running_queue.put(request) self._logger.debug( @@ -123,66 +116,36 @@ def _step(self): for requests in grouped_requests: batch_inputs = self._batcher.batch_request(requests).to(self.device) - if self.model_info.model_type == BuiltInModelType.SUNDIAL.value: - batch_output = self._model.generate( - batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - num_samples=10, - revin=True, - ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output.mean(dim=1)) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) - - elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value: - batch_output = self._model.generate( - batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - revin=True, - ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) + batch_output = self._inference_pipeline.infer( + batch_inputs, + predict_length=requests[0].max_new_tokens, + # num_samples=10, + revin=True, + ) + offset = 0 + for request in requests: + request.output_tensor = request.output_tensor.to(self.device) + cur_batch_size = request.batch_size + cur_output = batch_output[offset : offset + cur_batch_size] + offset += cur_batch_size + # request.write_step_output(cur_output.mean(dim=1)) + request.write_step_output(cur_output) + + # self._inference_pipeline.post_decode() + if request.is_finished(): + # self._inference_pipeline.post_inference() + # ensure the output tensor is on CPU before sending to result queue + request.output_tensor = request.output_tensor.cpu() + self._finished_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" + ) + else: + self._waiting_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" + ) + return def _requests_execute_loop(self): while not self._stop_event.is_set(): @@ -193,11 +156,8 @@ def run(self): self._logger = Logger( INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) - self._model_manager = ModelManager() self._request_scheduler.device = self.device - self._model = self._model_manager.load_model(self.model_info.model_id, {}).to( - self.device - ) + self._inference_pipeline = get_pipeline(self.model_info.model_id, self.device) self.ready_event.set() activate_daemon = threading.Thread( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py new file mode 100644 index 0000000000000..617c4e6738061 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from iotdb.ainode.core.inference.pipeline.sundial_pipeline import SundialPipeline +from iotdb.ainode.core.inference.pipeline.timerxl_pipeline import TimerxlPipeline + + +def get_pipeline(model_id, device): + if model_id == "timerxl": + return TimerxlPipeline(model_id, device=device) + elif model_id == "sundial": + return SundialPipeline(model_id, device=device) + else: + raise ValueError(f"Unsupported model_id: {model_id} with pipeline") diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py new file mode 100644 index 0000000000000..c413a92e82d62 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import ABC + +import torch + +from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.manager.model_manager import get_model_manager + + +class BasicPipeline(ABC): + def __init__(self, model_id, **infer_kwargs): + self.model_id = model_id + self.device = infer_kwargs.get("device", "cpu") + # self.model = get_model_manager().load_model(model_id).to(self.device) + self.model = get_model_manager().load_model( + model_id, device_map=str(self.device) + ) + + def _preprocess(self, inputs): + """ + Preprocess the input before inference, including shape validation and value transformation. + """ + # TODO: Integrate with the data processing pipeline operators + pass + + def infer(self, inputs): + pass + + def _post_decode(self): + """ + Post-process the outputs after each decode step. + """ + pass + + def _postprocess(self, output: torch.Tensor): + """ + Post-process the outputs after the entire inference task. + """ + pass + + +class ForecastPipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + if len(inputs.shape) != 2: + raise InferenceModelInternalError( + f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" + ) + return inputs + + def forecast(self, inputs, **infer_kwargs): + pass + + def _post_decode(self): + pass + + def _postprocess(self, output: torch.Tensor): + pass + + +class ClassificationPipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + pass + + def classify(self, inputs, **kwargs): + pass + + def _post_decode(self): + pass + + def _postprocess(self, output: torch.Tensor): + pass + + +class ChatPipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + pass + + def chat(self, inputs, **kwargs): + pass + + def _post_decode(self): + pass + + def _postprocess(self, output: torch.Tensor): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py new file mode 100644 index 0000000000000..004222db7cad5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import BasicPipeline + + +class SktimePipeline(BasicPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + return super()._preprocess(inputs) + + def infer(self, inputs, **infer_kwargs): + input_ids = self._preprocess(inputs) + + # Convert to pandas Series for sktime (sktime expects Series or DataFrame) + # Handle batch dimension: if batch_size > 1, process each sample separately + if len(input_ids.shape) == 2 and input_ids.shape[0] > 1: + # Batch processing: convert each row to Series + outputs = [] + for i in range(input_ids.shape[0]): + series = pd.Series(input_ids[i].cpu().numpy() if isinstance(input_ids, torch.Tensor) else input_ids[i]) + output = self.model.generate(series) + outputs.append(output) + output = np.array(outputs) + else: + # Single sample: convert to Series + if isinstance(input_ids, torch.Tensor): + series = pd.Series(input_ids.squeeze().cpu().numpy()) + else: + series = pd.Series(input_ids.squeeze()) + output = self.model.generate(series) + # Add batch dimension if needed + if len(output.shape) == 1: + output = output[np.newaxis, :] + + return self._postprocess(output) + + def _postprocess(self, output): + if isinstance(output, np.ndarray): + return torch.from_numpy(output).float() + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py new file mode 100644 index 0000000000000..8d0909954bf24 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.util.serde import convert_to_binary + + +class SundialPipeline(ForecastPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + return super()._preprocess(inputs) + + def infer(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + num_samples = infer_kwargs.get("num_samples", 10) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, + max_new_tokens=predict_length, + num_samples=num_samples, + revin=revin, + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output.mean(dim=1) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py new file mode 100644 index 0000000000000..cf6d35c805c5d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.util.serde import convert_to_binary + + +class TimerxlPipeline(ForecastPipeline): + def __init__(self, model_id, **infer_kwargs): + super().__init__(model_id, infer_kwargs=infer_kwargs) + + def _preprocess(self, inputs): + return super()._preprocess(inputs) + + def infer(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, max_new_tokens=predict_length, revin=revin + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 54580402ec293..5af5bb95102a6 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -41,9 +41,6 @@ ) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig from iotdb.ainode.core.util.atmoic_int import AtomicInt from iotdb.ainode.core.util.batch_executor import BatchExecutor from iotdb.ainode.core.util.decorator import synchronized @@ -78,7 +75,7 @@ def __init__(self, result_queue: mp.Queue): # =============== Pool Management =============== @synchronized(threading.Lock()) - def first_req_init(self, model_id: str): + def first_req_init(self, model_id: str, device): """ Initialize the pools when the first request for the given model_id arrives. """ @@ -110,17 +107,12 @@ def _first_pool_init(self, model_id: str, device_str: str): device = torch.device(device_str) device_id = device.index - if model_id == "sundial": - config = SundialConfig() - elif model_id == "timer_xl": - config = TimerConfig() first_queue = mp.Queue() ready_event = mp.Event() first_pool = InferenceRequestPool( pool_id=0, model_id=model_id, device=device_str, - config=config, request_queue=first_queue, result_queue=self._result_queue, ready_event=ready_event, @@ -255,29 +247,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int): """ def _expand_pool_on_device(*_): - result_queue = mp.Queue() + request_queue = mp.Queue() pool_id = self._new_pool_id.get_and_increment() model_info = self._model_manager.get_model_info(model_id) - model_type = model_info.model_type - if model_type == BuiltInModelType.SUNDIAL.value: - config = SundialConfig() - elif model_type == BuiltInModelType.TIMER_XL.value: - config = TimerConfig() - else: - raise InferenceModelInternalError( - f"Unsupported model type {model_type} for loading model {model_id}" - ) pool = InferenceRequestPool( pool_id=pool_id, model_info=model_info, device=device_id, - config=config, - request_queue=result_queue, + request_queue=request_queue, result_queue=self._result_queue, ready_event=mp.Event(), ) pool.start() - self._register_pool(model_id, device_id, pool_id, pool, result_queue) + self._register_pool(model_id, device_id, pool_id, pool, request_queue) if not pool.ready_event.wait(timeout=300): logger.error( f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool failed to be ready in time" diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 6a2bd2b619aa7..6ad55f742a4e1 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -28,7 +28,7 @@ ScaleActionType, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager +from iotdb.ainode.core.manager.model_manager import get_model_manager, ModelManager from iotdb.ainode.core.manager.utils import ( INFERENCE_EXTRA_MEMORY_RATIO, INFERENCE_MEMORY_USAGE_RATIO, @@ -36,12 +36,11 @@ estimate_pool_size, evaluate_system_resources, ) -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo +from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device logger = Logger() - def _estimate_shared_pool_size_by_total_mem( device: torch.device, existing_model_infos: List[ModelInfo], @@ -63,7 +62,7 @@ def _estimate_shared_pool_size_by_total_mem( mem_usages: Dict[str, float] = {} for model_info in all_models: mem_usages[model_info.model_id] = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO + MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO ) # Evaluate system resources and get TOTAL memory diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py deleted file mode 100644 index 2300169a6ee93..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from abc import ABC, abstractmethod - -import torch - - -class AbstractInferencePipeline(ABC): - """ - Abstract assistance strategy class for model inference. - This class shall define the interface process for specific model. - """ - - def __init__(self, model_config, **infer_kwargs): - self.model_config = model_config - self.infer_kwargs = infer_kwargs - - @abstractmethod - def preprocess_inputs(self, inputs: torch.Tensor): - """ - Preprocess the inputs before inference, including shape validation and value transformation. - - Args: - inputs (torch.Tensor): The input tensor to be preprocessed. - - Returns: - torch.Tensor: The preprocessed input tensor. - """ - # TODO: Integrate with the data processing pipeline operators - pass - - @abstractmethod - def post_decode(self): - """ - Post-process the outputs after each decode step. - """ - pass - - @abstractmethod - def post_inference(self): - """ - Post-process the outputs after the entire inference task. - """ - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py deleted file mode 100644 index 17c88e32fb5a0..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py +++ /dev/null @@ -1,51 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import torch - -from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig - - -class TimerSundialInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-Sundial model inference. - """ - - def __init__(self, model_config: SundialConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) - if len(inputs.shape) != 2: - raise InferenceModelInternalError( - f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" - ) - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - return inputs - - def post_decode(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass - - def post_inference(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py deleted file mode 100644 index dc1dd304f68e8..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py +++ /dev/null @@ -1,51 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -import torch - -from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig - - -class TimerXLInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-XL model inference. - """ - - def __init__(self, model_config: TimerConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) - if len(inputs.shape) != 2: - raise InferenceModelInternalError( - f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" - ) - # Considering that we are currently using the generate function interface, it seems that no pre-processing is required - return inputs - - def post_decode(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass - - def post_inference(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index a67d576b0ec8c..6f022f0ddd1e4 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -37,21 +37,11 @@ InferenceRequest, InferenceRequestProxy, ) +from iotdb.ainode.core.inference.pipeline import get_pipeline from iotdb.ainode.core.inference.pool_controller import PoolController -from iotdb.ainode.core.inference.strategy.timer_sundial_inference_pipeline import ( - TimerSundialInferencePipeline, -) -from iotdb.ainode.core.inference.strategy.timerxl_inference_pipeline import ( - TimerXLInferencePipeline, -) from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.modeling_timer import TimerForPrediction +from iotdb.ainode.core.manager.model_manager import get_model_manager from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.ainode.core.util.serde import convert_to_binary @@ -71,90 +61,13 @@ logger = Logger() -class InferenceStrategy(ABC): - def __init__(self, model): - self.model = model - - @abstractmethod - def infer(self, full_data, **kwargs): - pass - - -# [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length], -# we only get valueList currently. -class TimerXLStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) - df = pd.DataFrame(output[0]) - return convert_to_binary(df) - - -class SundialStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate( - seqs, max_new_tokens=predict_length, num_samples=10, revin=True - ) - df = pd.DataFrame(output[0].mean(dim=0)) - return convert_to_binary(df) - - -class BuiltInStrategy(InferenceStrategy): - def infer(self, full_data, **_): - data = pd.DataFrame(full_data[1]).T - output = self.model.inference(data) - df = pd.DataFrame(output) - return convert_to_binary(df) - - -class RegisteredStrategy(InferenceStrategy): - def infer(self, full_data, window_interval=None, window_step=None, **_): - _, dataset, _, length = full_data - if window_interval is None or window_step is None: - window_interval = length - window_step = float("inf") - - if window_interval <= 0 or window_step <= 0 or window_interval > length: - raise InvalidWindowArgumentError(window_interval, window_step, length) - - data = torch.tensor(dataset, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1) - - times = int((length - window_interval) // window_step + 1) - results = [] - try: - for i in range(times): - start = 0 if window_step == float("inf") else i * window_step - end = start + window_interval - window = data[:, start:end, :] - out = self.model(window) - df = pd.DataFrame(out.squeeze(0).detach().numpy()) - results.append(df) - except Exception as e: - msg = runtime_error_extractor(str(e)) or str(e) - raise InferenceModelInternalError(msg) - - # concatenate or return first window for forecast - return [convert_to_binary(df) for df in results] - - class InferenceManager: WAITING_INTERVAL_IN_MS = ( AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms() ) # How often to check for requests in the result queue def __init__(self): - self._model_manager = ModelManager() + self._model_manager = get_model_manager() self._model_mem_usage_map: Dict[str, int] = ( {} ) # store model memory usage for each model @@ -251,15 +164,6 @@ def _process_request(self, req): with self._result_wrapper_lock: del self._result_wrapper_map[req_id] - def _get_strategy(self, model_id, model): - if isinstance(model, TimerForPrediction): - return TimerXLStrategy(model) - if isinstance(model, SundialForPrediction): - return SundialStrategy(model) - if self._model_manager.model_storage.is_built_in_or_fine_tuned(model_id): - return BuiltInStrategy(model) - return RegisteredStrategy(model) - def _run( self, req, @@ -272,9 +176,17 @@ def _run( model_id = req.modelId try: raw = data_getter(req) + # full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently. full_data = deserializer(raw) - inference_attrs = extract_attrs(req) + # TODO: TSBlock -> Tensor codes should be unified + data = full_data[1][0] # get valueList in ndarray + if data.dtype.byteorder not in ("=", "|"): + np_data = data.byteswap() + data = np_data.view(np_data.dtype.newbyteorder()) + # the inputs should be on CPU before passing to the inference request + inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") + inference_attrs = extract_attrs(req) predict_length = int(inference_attrs.pop("predict_length", 96)) if ( predict_length @@ -290,41 +202,20 @@ def _run( ) if self._pool_controller.has_request_pools(model_id): - # use request pool to accelerate inference when the model instance is already loaded. - # TODO: TSBlock -> Tensor codes should be unified - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - # the inputs should be on CPU before passing to the inference request - inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - model_type = self._model_manager.get_model_info(model_id).model_type - if model_type == BuiltInModelType.SUNDIAL.value: - inference_pipeline = TimerSundialInferencePipeline(SundialConfig()) - elif model_type == BuiltInModelType.TIMER_XL.value: - inference_pipeline = TimerXLInferencePipeline(TimerConfig()) - else: - raise InferenceModelInternalError( - f"Unsupported model_id: {model_id}" - ) infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id, inputs=inputs, - inference_pipeline=inference_pipeline, max_new_tokens=predict_length, ) outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) else: - # load model - accel = str(inference_attrs.get("acceleration", "")).lower() == "true" - model = self._model_manager.load_model(model_id, inference_attrs, accel) - # inference by strategy - strategy = self._get_strategy(model_id, model) - outputs = strategy.infer( - full_data, predict_length=predict_length, **inference_attrs + inference_pipeline = get_pipeline(model_id, device="cpu") + outputs = inference_pipeline.infer( + inputs, predict_length=predict_length, **inference_attrs ) + outputs = convert_to_binary(pd.DataFrame(outputs[0])) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py new file mode 100644 index 0000000000000..a07552922ff36 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from typing import Any, List, Optional + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_loader import ModelLoader +from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage +from iotdb.ainode.core.rpc.status import get_status +from iotdb.thrift.ainode.ttypes import ( + TDeleteModelReq, + TRegisterModelReq, + TRegisterModelResp, + TShowModelsReq, + TShowModelsResp, +) +from iotdb.thrift.common.ttypes import TSStatus + +logger = Logger() + + +class ModelManager: + def __init__(self): + self.models_dir = os.path.join( + os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() + ) + self.storage = ModelStorage(models_dir=self.models_dir) + self.loader = ModelLoader(storage=self.storage) + + # Automatically discover all models + self._models = self.storage.discover_all() + + def register_model( + self, + req: TRegisterModelReq, + ) -> TRegisterModelResp: + try: + success = self.storage.register_model(model_id=req.modelId, uri=req.uri) + if success: + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + else: + return TRegisterModelResp( + get_status(TSStatusCode.AINODE_INTERNAL_ERROR) + ) + except ValueError as e: + return TRegisterModelResp( + get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) + ) + except Exception as e: + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + return self.storage.show_models(req) + + def delete_model(self, req: TDeleteModelReq) -> TSStatus: + try: + self.storage.delete_model(req.modelId) + return get_status(TSStatusCode.SUCCESS_STATUS) + except BuiltInModelDeletionError as e: + logger.warning(e) + return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + except Exception as e: + logger.warning(e) + return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + + def load_model(self, model_id: str, **kwargs) -> Any: + return self.loader.load_model(model_id=model_id, **kwargs) + + def get_model_info( + self, + model_id: str, + category: Optional[ModelCategory] = None, + ) -> Optional[ModelInfo]: + return self.storage.get_model_info(model_id, category) + + def get_model_infos( + self, + category: Optional[ModelCategory] = None, + model_type: Optional[str] = None, + ) -> List[ModelInfo]: + return self.storage.get_model_infos(category, model_type) + + def refresh(self): + """Refresh the model list (re-scan the file system)""" + self._models = self.storage.discover_all() + + def get_registered_models(self) -> List[str]: + return self.storage.get_registered_models() + + def is_model_registered(self, model_id: str) -> bool: + return self.storage.is_model_registered(model_id) + + +# Create a global model manager instance +_default_manager: Optional[ModelManager] = None + + +def get_model_manager() -> ModelManager: + global _default_manager + if _default_manager is None: + _default_manager = ModelManager() + return _default_manager diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 0264e27331a86..297a2b832d7a8 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -24,8 +24,7 @@ from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP +from iotdb.ainode.core.manager.model_manager import get_model_manager logger = Logger() @@ -47,7 +46,7 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id, {}).to(device) + model = get_model_manager().load_model(model_id, {}).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start @@ -80,8 +79,8 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = BUILT_IN_LTSM_MAP.get(model_id, None) - if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: + model_info = get_model_manager.get_model_info(model_id) + if model_info is None or model_info.model_id not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." ) @@ -90,9 +89,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] - mem_usage = ( - MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO - ) + mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage) if size <= 0: logger.error( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/model/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py new file mode 100644 index 0000000000000..a6a234a1ab8f5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from enum import Enum + + +class ModelCategory(Enum): + BUILTIN = "builtin" + USER_DEFINED = "user_defined" + FINETUNE = "finetune" + + +class ModelStates(Enum): + INACTIVE = "inactive" + ACTIVATING = "activating" + ACTIVE = "active" + LOADING = "loading" + LOADED = "loaded" + DROPPING = "dropping" + TRAINING = "training" + FAILED = "failed" + + +class ModelFileType(Enum): + SAFETENSORS = "safetensors" + PYTORCH = "pytorch" + UNKNOWN = "unknown" + + +class UriType(Enum): + REPO = "repo" + FILE = "file" + + +# Map for inferring which HuggingFace repository to download from based on model ID +REPO_ID_MAP = { + "timerxl": "thuml/timer-base-84m", + "sundial": "thuml/sundial-base-128m", + # More mappings can be added as needed +} + +# Model file constants +MODEL_CONFIG_FILE = "config.json" +MODEL_WEIGHTS_FILE = "model.safetensors" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py new file mode 100644 index 0000000000000..f36ad582a837b --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import Dict, List, Optional, Tuple + +from iotdb.ainode.core.model.model_enums import ModelCategory, ModelStates + +# Map for inferring which HuggingFace repository to download from based on model ID +REPO_ID_MAP = { + "timerxl": "thuml/timer-base-84m", + "sundial": "thuml/sundial-base-128m", + # More mappings can be added as needed +} + +# Model file constants +MODEL_CONFIG_FILE = "config.json" +MODEL_WEIGHTS_FILE = "model.safetensors" + + +class ModelInfo: + def __init__( + self, + model_id: str, + model_type: str, + category: ModelCategory, + state: ModelStates, + path: str = "", + auto_map: Optional[Dict] = None, + _transformers_registered: bool = False, + ): + self.model_id = model_id + self.model_type = model_type + self.category = category + self.state = state + self.path = path + self.auto_map = auto_map # If exists, indicates it's a Transformers model + self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + + def __repr__(self): + return ( + f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " + f"category={self.category.value}, state={self.state.value}, " + f"path={self.path}, has_auto_map={self.auto_map is not None})" + ) + + +BUILTIN_SKTIME_MODEL_MAP = { + # forecast models + "arima": ModelInfo( + model_id="arima", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + "holtwinters": ModelInfo( + model_id="holtwinters", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + "exponential_smoothing": ModelInfo( + model_id="exponential_smoothing", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + "naive_forecaster": ModelInfo( + model_id="naive_forecaster", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + "stl_forecaster": ModelInfo( + model_id="stl_forecaster", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + # anomaly detection models + "gaussian_hmm": ModelInfo( + model_id="gaussian_hmm", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + "gmm_hmm": ModelInfo( + model_id="gmm_hmm", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), + "stray": ModelInfo( + model_id="stray", + model_type="sktime", + category=ModelCategory.BUILTIN, + state=ModelStates.ACTIVE, + ), +} + +# Built-in huggingface transformers models, their weights are not included in AINode by default +BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { + "timerxl": ModelInfo( + model_id="timerxl", + model_type="timer", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + ), + "sundial": ModelInfo( + model_id="sundial", + model_type="sundial", + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + ), +} diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py new file mode 100644 index 0000000000000..9b70b5f80401e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from typing import Any + +import torch +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForNextSentencePrediction, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTimeSeriesPrediction, + AutoModelForTokenClassification, +) + +from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_enums import ModelCategory +from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelStorage +from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model + +logger = Logger() + + +class ModelLoader: + """Model loader - unified interface for loading different types of models""" + + def __init__(self, storage: ModelStorage): + self.storage = storage + + def load_model(self, model_id: str, **kwargs) -> Any: + # Lazy registration: if it's a Transformers model and not registered, register it first + model_info = self.storage.ensure_transformers_registered(model_id) + if not model_info: + logger.error( + f"Model {model_id} failed to register to Transformers, cannot load." + ) + return None + + if model_info.auto_map is not None: + model = self.load_model_from_transformers(model_info, **kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_id) + else: + model = self.load_model_from_pt(model_info, **kwargs) + + logger.info(f"Model {model_id} loaded to device {model.device} successfully.") + return model + + def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): + model_config, load_class = None, None + device_map = kwargs.get("device_map", "cpu") + trust_remote_code = kwargs.get("trust_remote_code", True) + train_from_scratch = kwargs.get("train_from_scratch", False) + + if model_info.category == ModelCategory.BUILTIN: + if model_info.model_id == "timerxl": + from iotdb.ainode.core.model.timerxl.configuration_timer import ( + TimerConfig, + ) + + model_config = TimerConfig() + from iotdb.ainode.core.model.timerxl.modeling_timer import ( + TimerForPrediction, + ) + + load_class = TimerForPrediction + elif model_info.model_id == "sundial": + from iotdb.ainode.core.model.sundial.configuration_sundial import ( + SundialConfig, + ) + + model_config = SundialConfig() + from iotdb.ainode.core.model.sundial.modeling_sundial import ( + SundialForPrediction, + ) + + load_class = SundialForPrediction + else: + logger.error( + f"Unsupported built-in Transformers model {model_info.model_id}." + ) + else: + model_config = AutoConfig.from_pretrained(model_info.path) + if ( + type(model_config) + in AutoModelForTimeSeriesPrediction._model_mapping.keys() + ): + load_class = AutoModelForTimeSeriesPrediction + elif ( + type(model_config) + in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + load_class = AutoModelForNextSentencePrediction + elif type(model_config) in AutoModelForSeq2SeqLM._model_mapping.keys(): + load_class = AutoModelForSeq2SeqLM + elif ( + type(model_config) + in AutoModelForSequenceClassification._model_mapping.keys() + ): + load_class = AutoModelForSequenceClassification + elif ( + type(model_config) + in AutoModelForTokenClassification._model_mapping.keys() + ): + load_class = AutoModelForTokenClassification + else: + load_class = AutoModelForCausalLM + + if train_from_scratch: + model = load_class.from_config( + model_config, trust_remote_code=trust_remote_code, device_map=device_map + ) + else: + model = load_class.from_pretrained( + model_info.path, + trust_remote_code=trust_remote_code, + device_map=device_map, + ) + + return model + + def load_model_from_pt(self, model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + acceleration = kwargs.get("acceleration", False) + model_path = os.path.join(model_info.path, "model.pt") + if not os.path.exists(model_path): + logger.error(f"Model file not found at {model_path}.") + raise ModelNotExistError(model_path) + model = torch.jit.load(model_path) + if ( + isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + or not acceleration + ): + return model + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + return model.to(device_map) + + def load_model_for_efficient_inference(self): + # TODO: An efficient model loading method for inference based on model_arguments + pass + + def load_model_for_powerful_finetune(self): + # TODO: An powerful model loading method for finetune based on model_arguments + pass + + def unload_model(self): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py new file mode 100644 index 0000000000000..5ba67e158a5dc --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -0,0 +1,576 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import concurrent.futures +import os +import shutil +from typing import List, Optional + +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM + +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_enums import REPO_ID_MAP, ModelCategory, ModelStates +from iotdb.ainode.core.model.model_info import ( + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, + ModelInfo, +) +from iotdb.ainode.core.model.utils import * +from iotdb.ainode.core.util.lock import ModelLockPool +from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp +from iotdb.thrift.common.ttypes import TSStatus + +logger = Logger() + + +class ModelStorage: + """Model storage class - unified management of model discovery and registration""" + + def __init__(self, models_dir: str): + self.models_dir = Path(models_dir) + # Unified storage: category -> {model_id -> ModelInfo} + self._models: Dict[str, Dict[str, ModelInfo]] = { + ModelCategory.BUILTIN.value: {}, + ModelCategory.USER_DEFINED.value: {}, + ModelCategory.FINETUNE.value: {}, + } + # Async download executor (using single-threaded executor because hf download interface is unstable with concurrent downloads) + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Thread lock pool for protecting concurrent access to model information + self._lock_pool = ModelLockPool() + self._initialize_directories() + + def _initialize_directories(self): + """Initialize directory structure and ensure __init__.py files exist""" + self.models_dir.mkdir(parents=True, exist_ok=True) + ensure_init_file(self.models_dir) + + for category in ModelCategory: + category_path = self.models_dir / category.value + category_path.mkdir(parents=True, exist_ok=True) + ensure_init_file(category_path) + + # ==================== Discovery Methods ==================== + + def discover_all(self) -> Dict[str, Dict[str, ModelInfo]]: + """Scan file system to discover all models""" + self._discover_category(ModelCategory.BUILTIN) + self._discover_category(ModelCategory.USER_DEFINED) + self._discover_category(ModelCategory.FINETUNE) + return self._models + + def _discover_category(self, category: ModelCategory): + """Discover all models in a category directory""" + category_path = self.models_dir / category.value + if not category_path.exists(): + return + + if category == ModelCategory.BUILTIN: + self._discover_builtin_models(category_path) + else: + # For finetune and user_defined, scan directories + for item in category_path.iterdir(): + if item.is_dir() and not item.name.startswith("__"): + relative_path = item.relative_to(category_path) + model_id = str(relative_path).replace("/", "_").replace("\\", "_") + self._process_model_directory(item, model_id, category) + + def _discover_builtin_models(self, category_path: Path): + # Register SKTIME models directly from map + for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_SKTIME_MODEL_MAP[model_id] + ) + + # Process HuggingFace Transformers models + for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): + model_dir = category_path / model_id + model_dir.mkdir(parents=True, exist_ok=True) + self._process_model_directory(model_dir, model_id, ModelCategory.BUILTIN) + + def _process_model_directory( + self, model_dir: Path, model_id: str, category: ModelCategory + ): + """Handling the discovery logic for a single model directory.""" + ensure_init_file(model_dir) + + config_path = model_dir / MODEL_CONFIG_FILE + weights_path = model_dir / MODEL_WEIGHTS_FILE + needs_download = not config_path.exists() or not weights_path.exists() + + if needs_download: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type="", # Read from config.json after download + category=category, + state=ModelStates.ACTIVATING, + path=str(model_dir), + auto_map=None, + _transformers_registered=False, + ) + self._models[category.value][model_id] = model_info + + future = self._executor.submit( + self._download_model_if_necessary, str(model_dir), model_id + ) + future.add_done_callback( + lambda f, mid=model_id, cat=category: self._callback_model_download_result( + f, mid, cat + ) + ) + else: + config = load_model_config(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type=model_type, + category=category, + state=ModelStates.ACTIVE, + path=str(model_dir), + auto_map=auto_map, + _transformers_registered=False, # Lazy registration + ) + self._models[category.value][model_id] = model_info + + def _download_model_if_necessary(self, model_dir: str, model_id: str) -> bool: + """Returns: True if the model is existed or downloaded successfully, False otherwise.""" + if model_id in REPO_ID_MAP: + repo_id = REPO_ID_MAP[model_id] + else: + logger.error(f"Model {model_id} not found in REPO_ID_MAP") + return False + + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE) + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE) + + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE, + local_dir=model_dir, + ) + except Exception as e: + logger.error(f"Failed to download model weights from HuggingFace: {e}") + return False + + if not os.path.exists(config_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE, + local_dir=model_dir, + ) + except Exception as e: + logger.error(f"Failed to download model config from HuggingFace: {e}") + return False + + return True + + def _callback_model_download_result( + self, future, model_id: str, category: ModelCategory + ): + """Callback function for handling model download results""" + with self._lock_pool.get_lock(model_id).write_lock(): + try: + if future.result(): + if model_id in self._models[category.value]: + model_info = self._models[category.value][model_id] + model_info.state = ModelStates.ACTIVE + config_path = os.path.join(model_info.path, MODEL_CONFIG_FILE) + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + model_info.model_type = config.get("model_type", "") + model_info.auto_map = config.get("auto_map") + logger.info( + f"Model {model_id} downloaded successfully and is ready to use." + ) + else: + if model_id in self._models[category.value]: + self._models[category.value][ + model_id + ].state = ModelStates.INACTIVE + logger.warning(f"Failed to download model {model_id}.") + except Exception as e: + logger.error(f"Error in download callback for model {model_id}: {e}") + if model_id in self._models[category.value]: + self._models[category.value][model_id].state = ModelStates.INACTIVE + + # ==================== Registration Methods ==================== + + def register_model(self, model_id: str, uri: str) -> bool: + """ + Supported URI formats: + - repo:// + - file:// + """ + uri_type = parse_uri_type(uri) + parsed_uri = get_parsed_uri(uri) + + model_dir = Path(self.models_dir) / "user_defined" / model_id + model_dir.mkdir(parents=True, exist_ok=True) + ensure_init_file(model_dir) + + if uri_type == UriType.REPO: + self._fetch_model_from_hf_repo(parsed_uri, str(model_dir)) + else: + self._fetch_model_from_local(os.path.expanduser(parsed_uri), str(model_dir)) + + config_path, _ = validate_model_files(model_dir) + config = load_model_config(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type=model_type, + category=ModelCategory.USER_DEFINED, + state=ModelStates.ACTIVE, + path=str(model_dir), + auto_map=auto_map, + _transformers_registered=False, # Register later + ) + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + if auto_map: + # Transformers model: immediately register to Transformers auto-loading mechanism + success = self._register_transformers_model(model_info) + if success: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info._transformers_registered = True + else: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info.state = ModelStates.INACTIVE + logger.error(f"Failed to register Transformers model {model_id}") + return False + else: + # Other type models: only log + self._register_other_model(model_info) + + logger.info(f"Successfully registered model {model_id} from URI: {uri}") + return True + + def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" + ) + + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise + + def _fetch_model_from_local(self, source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") + + source_dir = Path(source_path) + if not source_dir.is_dir(): + raise ValueError( + f"Source path does not exist or is not a directory: {source_path}" + ) + + source_config = source_dir / MODEL_CONFIG_FILE + source_weights = source_dir / MODEL_WEIGHTS_FILE + if not source_config.exists(): + raise ValueError( + f"Config file missing in source directory: {source_config}" + ) + if not source_weights.exists(): + raise ValueError( + f"Weights file missing in source directory: {source_weights}" + ) + + # Copy all files + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + + def _register_transformers_model(self, model_info: ModelInfo) -> bool: + """ + Register Transformers model to auto-loading mechanism (internal method) + """ + auto_map = model_info.auto_map + if not auto_map: + return False + + auto_config_path = auto_map.get("AutoConfig") + auto_model_path = auto_map.get("AutoModelForCausalLM") + + try: + module_parent = str(Path(model_info.path).parent.absolute()) + with temporary_sys_path(module_parent): + config_class = import_class_from_path( + model_info.model_id, auto_config_path + ) + AutoConfig.register(model_info.model_type, config_class) + logger.info( + f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}" + ) + + model_class = import_class_from_path( + model_info.model_id, auto_model_path + ) + AutoModelForCausalLM.register(config_class, model_class) + logger.info( + f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" + ) + + return True + except Exception as e: + logger.warning( + f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct." + ) + return False + + def _register_other_model(self, model_info: ModelInfo): + """Register other type models (non-Transformers models)""" + logger.info( + f"Registered other type model: {model_info.model_id} ({model_info.model_type})" + ) + + def ensure_transformers_registered(self, model_id: str) -> "ModelInfo": + """ + Ensure Transformers model is registered (called for lazy registration) + This method uses locks to ensure thread safety. All check logic is within lock protection. + Returns: + str: If None, registration failed, otherwise returns model path + """ + # Use lock to protect entire check-execute process + with self._lock_pool.get_lock(model_id).write_lock(): + # Directly access _models dictionary (avoid calling get_model_info which may cause deadlock) + model_info = None + for category_dict in self._models.values(): + if model_id in category_dict: + model_info = category_dict[model_id] + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot register") + return None + + # If already registered, return directly + if model_info._transformers_registered: + return model_info + + # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) + if ( + not model_info.auto_map + or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() + ): + model_info._transformers_registered = True + return model_info + + # Execute registration (under lock protection) + try: + success = self._register_transformers_model(model_info) + if success: + model_info._transformers_registered = True + logger.info( + f"Model {model_id} successfully registered to Transformers" + ) + return model_info + else: + model_info.state = ModelStates.INACTIVE + logger.error(f"Model {model_id} failed to register to Transformers") + return None + + except Exception as e: + # Ensure state consistency in exception cases + model_info.state = ModelStates.INACTIVE + model_info._transformers_registered = False + logger.error( + f"Exception occurred while registering model {model_id} to Transformers: {e}" + ) + return None + + # ==================== Show and Delete Models ==================== + + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + resp_status = TSStatus( + code=TSStatusCode.SUCCESS_STATUS.value, + message="Show models successfully", + ) + + # Use global lock to protect entire dictionary structure + with self._lock_pool.get_lock("").read_lock(): + if req.modelId: + # Find specified model + model_info = None + for category_dict in self._models.values(): + if req.modelId in category_dict: + model_info = category_dict[req.modelId] + break + + if model_info: + return TShowModelsResp( + status=resp_status, + modelIdList=[req.modelId], + modelTypeMap={req.modelId: model_info.model_type}, + categoryMap={req.modelId: model_info.category.value}, + stateMap={req.modelId: model_info.state.value}, + ) + else: + return TShowModelsResp( + status=resp_status, + modelIdList=[], + modelTypeMap={}, + categoryMap={}, + stateMap={}, + ) + else: + # Return all models + model_id_list = [] + model_type_map = {} + category_map = {} + state_map = {} + + for category_dict in self._models.values(): + for model_id, model_info in category_dict.items(): + model_id_list.append(model_id) + model_type_map[model_id] = model_info.model_type + category_map[model_id] = model_info.category.value + state_map[model_id] = model_info.state.value + + return TShowModelsResp( + status=resp_status, + modelIdList=model_id_list, + modelTypeMap=model_type_map, + categoryMap=category_map, + stateMap=state_map, + ) + + def delete_model(self, model_id: str) -> None: + # Use write lock to protect entire deletion process + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = None + category_value = None + for cat_value, category_dict in self._models.items(): + if model_id in category_dict: + model_info = category_dict[model_id] + category_value = cat_value + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot delete") + return + + if model_info.category == ModelCategory.BUILTIN: + raise BuiltInModelDeletionError(model_id) + + model_path = Path(model_info.path) + if model_path.exists(): + try: + shutil.rmtree(model_path) + logger.info(f"Deleted model directory: {model_path}") + except Exception as e: + logger.error(f"Failed to delete model directory {model_path}: {e}") + raise + + if category_value and model_id in self._models[category_value]: + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from storage") + + return + + # ==================== Query Methods ==================== + + def get_model_info( + self, model_id: str, category: Optional[ModelCategory] = None + ) -> Optional[ModelInfo]: + """ + Get single model information + + If category is specified, use model_id's lock + If category is not specified, need to traverse all dictionaries, use global lock + """ + if category: + # Category specified, only need to access specific dictionary, use model_id's lock + with self._lock_pool.get_lock(model_id).read_lock(): + return self._models[category.value].get(model_id) + else: + # Category not specified, need to traverse all dictionaries, use global lock + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return category_dict[model_id] + return None + + def get_model_infos( + self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None + ) -> List[ModelInfo]: + """ + Get model information list + + Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure + For single model access, using model_id-based lock would be more efficient + """ + matching_models = [] + + # For traversal operations, we need to protect the entire dictionary structure + # Use a special lock (using empty string as key) to protect the entire dictionary + with self._lock_pool.get_lock("").read_lock(): + if category and model_type: + for model_info in self._models[category.value].values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models + elif category: + return list(self._models[category.value].values()) + elif model_type: + for category_dict in self._models.values(): + for model_info in category_dict.values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models + else: + for category_dict in self._models.values(): + matching_models.extend(category_dict.values()) + return matching_models + + def is_model_registered(self, model_id: str) -> bool: + """Check if model is registered (search in _models)""" + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return True + return False + + def get_registered_models(self) -> List[str]: + """Get list of all registered model IDs""" + with self._lock_pool.get_lock("").read_lock(): + model_ids = [] + for category_dict in self._models.values(): + model_ids.extend(category_dict.keys()) + return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json new file mode 100644 index 0000000000000..dcdc133529090 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "name": "ARIMA", + "predict_length": 1, + "order": [1, 0, 0], + "seasonal_order": [0, 0, 0, 0], + "method": "lbfgs", + "maxiter": 1, + "suppress_warnings": true, + "out_of_sample_size": 0, + "scoring": "mse", + "with_intercept": true, + "time_varying_regression": false, + "enforce_stationarity": true, + "enforce_invertibility": true, + "simple_differencing": false, + "measurement_error": false, + "mle_regression": true, + "hamilton_representation": false, + "concentrate_scale": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py new file mode 100644 index 0000000000000..bd780da3a73fb --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -0,0 +1,379 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Sktime model configuration module - simplified version +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "pipeline": AttributeConfig( + "pipeline", "last", "str", choices=["last", "mean"] + ), + "sp": AttributeConfig("sp", 1, "int", 1, 5000), + }, + "EXPONENTIAL_SMOOTHING": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "damped_trend": AttributeConfig("damped_trend", False, "bool"), + "initialization_method": AttributeConfig( + "initialization_method", + "estimated", + "str", + choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + "optimized": AttributeConfig("optimized", True, "bool"), + "remove_bias": AttributeConfig("remove_bias", False, "bool"), + "use_brute": AttributeConfig("use_brute", False, "bool"), + }, + "ARIMA": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "seasonal_order": AttributeConfig( + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + ), + "method": AttributeConfig( + "method", + "lbfgs", + "str", + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + "maxiter": AttributeConfig("maxiter", 1, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", True, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), + }, + "STL_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "sp": AttributeConfig("sp", 2, "int", 1, 5000), + "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), + "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), + "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), + "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), + "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + }, + "GAUSSIAN_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 1e-2, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1.0, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), + "init_params": AttributeConfig( + "init_params", "stmc", "str", choices=["stmc", "stm"] + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "GMM_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["sperical", "diag", "full", "tied"], + ), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "init_params": AttributeConfig( + "init_params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "params": AttributeConfig( + "params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "STRAY": { + "alpha": AttributeConfig("alpha", 0.01, "float", -1e10, 1e10), + "k": AttributeConfig("k", 10, "int", 1, 5000), + "knn_algorithm": AttributeConfig( + "knn_algorithm", + "brute", + "str", + choices=["brute", "kd_tree", "ball_tree", "auto"], + ), + "p": AttributeConfig("p", 0.5, "float", -1e10, 1e10), + "size_threshold": AttributeConfig("size_threshold", 50, "int", 1, 5000), + "outlier_tail": AttributeConfig( + "outlier_tail", "max", "str", choices=["min", "max"] + ), + }, +} + + +def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: + """Get attribute configuration for Sktime model""" + model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id + if model_id not in MODEL_CONFIGS: + raise BuiltInModelNotSupportError(model_id) + return MODEL_CONFIGS[model_id] + + +def update_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, AttributeConfig] +) -> Dict[str, Any]: + """Update Sktime model attributes using input attributes""" + attributes = {} + for name, config in attribute_map.items(): + if name in input_attributes: + value = config.parse(input_attributes[name]) + config.validate_value(value) + attributes[name] = value + else: + attributes[name] = config.default + return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json new file mode 100644 index 0000000000000..d6002fb26e87a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "name": "ExponentialSmoothing", + "predict_length": 1, + "damped_trend": false, + "initialization_method": "estimated", + "optimized": true, + "remove_bias": false, + "use_brute": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json new file mode 100644 index 0000000000000..3392e1c0b57c8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json @@ -0,0 +1,20 @@ +{ + "model_type": "sktime", + "name": "GaussianHMM", + "n_components": 1, + "covariance_type": "diag", + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "covars_prior": 0.01, + "covars_weight": 1.0, + "algorithm": "viterbi", + "n_iter": 10, + "tol": 0.01, + "params": "stmc", + "init_params": "stmc", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json new file mode 100644 index 0000000000000..235f8ae642da4 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json @@ -0,0 +1,20 @@ +{ + "model_type": "sktime", + "name": "GMMHMM", + "n_components": 1, + "n_mix": 1, + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "weights_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "algorithm": "viterbi", + "covariance_type": "diag", + "n_iter": 10, + "tol": 0.01, + "init_params": "stmcw", + "params": "stmcw", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py new file mode 100644 index 0000000000000..f272e3dda3579 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Sktime model implementation module - simplified version +""" + +from abc import abstractmethod +from typing import Any, Dict + +import numpy as np +from sklearn.preprocessing import MinMaxScaler +from sktime.detection.hmm_learn import GMMHMM, GaussianHMM +from sktime.detection.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + InferenceModelInternalError, +) +from iotdb.ainode.core.log import Logger + +from .configuration_sktime import get_attributes, update_attribute + +logger = Logger() + + +class SktimeModel: + """Base class for Sktime models""" + + def __init__(self, attributes: Dict[str, Any]): + self._attributes = attributes + self._model = None + + @abstractmethod + def generate(self, data): + """Execute generation/inference""" + raise NotImplementedError + + +class ForecastingModel(SktimeModel): + """Base class for forecasting models""" + + def generate(self, data): + """Execute forecasting""" + try: + predict_length = self._attributes["predict_length"] + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + return np.array(output, dtype=np.float64) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class DetectionModel(SktimeModel): + """Base class for detection models""" + + def generate(self, data): + """Execute detection""" + try: + self._model.fit(data) + output = self._model.predict(data) + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ArimaModel(ForecastingModel): + """ARIMA model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ARIMA( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class ExponentialSmoothingModel(ForecastingModel): + """Exponential smoothing model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ExponentialSmoothing( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class NaiveForecasterModel(ForecastingModel): + """Naive forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = NaiveForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class STLForecasterModel(ForecastingModel): + """STL forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STLForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class GMMHMMModel(DetectionModel): + """GMM HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GMMHMM(**attributes) + + +class GaussianHmmModel(DetectionModel): + """Gaussian HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GaussianHMM(**attributes) + + +class STRAYModel(DetectionModel): + """STRAY anomaly detection model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STRAY(**attributes) + + def generate(self, data): + """STRAY requires special handling: normalize first""" + try: + data = MinMaxScaler().fit_transform(data) + output = self._model.fit_transform(data) + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +# Model factory mapping +_MODEL_FACTORY = { + "ARIMA": ArimaModel, + "EXPONENTIAL_SMOOTHING": ExponentialSmoothingModel, + "HOLTWINTERS": ExponentialSmoothingModel, # Use the same model class + "NAIVE_FORECASTER": NaiveForecasterModel, + "STL_FORECASTER": STLForecasterModel, + "GMM_HMM": GMMHMMModel, + "GAUSSIAN_HMM": GaussianHmmModel, + "STRAY": STRAYModel, +} + + +def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: + """Create a Sktime model instance""" + attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) + model_class = _MODEL_FACTORY.get(model_id.upper()) + if model_class is None: + raise BuiltInModelNotSupportError(model_id) + return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json new file mode 100644 index 0000000000000..20d8c1ed32b5c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json @@ -0,0 +1,8 @@ +{ + "model_type": "sktime", + "name": "NaiveForecaster", + "predict_length": 1, + "strategy": "last", + "sp": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json new file mode 100644 index 0000000000000..1005f9d944e9d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json @@ -0,0 +1,14 @@ +{ + "model_type": "sktime", + "name": "STLForecaster", + "predict_length": 1, + "sp": 2, + "seasonal": 7, + "seasonal_deg": 1, + "trend_deg": 1, + "low_pass_deg": 1, + "seasonal_jump": 1, + "trend_jump": 1, + "low_pass_jump": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json new file mode 100644 index 0000000000000..64c64aa9e0514 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "name": "STRAY", + "alpha": 0.01, + "k": 10, + "knn_algorithm": "brute", + "p": 0.5, + "size_threshold": 50, + "outlier_tail": "max" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py new file mode 100644 index 0000000000000..21eefef2933b3 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/configuration_sundial.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List + +from transformers import PretrainedConfig + + +class SundialConfig(PretrainedConfig): + model_type = "sundial" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + input_token_len: int = 16, + hidden_size: int = 768, + intermediate_size: int = 3072, + output_token_lens: List[int] = [720], + num_hidden_layers: int = 12, + num_attention_heads: int = 12, + hidden_act: str = "silu", + use_cache: bool = True, + rope_theta: int = 10000, + dropout_rate: float = 0.1, + initializer_range: float = 0.02, + max_position_embeddings: int = 10000, + flow_loss_depth: int = 3, + num_sampling_steps: int = 50, + diffusion_batch_mul: int = 4, + **kwargs, + ): + self.input_token_len = input_token_len + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.output_token_lens = output_token_lens + self.use_cache = use_cache + self.rope_theta = rope_theta + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + self.flow_loss_depth = flow_loss_depth + self.num_sampling_steps = num_sampling_steps + self.diffusion_batch_mul = diffusion_batch_mul + + super().__init__( + **kwargs, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py new file mode 100644 index 0000000000000..b3fe95dbe2d27 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/flow_loss.py @@ -0,0 +1,255 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import math + +import torch +import torch.nn as nn + + +class FlowLoss(nn.Module): + """Flow Loss""" + + def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps): + super(FlowLoss, self).__init__() + self.in_channels = target_channels + self.net = SimpleMLPAdaLN( + in_channels=target_channels, + model_channels=width, + out_channels=target_channels, + z_channels=z_channels, + num_res_blocks=depth, + ) + self.num_sampling_steps = num_sampling_steps + + def forward(self, target, z, mask=None, mask_y=None): + noise = torch.randn_like(target) + t = torch.rand(target.shape[0], device=target.device) + + noised_target = t[:, None] * target + (1 - t[:, None]) * noise + + predict_v = self.net(noised_target, t * 1000, z) + + weights = 1.0 / torch.arange( + 1, self.in_channels + 1, dtype=torch.float32, device=target.device + ) + if mask_y is not None: + loss = (mask_y * weights * (predict_v - target) ** 2).sum(dim=-1) + else: + loss = (weights * (predict_v - target) ** 2).sum(dim=-1) + + if mask is not None: + loss = (loss * mask).sum() / mask.sum() + return loss.mean() + + def sample(self, z, num_samples=1): + z = z.repeat(num_samples, 1) + noise = torch.randn(z.shape[0], self.in_channels).to(z.device) + x = noise + dt = 1.0 / self.num_sampling_steps + for i in range(self.num_sampling_steps): + t = (torch.ones((x.shape[0])) * i / self.num_sampling_steps).to(x.device) + pred = self.net(x, t * 1000, z) + x = x + (pred - noise) * dt + x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1) + return x + + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + """ + + def __init__(self, channels): + super().__init__() + self.channels = channels + + self.in_ln = nn.LayerNorm(channels, eps=1e-6) + self.mlp = nn.Sequential( + nn.Linear(channels, channels, bias=True), + nn.SiLU(), + nn.Linear(channels, channels, bias=True), + ) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True) + ) + + def forward(self, x, y): + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) + h = modulate(self.in_ln(x), shift_mlp, scale_mlp) + h = self.mlp(h) + return x + gate_mlp * h + + +class FinalLayer(nn.Module): + """ + The final layer adopted from DiT. + """ + + def __init__(self, model_channels, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm( + model_channels, elementwise_affine=False, eps=1e-6 + ) + self.linear = nn.Linear(model_channels, out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class SimpleMLPAdaLN(nn.Module): + """ + The MLP for Diffusion Loss. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param z_channels: channels in the condition. + :param num_res_blocks: number of residual blocks per downsample. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + z_channels, + num_res_blocks, + ): + super().__init__() + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + + self.time_embed = TimestepEmbedder(model_channels) + self.cond_embed = nn.Linear(z_channels, model_channels) + + self.input_proj = nn.Linear(in_channels, model_channels) + + res_blocks = [] + for i in range(num_res_blocks): + res_blocks.append( + ResBlock( + model_channels, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + self.final_layer = FinalLayer(model_channels, out_channels) + + self.initialize_weights() + + def initialize_weights(self): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP + nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for block in self.res_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def forward(self, x, t, c): + """ + Apply the model to an input batch. + :param x: an [N x C] Tensor of inputs. + :param t: a 1-D batch of timesteps. + :param c: conditioning from AR transformer. + :return: an [N x C] Tensor of outputs. + """ + x = self.input_proj(x) + t = self.time_embed(t) + c = self.cond_embed(c) + y = t + c + + for block in self.res_blocks: + x = block(x, y) + + return self.final_layer(x, y) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py new file mode 100644 index 0000000000000..544193e4d9c65 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -0,0 +1,651 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file as load_safetensors +from torch import nn +from transformers import Cache, DynamicCache, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) + +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig +from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss +from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin + +logger = Logger() + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class SundialPatchEmbedding(nn.Module): + def __init__(self, config: SundialConfig): + super().__init__() + self.dropout = nn.Dropout(config.dropout_rate) + self.hidden_layer = nn.Linear( + config.input_token_len * 2, config.intermediate_size + ) + self.act = ACT2FN[config.hidden_act] + self.output_layer = nn.Linear(config.intermediate_size, config.hidden_size) + self.residual_layer = nn.Linear(config.input_token_len * 2, config.hidden_size) + self.input_token_len = config.input_token_len + + def forward(self, x): + mask = torch.ones_like(x, dtype=torch.float32) + input_length = x.shape[-1] + padding_length = ( + self.input_token_len - (input_length % self.input_token_len) + ) % self.input_token_len + x = F.pad(x, (padding_length, 0)) + mask = F.pad(mask, (padding_length, 0)) + x = x.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len) + mask = mask.unfold( + dimension=-1, size=self.input_token_len, step=self.input_token_len + ) + + x = torch.cat([x, mask], dim=-1) + hid = self.act(self.hidden_layer(x)) + out = self.dropout(self.output_layer(hid)) + res = self.residual_layer(x) + out = out + res + return out + + +class SundialRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class SundialAttention(nn.Module): + def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.attention_dropout = config.dropout_rate + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.rotary_emb = SundialRotaryEmbedding( + self.head_dim, max_position_embeddings=config.max_position_embeddings + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx + ) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + dropout_p=(self.attention_dropout if self.training else 0.0), + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class SundialMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +class SundialDecoderLayer(nn.Module): + def __init__(self, config: SundialConfig, layer_idx: int): + super().__init__() + self.self_attn = SundialAttention(config, layer_idx) + + self.ffn_layer = SundialMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.norm1 = torch.nn.LayerNorm(config.hidden_size) + self.norm2 = torch.nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + Optional[torch.Tensor], + Optional[Cache], + ]: + residual = hidden_states + + hidden_states = self.norm1(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.ffn_layer(hidden_states) + hidden_states = residual + hidden_states + + if not output_attentions: + self_attn_weights = None + + return hidden_states, self_attn_weights, present_key_value + + +class SundialPreTrainedModel(PreTrainedModel): + config_class = SundialConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["SundialDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, torch.nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, torch.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class SundialModel(SundialPreTrainedModel): + def __init__(self, config: SundialConfig): + super().__init__(config) + self.embed_layer = SundialPatchEmbedding(config) + self.layers = nn.ModuleList( + [ + SundialDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = torch.nn.LayerNorm(config.hidden_size) + self.gradient_checkpointing = False + + def forward( + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[ + Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] + ] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + # input_ids is the input of time series, its shape is [batch_size, seq_len] + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_layer(input_ids) + seq_length = inputs_embeds.shape[1] + + past_key_values_length = 0 + use_legacy_cache = False + + if past_key_values is not None: + use_legacy_cache = not isinstance(past_key_values, Cache) + # Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility. + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # Suppose the sequence length of each layer is the same + past_key_values_length = past_key_values.get_seq_length() + + # When training + checkpoints, caching is usually disabled (just do not transfer) + if ( + self.gradient_checkpointing + and self.training + and isinstance(past_key_values, Cache) + ): + past_key_values = None + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=None, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if isinstance(past_key_values, Cache): + next_decoder_cache = layer_outputs[2] + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if isinstance(past_key_values, Cache): + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin): + def __init__(self, config: SundialConfig): + super().__init__(config) + self.config = config + self.model = SundialModel(self.config) + self.flow_loss = FlowLoss( + self.config.output_token_lens[-1], + self.config.hidden_size, + self.config.flow_loss_depth, + self.config.hidden_size, + self.config.num_sampling_steps, + ) + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[ + Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] + ] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + loss_masks: Optional[torch.FloatTensor] = None, + mask_y: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + max_output_length: Optional[int] = None, + revin: Optional[bool] = False, + num_samples: Optional[int] = 1, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if revin: + means = input_ids.mean(1, keepdim=True).detach() + stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach() + stdev = torch.where( + stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device) + ) + input_ids = (input_ids - means) / stdev + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + predictions = None + + loss = None + if labels is not None: + if revin: + labels = (labels - means) / stdev + output_token_len = self.config.output_token_lens[-1] + seq_len = hidden_states.shape[1] * self.config.input_token_len + labels = labels[ + :, : seq_len - self.config.input_token_len + output_token_len + ] + shift_labels = labels.unfold( + dimension=-1, size=output_token_len, step=self.config.input_token_len + ) + + bsz, L, _ = shift_labels.shape + shift_labels = shift_labels.reshape(bsz * L, -1).repeat( + self.config.diffusion_batch_mul, 1 + ) + hidden_states = hidden_states.reshape(bsz * L, -1).repeat( + self.config.diffusion_batch_mul, 1 + ) + loss_masks = loss_masks.reshape(bsz * L).repeat( + self.config.diffusion_batch_mul + ) + mask_y = mask_y.repeat(L * self.config.diffusion_batch_mul, 1) + + loss = self.flow_loss(shift_labels, hidden_states, loss_masks, mask_y) + else: + if max_output_length is None: + output_token_len = self.config.output_token_lens[0] + max_output_length = output_token_len + else: + output_token_len = self.config.output_token_lens[0] + for h in self.config.output_token_lens[1:]: + if h > max_output_length: + break + else: + output_token_len = h + + bsz = hidden_states.shape[0] + hidden_states = hidden_states[:, -1, :] + predictions = self.flow_loss.sample(hidden_states, num_samples) + if output_token_len > max_output_length: + predictions = predictions[:, :, :max_output_length] + if revin: + predictions = predictions * stdev + means + if not return_dict: + output = (predictions,) + outputs[1:] + return (loss) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=predictions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + revin=False, + num_samples=1, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > ( + input_ids.shape[1] // self.config.input_token_len + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < (input_ids.shape[1] // self.config.input_token_len): + input_ids = input_ids[:, past_length * self.config.input_token_len :] + # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + "revin": revin, + "num_samples": num_samples, + } + ) + return model_inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py new file mode 100644 index 0000000000000..f09621f2cb0a5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/ts_generation_mixin.py @@ -0,0 +1,383 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList +from transformers.generation import EosTokenCriteria, validate_stopping_criteria +from transformers.generation.utils import ( + GenerateDecoderOnlyOutput, + GenerateEncoderDecoderOutput, + GenerateNonBeamOutput, + GenerateOutput, + GenerationConfig, +) +from transformers.utils import ModelOutput + + +class TSGenerationMixin(GenerationMixin): + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + revin: Optional[bool] = True, + num_samples: Optional[int] = 1, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + if len(inputs.shape) != 2: + raise ValueError("Input shape must be: [batch_size, seq_len]") + batch_size, cur_len = inputs.shape + if cur_len < self.config.input_token_len: + raise ValueError( + f"Input length must be at least {self.config.input_token_len}" + ) + if revin: + means = inputs.mean(dim=-1, keepdim=True) + stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5 + inputs = (inputs - means) / stdev + outputs = super().generate( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_samples=num_samples, + **kwargs, + ) + if revin: + stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1) + means = means.unsqueeze(1).repeat(1, num_samples, 1) + outputs = (outputs * stdev) + means + return outputs + + def _sample( + self, + input_ids: torch.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.Tensor]: + input_ids = input_ids.to(self.device) + batch_size, cur_len = input_ids.shape + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + if eos_token_id is not None: + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None + ) + + # keep track of which sequences are already finished + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + this_peer_finished = False + unfinished_sequences = torch.ones( + batch_size, dtype=torch.long, device=input_ids.device + ) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + true_seq_len = ( + cur_len + self.config.input_token_len - 1 + ) // self.config.input_token_len + model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ + :, -true_seq_len: + ] + max_length = stopping_criteria.max_length + generate_results = None + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + input_length = input_ids.shape[1] + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + max_output_length=max_length - input_length, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + next_token_logits = outputs.logits + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_tokens_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # argmax + # next_tokens = torch.argmax(next_tokens_scores, dim=-1) + next_tokens = next_tokens_scores + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + + # update generated ids, model inputs, and length for next step + horizon_length = next_tokens.shape[-1] // self.config.input_token_len + + past_key_values = model_kwargs.get("past_key_values") + if past_key_values is None or generate_results is None: + generate_results = next_tokens + else: + generate_results = torch.cat([generate_results, next_tokens], dim=-1) + input_ids = torch.cat([input_ids, next_tokens.median(dim=1)[0]], dim=-1) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + horizon_length=horizon_length, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) + this_peer_finished = unfinished_sequences.max() == 0 + + if input_ids.shape[-1] > max_length: + input_ids = input_ids[:, :max_length] + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return generate_results[:, :, : (max_length - cur_len)] + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + horizon_length: int = 1, + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + if "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past_key_values"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past_key_values"] = outputs.past_buckets_states + + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [ + attention_mask, + attention_mask.new_ones( + (attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [ + decoder_attention_mask, + decoder_attention_mask.new_ones( + (decoder_attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, + ) + + if ( + "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = ( + model_kwargs["cache_position"][-1:] + horizon_length + ) + + return model_kwargs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py new file mode 100644 index 0000000000000..34f9de91b633d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List + +from transformers import PretrainedConfig + + +class TimerConfig(PretrainedConfig): + model_type = "timer" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + input_token_len: int = 1, + hidden_size: int = 1024, + intermediate_size: int = 2048, + output_token_lens: List[int] = [1, 8, 32, 64], + num_hidden_layers: int = 8, + num_attention_heads: int = 8, + hidden_act: str = "silu", + use_cache: bool = True, + rope_theta: int = 10000, + attention_dropout: float = 0.0, + initializer_range: float = 0.02, + max_position_embeddings: int = 10000, + **kwargs, + ): + self.input_token_len = input_token_len + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.output_token_lens = output_token_lens + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.initializer_range = initializer_range + self.max_position_embeddings = max_position_embeddings + + super().__init__( + **kwargs, + ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py new file mode 100644 index 0000000000000..37bf56dfc59a7 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py @@ -0,0 +1,640 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache, DynamicCache, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) + +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig +from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin + +logger = Logger() + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class TimerPatchEmbedding(nn.Module): + def __init__(self, config: TimerConfig): + super().__init__() + self.input_token_len = config.input_token_len + self.emb = nn.Linear(config.input_token_len, config.hidden_size, bias=False) + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.unfold( + dimension=-1, size=self.input_token_len, step=self.input_token_len + ) + return self.emb(hidden_state) + + +class TimerPointEmbedding(nn.Module): + def __init__(self, config: TimerConfig): + super().__init__() + self.emb_layer = nn.Linear( + config.input_token_len, config.hidden_size, bias=False + ) + self.gate_layer = nn.Linear( + config.input_token_len, config.hidden_size, bias=False + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x) + return emb + + +class TimeMoeRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class TimerAttention(nn.Module): + def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None): + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.attention_dropout = config.attention_dropout + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.rotary_emb = TimeMoeRotaryEmbedding( + self.head_dim, max_position_embeddings=config.max_position_embeddings + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + + if past_key_value is not None: + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx + ) + + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + dropout_p=self.attention_dropout, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class TimerMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, hidden_state): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + + +class TimerDecoderLayer(nn.Module): + def __init__(self, config: TimerConfig, layer_idx: int): + super().__init__() + self.self_attn = TimerAttention(config, layer_idx) + + self.ffn_layer = TimerMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.norm1 = torch.nn.LayerNorm(config.hidden_size) + self.norm2 = torch.nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + Optional[torch.Tensor], + Optional[Cache], + ]: + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + hidden_states = self.norm1(hidden_states) + + # Fully Connected + residual = hidden_states + hidden_states = self.ffn_layer(hidden_states) + hidden_states = residual + hidden_states + hidden_states = self.norm2(hidden_states) + + if not output_attentions: + self_attn_weights = None + + return hidden_states, self_attn_weights, present_key_value + + +class TimerPreTrainedModel(PreTrainedModel): + config_class = TimerConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["TimeDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = False + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, torch.nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, torch.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class TimerModel(TimerPreTrainedModel): + def __init__(self, config: TimerConfig): + super().__init__(config) + self.embed_layer = TimerPatchEmbedding(config) + self.layers = nn.ModuleList( + [ + TimerDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = torch.nn.LayerNorm(config.hidden_size) + self.gradient_checkpointing = False + + def forward( + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[ + Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] + ] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + # input_ids is the input of time series, its shape is [batch_size, seq_len] + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_layer(input_ids) + seq_length = inputs_embeds.shape[1] + + past_key_values_length = 0 + use_legacy_cache = False + + if past_key_values is not None: + use_legacy_cache = not isinstance(past_key_values, Cache) + # Converts the legacy cache which is tuple into an equivalent Cache. Used for backward compatibility. + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() + + # When training + checkpoints, caching is usually disabled (just do not transfer) + if ( + self.gradient_checkpointing + and self.training + and isinstance(past_key_values, Cache) + ): + past_key_values = None + past_key_values_length = 0 + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + position_ids = position_ids.view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + sliding_window=None, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if isinstance(past_key_values, Cache): + next_decoder_cache = layer_outputs[2] + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if isinstance(past_key_values, Cache): + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin): + def __init__(self, config: TimerConfig): + super().__init__(config) + self.config = config + self.model = TimerModel(self.config) + lm_head_list = [] + self.output_token_len_map = {} + for i, output_token_len in enumerate(self.config.output_token_lens): + lm_head_list.append( + nn.Linear(self.config.hidden_size, output_token_len, bias=False) + ) + self.output_token_len_map[output_token_len] = i + self.lm_heads = nn.ModuleList(lm_head_list) + self.loss_function = torch.nn.MSELoss(reduction="none") + self.post_init() + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[ + Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor]]] + ] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + loss_masks: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + max_output_length: Optional[int] = None, + revin: Optional[bool] = False, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if revin: + mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std( + dim=-1, keepdim=True + ) + input_ids = (input_ids - mean) / std + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + predictions = None + + loss = None + if labels is not None: + ar_loss = 0.0 + for lm_head, output_token_len in zip( + self.lm_heads, self.config.output_token_lens + ): + one_predictions = lm_head(hidden_states) + one_loss = self.calc_ar_loss( + one_predictions, labels, loss_masks, output_token_len + ) + ar_loss += one_loss + if predictions is None: + predictions = one_predictions + loss = ar_loss / len(self.config.output_token_lens) + else: + if max_output_length is None: + output_token_len = self.config.output_token_lens[0] + max_output_length = output_token_len + else: + output_token_len = self.config.output_token_lens[0] + for h in self.config.output_token_lens[1:]: + if h > max_output_length: + break + else: + output_token_len = h + lm_head = self.lm_heads[self.output_token_len_map[output_token_len]] + predictions = lm_head(hidden_states)[:, -1, :] + if output_token_len > max_output_length: + predictions = predictions[:, :max_output_length] + if revin: + predictions = predictions * std + mean + if not return_dict: + output = (predictions,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=predictions, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len): + seq_len = predictions.shape[1] * self.config.input_token_len + labels = labels[:, : seq_len - self.config.input_token_len + output_token_len] + shift_labels = labels.unfold( + dimension=-1, size=output_token_len, step=self.config.input_token_len + ) + + # Calculate loss with mask + losses = self.loss_function(predictions, shift_labels).mean(dim=-1) + if loss_masks is not None: + losses = losses * loss_masks + loss = losses.sum() / loss_masks.sum() + else: + loss = torch.mean(losses) + + return loss + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + revin=True, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + past_length = past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > ( + input_ids.shape[1] // self.config.input_token_len + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < (input_ids.shape[1] // self.config.input_token_len): + input_ids = input_ids[:, past_length * self.config.input_token_len :] + # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + "revin": revin, + } + ) + return model_inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py new file mode 100644 index 0000000000000..18f711b8e1a13 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py @@ -0,0 +1,370 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import warnings +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList +from transformers.generation import EosTokenCriteria, validate_stopping_criteria +from transformers.generation.utils import ( + GenerateDecoderOnlyOutput, + GenerateEncoderDecoderOutput, + GenerateNonBeamOutput, + GenerateOutput, + GenerationConfig, +) +from transformers.utils import ModelOutput + + +class TSGenerationMixin(GenerationMixin): + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + if len(inputs.shape) == 2: + batch_size, cur_len = inputs.shape + if cur_len < self.config.input_token_len: + raise ValueError( + f"Input length must be at least {self.config.input_token_len}" + ) + elif cur_len % self.config.input_token_len != 0: + new_len = ( + cur_len // self.config.input_token_len + ) * self.config.input_token_len + inputs = inputs[:, -new_len:] + else: + raise ValueError("Input shape must be: [batch_size, seq_len]") + return super().generate( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + **kwargs, + ) + + def _sample( + self, + input_ids: torch.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.Tensor]: + input_ids = input_ids.to(self.device) + batch_size, cur_len = input_ids.shape + # init values + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria( + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) + if eos_token_id is not None: + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + else: + # remove when the method is totally private + # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever + eos_token_id = [ + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") + ] + eos_token_id = eos_token_id[0] if eos_token_id else None + if eos_token_id is None and self.generation_config.eos_token_id is not None: + eos_token_id = self.generation_config.eos_token_id + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + raw_logits = () if (return_dict_in_generate and output_logits) else None + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None + ) + + # keep track of which sequences are already finished + if "inputs_embeds" in model_kwargs: + cur_len = model_kwargs["inputs_embeds"].shape[1] + this_peer_finished = False + unfinished_sequences = torch.ones( + batch_size, dtype=torch.long, device=input_ids.device + ) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + true_seq_len = cur_len // self.config.input_token_len + model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ + :, -true_seq_len: + ] + max_length = stopping_criteria.max_length + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + input_length = input_ids.shape[1] + + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + max_output_length=max_length - input_length, + ) + + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits + + # pre-process distribution + next_tokens_scores = logits_processor(input_ids, next_token_logits) + + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_tokens_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # argmax + # next_tokens = torch.argmax(next_tokens_scores, dim=-1) + next_tokens = next_tokens_scores + + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError( + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) + + # update generated ids, model inputs, and length for next step + horizon_length = next_tokens.shape[1] // self.config.input_token_len + + input_ids = torch.cat([input_ids, next_tokens], dim=-1) + if streamer is not None: + streamer.put(next_tokens.cpu()) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + horizon_length=horizon_length, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + unfinished_sequences = unfinished_sequences & ~stopping_criteria( + input_ids, scores + ) + this_peer_finished = unfinished_sequences.max() == 0 + + if input_ids.shape[1] > max_length: + input_ids = input_ids[:, :max_length] + + if streamer is not None: + streamer.end() + + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return GenerateEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return GenerateDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + ) + else: + return input_ids[:, -(max_length - cur_len) :] + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + horizon_length: int = 1, + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + if "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + elif "mems" in outputs: + model_kwargs["past_key_values"] = outputs.mems + elif "past_buckets_states" in outputs: + model_kwargs["past_key_values"] = outputs.past_buckets_states + + if getattr(outputs, "state", None) is not None: + model_kwargs["state"] = outputs.state + + # update token_type_ids with last value + if "token_type_ids" in model_kwargs: + token_type_ids = model_kwargs["token_type_ids"] + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) + + if not is_encoder_decoder: + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [ + attention_mask, + attention_mask.new_ones( + (attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, + ) + else: + # update decoder attention mask + if "decoder_attention_mask" in model_kwargs: + decoder_attention_mask = model_kwargs["decoder_attention_mask"] + model_kwargs["decoder_attention_mask"] = torch.cat( + [ + decoder_attention_mask, + decoder_attention_mask.new_ones( + (decoder_attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, + ) + + if ( + "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = ( + model_kwargs["cache_position"][-1:] + horizon_length + ) + + return model_kwargs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py new file mode 100644 index 0000000000000..9da8486d3905e --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import importlib +import json +import sys +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Tuple + +from iotdb.ainode.core.model.model_enums import ( + MODEL_CONFIG_FILE, + MODEL_WEIGHTS_FILE, + UriType, +) + + +def parse_uri_type(uri: str) -> UriType: + if uri.startswith("repo://"): + return UriType.REPO + elif uri.startswith("file://"): + return UriType.FILE + else: + raise ValueError( + f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + ) + + +def get_parsed_uri(uri: str) -> str: + return uri[7:] # Remove "repo://" or "file://" prefix + + +@contextmanager +def temporary_sys_path(path: str): + """Context manager for temporarily adding a path to sys.path""" + path_added = path not in sys.path + if path_added: + sys.path.insert(0, path) + try: + yield + finally: + if path_added and path in sys.path: + sys.path.remove(path) + + +def load_model_config(config_path: Path) -> Dict: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def validate_model_files(model_dir: Path) -> Tuple[Path, Path]: + """Validate model files exist, return config and weights file paths""" + config_path = model_dir / MODEL_CONFIG_FILE + weights_path = model_dir / MODEL_WEIGHTS_FILE + + if not config_path.exists(): + raise ValueError(f"Model config file does not exist: {config_path}") + if not weights_path.exists(): + raise ValueError(f"Model weights file does not exist: {weights_path}") + + # Create __init__.py file to ensure model directory can be imported as a module + init_file = model_dir / "__init__.py" + if not init_file.exists(): + init_file.touch() + + return config_path, weights_path + + +def import_class_from_path(module_name, class_path: str): + file_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name + "." + file_name) + return getattr(module, class_name) + + +def ensure_init_file(path: Path): + """Ensure __init__.py file exists in the given path""" + init_file = path / "__init__.py" + if not init_file.exists(): + init_file.touch() diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index f01e1594f0698..de5de968a7e52 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -20,7 +20,7 @@ from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.cluster_manager import ClusterManager from iotdb.ainode.core.manager.inference_manager import InferenceManager -from iotdb.ainode.core.manager.model_manager import ModelManager +from iotdb.ainode.core.manager.model_manager import get_model_manager from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.thrift.ainode import IAINodeRPCService @@ -47,24 +47,10 @@ logger = Logger() -def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus: - """ - Ensure that the device IDs in the provided list are available. - """ - available_devices = get_available_devices() - for device_id in device_id_list: - if device_id not in available_devices: - return TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", - ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) - - class AINodeRPCServiceHandler(IAINodeRPCService.Iface): def __init__(self, ainode): self._ainode = ainode - self._model_manager = ModelManager() + self._model_manager = get_model_manager() self._inference_manager = InferenceManager() def stop(self) -> None: @@ -78,43 +64,55 @@ def stopAINode(self) -> TSStatus: def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) - def loadModel(self, req: TLoadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - status = _ensure_device_id_is_available(req.deviceIdList) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - return self._inference_manager.load_model(req) - - def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - status = _ensure_device_id_is_available(req.deviceIdList) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return status - return self._inference_manager.unload_model(req) - def deleteModel(self, req: TDeleteModelReq) -> TSStatus: return self._model_manager.delete_model(req) - def inference(self, req: TInferenceReq) -> TInferenceResp: - return self._inference_manager.inference(req) + def showModels(self, req: TShowModelsReq) -> TShowModelsResp: + return self._model_manager.show_models(req) - def forecast(self, req: TForecastReq) -> TSStatus: - return self._inference_manager.forecast(req) + def loadModel(self, req: TLoadModelReq) -> TSStatus: + if not self._model_manager.is_model_registered(req.existingModelId): + return TSStatus( + code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, + message=f"Model [{req.existingModelId}] is not supported. You can use 'SHOW MODELS' to retrieve the available models.", + ) - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) + available_devices = get_available_devices() + for device_id in req.deviceIdList: + if device_id not in available_devices: + return TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) - def showModels(self, req: TShowModelsReq) -> TShowModelsResp: - return self._model_manager.show_models(req) + return self._inference_manager.load_model(req) + + def unloadModel(self, req: TUnloadModelReq) -> TSStatus: + if not self._model_manager.is_model_registered(req.modelId): + return TSStatus( + code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, + message=f"Model [{req.modelId}] is not supported. You can use 'SHOW MODELS' to retrieve the available models.", + ) + + available_devices = get_available_devices() + for device_id in req.deviceIdList: + if device_id not in available_devices: + return TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) + + return self._inference_manager.unload_model(req) def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: - status = _ensure_device_id_is_available(req.deviceIdList) - if status.code != TSStatusCode.SUCCESS_STATUS.value: - return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) + available_devices = get_available_devices() + for device_id in req.deviceIdList: + if device_id not in available_devices: + status = TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) + return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) return self._inference_manager.show_loaded_models(req) def showAIDevices(self) -> TShowAIDevicesResp: @@ -123,13 +121,14 @@ def showAIDevices(self) -> TShowAIDevicesResp: deviceIdList=get_available_devices(), ) + def inference(self, req: TInferenceReq) -> TInferenceResp: + return self._inference_manager.inference(req) + + def forecast(self, req: TForecastReq) -> TSStatus: + return self._inference_manager.forecast(req) + + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass - - def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: - if not self._model_manager.is_built_in_or_fine_tuned(model_id): - return TSStatus( - code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", - ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) From 20923d55c7544955c82f25476087235aefbe2a72 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Wed, 26 Nov 2025 17:40:04 +0800 Subject: [PATCH 04/38] remove useless codes in IoTDB remove useless codes in IoTDB --- .../ainode/iotdb/ainode/core/constant.py | 16 +- .../ainode/iotdb/ainode/core/exception.py | 5 +- .../core/inference/inference_request_pool.py | 2 +- .../core/inference/pipeline/__init__.py | 8 +- .../core/inference/pipeline/basic_pipeline.py | 17 +- .../ainode/core/manager/model_manager.py | 53 +-- .../ainode/iotdb/ainode/core/manager/utils.py | 6 +- .../{model_enums.py => model_constants.py} | 29 +- .../iotdb/ainode/core/model/model_info.py | 54 +-- .../iotdb/ainode/core/model/model_loader.py | 8 +- .../iotdb/ainode/core/model/model_storage.py | 357 +++++++--------- .../sktime/pipeline_sktime.py} | 0 .../core/model/sundial/modeling_sundial.py | 8 +- .../sundial/pipeline_sundial.py} | 2 - .../model/{timerxl => timer_xl}/__init__.py | 0 .../configuration_timer.py | 0 .../{timerxl => timer_xl}/modeling_timer.py | 9 +- .../timer_xl/pipeline_timer.py} | 4 +- .../ts_generation_mixin.py | 0 .../ainode/iotdb/ainode/core/model/utils.py | 19 +- .../ainode/iotdb/ainode/core/rpc/handler.py | 86 ++-- .../async/AsyncAINodeHeartbeatClientPool.java | 19 +- .../AsyncDataNodeHeartbeatClientPool.java | 1 - .../consensus/request/ConfigPhysicalPlan.java | 8 - .../request/read/model/GetModelInfoPlan.java | 64 --- .../request/read/model/ShowModelPlan.java | 70 --- .../response/model/GetModelInfoResp.java | 63 --- .../response/model/ModelTableResp.java | 62 --- .../confignode/manager/ConfigManager.java | 179 -------- .../iotdb/confignode/manager/IManager.java | 42 -- .../confignode/manager/ModelManager.java | 245 ----------- .../confignode/manager/ProcedureManager.java | 20 - .../confignode/persistence/ModelInfo.java | 378 ----------------- .../executor/ConfigPlanExecutor.java | 25 -- .../impl/model/CreateModelProcedure.java | 250 ----------- .../impl/model/DropModelProcedure.java | 200 --------- .../impl/node/RemoveAINodeProcedure.java | 7 +- .../procedure/store/ProcedureFactory.java | 12 - .../procedure/store/ProcedureType.java | 2 + .../thrift/ConfigNodeRPCServiceProcessor.java | 25 -- .../protocol/client/AINodeClientFactory.java | 133 ------ .../db/protocol/client/ConfigNodeClient.java | 30 +- .../client/DataNodeClientPoolFactory.java | 42 +- .../protocol/client/ainode/AINodeClient.java | 401 ------------------ .../client/ainode/AINodeClientManager.java | 75 ---- .../db/protocol/client/an/AINodeClient.java | 321 ++++++++++++++ .../client/an/AINodeClientManager.java | 47 ++ .../process/ai/InferenceOperator.java | 14 +- ...formationSchemaContentSupplierFactory.java | 113 ----- .../plan/analyze/IModelFetcher.java | 4 - .../plan/analyze/ModelFetcher.java | 40 +- .../executor/ClusterConfigTaskExecutor.java | 43 +- .../analyzer/StatementAnalyzer.java | 6 - .../function/tvf/ForecastTableFunction.java | 39 +- .../plan/relational/metadata/Metadata.java | 6 - .../metadata/TableMetadataImpl.java | 5 - .../DataNodeLocationSupplierFactory.java | 1 - .../db/queryengine/plan/udf/UDTFForecast.java | 25 +- .../relational/analyzer/TSBSMetadata.java | 6 - .../analyzer/TableFunctionTest.java | 3 - .../relational/analyzer/TestMetadata.java | 19 - iotdb-core/node-commons/pom.xml | 5 + .../commons/client/ClientPoolFactory.java | 28 ++ .../AsyncAINodeInternalServiceClient.java} | 25 +- .../schema/table/InformationSchema.java | 18 - .../src/main/thrift/confignode.thrift | 63 --- 66 files changed, 808 insertions(+), 3059 deletions(-) rename iotdb-core/ainode/iotdb/ainode/core/model/{model_enums.py => model_constants.py} (67%) rename iotdb-core/ainode/iotdb/ainode/core/{inference/pipeline/sktime_pipeline.py => model/sktime/pipeline_sktime.py} (100%) rename iotdb-core/ainode/iotdb/ainode/core/{inference/pipeline/sundial_pipeline.py => model/sundial/pipeline_sundial.py} (95%) rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/__init__.py (100%) rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/configuration_timer.py (100%) rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/modeling_timer.py (98%) rename iotdb-core/ainode/iotdb/ainode/core/{inference/pipeline/timerxl_pipeline.py => model/timer_xl/pipeline_timer.py} (92%) rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/ts_generation_mixin.py (100%) delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java delete mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java delete mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java delete mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java rename iotdb-core/{datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java => node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java} (83%) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 74decf9e88a61..3576ac711ce0e 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -51,7 +51,7 @@ AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { "sundial": 1036 * 1024**2, # 1036 MiB - "timerxl": 856 * 1024**2, # 856 MiB + "timer_xl": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( @@ -59,15 +59,6 @@ ) AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") -AINODE_BUILTIN_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/builtin" -) # For built-in models, we only need to store their weights and config. -AINODE_FINETUNE_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/finetune" -) -AINODE_USER_DEFINED_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/user_defined" -) AINODE_SYSTEM_DIR = "data/ainode/system" AINODE_LOG_DIR = "logs" @@ -80,11 +71,6 @@ "log_inference_rank_{}_" # example: log_inference_rank_0_all.log ) -# AINode model management -MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" -MODEL_CONFIG_FILE_IN_JSON = "config.json" -MODEL_WEIGHTS_FILE_IN_PT = "model.pt" -MODEL_CONFIG_FILE_IN_YAML = "config.yaml" DEFAULT_CHUNK_SIZE = 8192 diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index bc89cdc306625..91ad096418872 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -17,10 +17,7 @@ # import re -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) +from iotdb.ainode.core.model.model_constants import MODEL_WEIGHTS_FILE_IN_PT, MODEL_CONFIG_FILE_IN_YAML class _BaseError(Exception): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index fcaa4c7a7543b..b56ffce461f5f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -82,7 +82,7 @@ def __init__( self._batcher = BasicBatcher() self._stop_event = mp.Event() - # self._inference_pipeline = get_pipeline(self.model_info.model_id, self.device) + self._inference_pipeline = None self._logger = None diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py index 617c4e6738061..53cf7b1086891 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py @@ -16,13 +16,13 @@ # under the License. # -from iotdb.ainode.core.inference.pipeline.sundial_pipeline import SundialPipeline -from iotdb.ainode.core.inference.pipeline.timerxl_pipeline import TimerxlPipeline +from iotdb.ainode.core.model.sundial.pipeline_sundial import SundialPipeline +from iotdb.ainode.core.model.timer_xl.pipeline_timer import TimerPipeline def get_pipeline(model_id, device): - if model_id == "timerxl": - return TimerxlPipeline(model_id, device=device) + if model_id == "timer_xl": + return TimerPipeline(model_id, device=device) elif model_id == "sundial": return SundialPipeline(model_id, device=device) else: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index c413a92e82d62..19efe0220c64f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -21,15 +21,14 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.manager.model_manager import get_model_manager +from iotdb.ainode.core.manager.model_manager import ModelManager class BasicPipeline(ABC): def __init__(self, model_id, **infer_kwargs): self.model_id = model_id self.device = infer_kwargs.get("device", "cpu") - # self.model = get_model_manager().load_model(model_id).to(self.device) - self.model = get_model_manager().load_model( + self.model = ModelManager().load_model( model_id, device_map=str(self.device) ) @@ -40,15 +39,6 @@ def _preprocess(self, inputs): # TODO: Integrate with the data processing pipeline operators pass - def infer(self, inputs): - pass - - def _post_decode(self): - """ - Post-process the outputs after each decode step. - """ - pass - def _postprocess(self, output: torch.Tensor): """ Post-process the outputs after the entire inference task. @@ -70,9 +60,6 @@ def _preprocess(self, inputs): def forecast(self, inputs, **infer_kwargs): pass - def _post_decode(self): - pass - def _postprocess(self, output: torch.Tensor): pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index a07552922ff36..4cafbbecd8c0b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -16,16 +16,15 @@ # under the License. # -import os from typing import Any, List, Optional -from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_loader import ModelLoader from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status +from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.ainode.ttypes import ( TDeleteModelReq, TRegisterModelReq, @@ -37,43 +36,34 @@ logger = Logger() - +@singleton class ModelManager: def __init__(self): - self.models_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() - ) - self.storage = ModelStorage(models_dir=self.models_dir) - self.loader = ModelLoader(storage=self.storage) - - # Automatically discover all models - self._models = self.storage.discover_all() + self._model_storage = ModelStorage() + self._model_loader = ModelLoader(storage=self._model_storage) def register_model( self, req: TRegisterModelReq, ) -> TRegisterModelResp: try: - success = self.storage.register_model(model_id=req.modelId, uri=req.uri) - if success: + if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) - else: - return TRegisterModelResp( - get_status(TSStatusCode.AINODE_INTERNAL_ERROR) - ) + return TRegisterModelResp( + get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) except ValueError as e: return TRegisterModelResp( get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) ) except Exception as e: - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))) def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - return self.storage.show_models(req) + return self._model_storage.show_models(req) def delete_model(self, req: TDeleteModelReq) -> TSStatus: try: - self.storage.delete_model(req.modelId) + self._model_storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) except BuiltInModelDeletionError as e: logger.warning(e) @@ -83,39 +73,28 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus: return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) def load_model(self, model_id: str, **kwargs) -> Any: - return self.loader.load_model(model_id=model_id, **kwargs) + return self._model_loader.load_model(model_id=model_id, **kwargs) def get_model_info( self, model_id: str, category: Optional[ModelCategory] = None, ) -> Optional[ModelInfo]: - return self.storage.get_model_info(model_id, category) + return self._model_storage.get_model_info(model_id, category) def get_model_infos( self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None, ) -> List[ModelInfo]: - return self.storage.get_model_infos(category, model_type) + return self._model_storage.get_model_infos(category, model_type) def refresh(self): """Refresh the model list (re-scan the file system)""" - self._models = self.storage.discover_all() + self._model_storage.discover_all_models() def get_registered_models(self) -> List[str]: - return self.storage.get_registered_models() + return self._model_storage.get_registered_models() def is_model_registered(self, model_id: str) -> bool: - return self.storage.is_model_registered(model_id) - - -# Create a global model manager instance -_default_manager: Optional[ModelManager] = None - - -def get_model_manager() -> ModelManager: - global _default_manager - if _default_manager is None: - _default_manager = ModelManager() - return _default_manager + return self._model_storage.is_model_registered(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 297a2b832d7a8..afb87ee8c5acf 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -24,7 +24,7 @@ from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import get_model_manager +from iotdb.ainode.core.manager.model_manager import ModelManager logger = Logger() @@ -46,7 +46,7 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = get_model_manager().load_model(model_id, {}).to(device) + model = ModelManager().load_model(model_id).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start @@ -79,7 +79,7 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = get_model_manager.get_model_info(model_id) + model_info = ModelManager().get_model_info(model_id) if model_info is None or model_info.model_id not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py similarity index 67% rename from iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py rename to iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py index a6a234a1ab8f5..ef24830142a39 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -18,41 +18,24 @@ from enum import Enum +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + class ModelCategory(Enum): BUILTIN = "builtin" USER_DEFINED = "user_defined" - FINETUNE = "finetune" class ModelStates(Enum): INACTIVE = "inactive" ACTIVATING = "activating" ACTIVE = "active" - LOADING = "loading" - LOADED = "loaded" DROPPING = "dropping" - TRAINING = "training" - FAILED = "failed" - - -class ModelFileType(Enum): - SAFETENSORS = "safetensors" - PYTORCH = "pytorch" - UNKNOWN = "unknown" class UriType(Enum): REPO = "repo" FILE = "file" - - -# Map for inferring which HuggingFace repository to download from based on model ID -REPO_ID_MAP = { - "timerxl": "thuml/timer-base-84m", - "sundial": "thuml/sundial-base-128m", - # More mappings can be added as needed -} - -# Model file constants -MODEL_CONFIG_FILE = "config.json" -MODEL_WEIGHTS_FILE = "model.safetensors" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index f36ad582a837b..690bc09fe8aae 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -16,29 +16,20 @@ # under the License. # -from typing import Dict, List, Optional, Tuple - -from iotdb.ainode.core.model.model_enums import ModelCategory, ModelStates - -# Map for inferring which HuggingFace repository to download from based on model ID -REPO_ID_MAP = { - "timerxl": "thuml/timer-base-84m", - "sundial": "thuml/sundial-base-128m", - # More mappings can be added as needed -} - -# Model file constants -MODEL_CONFIG_FILE = "config.json" -MODEL_WEIGHTS_FILE = "model.safetensors" +from typing import Dict, Optional +from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates class ModelInfo: def __init__( self, model_id: str, - model_type: str, category: ModelCategory, state: ModelStates, + model_type: str = "", + model_cls: str = "", + pipeline_cls: str = "", + repo_id: str = "", path: str = "", auto_map: Optional[Dict] = None, _transformers_registered: bool = False, @@ -47,6 +38,9 @@ def __init__( self.model_type = model_type self.category = category self.state = state + self.model_cls = model_cls + self.pipeline_cls = pipeline_cls + self.repo_id = repo_id self.path = path self.auto_map = auto_map # If exists, indicates it's a Transformers model self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers @@ -63,67 +57,73 @@ def __repr__(self): # forecast models "arima": ModelInfo( model_id="arima", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "holtwinters": ModelInfo( model_id="holtwinters", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "exponential_smoothing": ModelInfo( model_id="exponential_smoothing", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "naive_forecaster": ModelInfo( model_id="naive_forecaster", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stl_forecaster": ModelInfo( model_id="stl_forecaster", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), # anomaly detection models "gaussian_hmm": ModelInfo( model_id="gaussian_hmm", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "gmm_hmm": ModelInfo( model_id="gmm_hmm", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stray": ModelInfo( model_id="stray", - model_type="sktime", category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), } # Built-in huggingface transformers models, their weights are not included in AINode by default BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { - "timerxl": ModelInfo( - model_id="timerxl", - model_type="timer", + "timer_xl": ModelInfo( + model_id="timer_xl", category=ModelCategory.BUILTIN, state=ModelStates.INACTIVE, + model_type="timer", + model_cls="modeling_timer.TimerForPrediction", + pipeline_cls="pipeline_timer.TimerPipeline", + repo_id="thuml/timer-base-84m", ), "sundial": ModelInfo( model_id="sundial", - model_type="sundial", category=ModelCategory.BUILTIN, state=ModelStates.INACTIVE, + model_type="sundial", + model_cls="modeling_sundial.SundialForPrediction", + pipeline_cls="pipeline_sundial.SundialPipeline", + repo_id="thuml/sundial-base-128m", ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index 9b70b5f80401e..dc8220aad9a25 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -32,7 +32,7 @@ from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelCategory +from iotdb.ainode.core.model.model_constants import ModelCategory from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.model.model_storage import ModelStorage from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model @@ -73,13 +73,13 @@ def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): train_from_scratch = kwargs.get("train_from_scratch", False) if model_info.category == ModelCategory.BUILTIN: - if model_info.model_id == "timerxl": - from iotdb.ainode.core.model.timerxl.configuration_timer import ( + if model_info.model_id == "timer_xl": + from iotdb.ainode.core.model.timer_xl.configuration_timer import ( TimerConfig, ) model_config = TimerConfig() - from iotdb.ainode.core.model.timerxl.modeling_timer import ( + from iotdb.ainode.core.model.timer_xl.modeling_timer import ( TimerForPrediction, ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index 5ba67e158a5dc..e2c9e8bf21316 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -17,23 +17,27 @@ # import concurrent.futures +import json import os import shutil -from typing import List, Optional +from pathlib import Path +from typing import List, Optional, Dict from huggingface_hub import hf_hub_download, snapshot_download from transformers import AutoConfig, AutoModelForCausalLM +from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import REPO_ID_MAP, ModelCategory, ModelStates +from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates, UriType, \ + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, MODEL_CONFIG_FILE_IN_JSON from iotdb.ainode.core.model.model_info import ( BUILTIN_HF_TRANSFORMERS_MODEL_MAP, BUILTIN_SKTIME_MODEL_MAP, - ModelInfo, -) -from iotdb.ainode.core.model.utils import * + ModelInfo) +from iotdb.ainode.core.model.utils import ensure_init_file, load_model_config_in_json, parse_uri_type, get_parsed_uri, \ + validate_model_files, temporary_sys_path, import_class_from_path from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus @@ -44,56 +48,49 @@ class ModelStorage: """Model storage class - unified management of model discovery and registration""" - def __init__(self, models_dir: str): - self.models_dir = Path(models_dir) + def __init__(self): + self._models_dir = os.path.join( + os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() + ) # Unified storage: category -> {model_id -> ModelInfo} self._models: Dict[str, Dict[str, ModelInfo]] = { ModelCategory.BUILTIN.value: {}, ModelCategory.USER_DEFINED.value: {}, - ModelCategory.FINETUNE.value: {}, } - # Async download executor (using single-threaded executor because hf download interface is unstable with concurrent downloads) - self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + # Async download executor + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) # Thread lock pool for protecting concurrent access to model information self._lock_pool = ModelLockPool() self._initialize_directories() + self.discover_all_models() def _initialize_directories(self): """Initialize directory structure and ensure __init__.py files exist""" - self.models_dir.mkdir(parents=True, exist_ok=True) - ensure_init_file(self.models_dir) - + os.makedirs(self._models_dir, exist_ok=True) + ensure_init_file(self._models_dir) for category in ModelCategory: - category_path = self.models_dir / category.value - category_path.mkdir(parents=True, exist_ok=True) + category_path = os.path.join(self._models_dir, category.value) + os.makedirs(category_path, exist_ok=True) ensure_init_file(category_path) # ==================== Discovery Methods ==================== - def discover_all(self) -> Dict[str, Dict[str, ModelInfo]]: + def discover_all_models(self): """Scan file system to discover all models""" self._discover_category(ModelCategory.BUILTIN) self._discover_category(ModelCategory.USER_DEFINED) - self._discover_category(ModelCategory.FINETUNE) - return self._models def _discover_category(self, category: ModelCategory): """Discover all models in a category directory""" - category_path = self.models_dir / category.value - if not category_path.exists(): - return - + category_path = os.path.join(self._models_dir, category.value) if category == ModelCategory.BUILTIN: self._discover_builtin_models(category_path) - else: - # For finetune and user_defined, scan directories - for item in category_path.iterdir(): - if item.is_dir() and not item.name.startswith("__"): - relative_path = item.relative_to(category_path) - model_id = str(relative_path).replace("/", "_").replace("\\", "_") - self._process_model_directory(item, model_id, category) - - def _discover_builtin_models(self, category_path: Path): + elif category == ModelCategory.USER_DEFINED: + for model_id in os.listdir(category_path): + if os.path.isdir(os.path.join(category_path, model_id)): + self._process_user_defined_model_directory(os.path.join(category_path, model_id), model_id) + + def _discover_builtin_models(self, category_path: str): # Register SKTIME models directly from map for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): with self._lock_pool.get_lock(model_id).write_lock(): @@ -103,136 +100,115 @@ def _discover_builtin_models(self, category_path: Path): # Process HuggingFace Transformers models for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): - model_dir = category_path / model_id - model_dir.mkdir(parents=True, exist_ok=True) - self._process_model_directory(model_dir, model_id, ModelCategory.BUILTIN) + model_dir = os.path.join(category_path, model_id) + os.makedirs(model_dir, exist_ok=True) + self._process_builtin_model_directory(model_dir, model_id) - def _process_model_directory( - self, model_dir: Path, model_id: str, category: ModelCategory + def _process_builtin_model_directory( + self, model_dir: str, model_id: str ): - """Handling the discovery logic for a single model directory.""" + """Handling the discovery logic for a builtin model directory.""" ensure_init_file(model_dir) + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] + self._models[ModelCategory.BUILTIN.value][model_id].state = ModelStates.ACTIVATING + + def _download_model_if_necessary() -> bool: + """Returns: True if the model is existed or downloaded successfully, False otherwise.""" + repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + local_dir=model_dir, + ) + except Exception as e: + logger.error(f"Failed to download model weights from HuggingFace: {e}") + return False + if not os.path.exists(config_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE_IN_JSON, + local_dir=model_dir, + ) + except Exception as e: + logger.error(f"Failed to download model config from HuggingFace: {e}") + return False + return True - config_path = model_dir / MODEL_CONFIG_FILE - weights_path = model_dir / MODEL_WEIGHTS_FILE - needs_download = not config_path.exists() or not weights_path.exists() - - if needs_download: - with self._lock_pool.get_lock(model_id).write_lock(): - model_info = ModelInfo( - model_id=model_id, - model_type="", # Read from config.json after download - category=category, - state=ModelStates.ACTIVATING, - path=str(model_dir), - auto_map=None, - _transformers_registered=False, - ) - self._models[category.value][model_id] = model_info - - future = self._executor.submit( - self._download_model_if_necessary, str(model_dir), model_id - ) - future.add_done_callback( - lambda f, mid=model_id, cat=category: self._callback_model_download_result( - f, mid, cat - ) + future = self._executor.submit(_download_model_if_necessary) + future.add_done_callback( + lambda f, mid=model_id: self._callback_model_download_result( + f, mid ) - else: - config = load_model_config(config_path) - model_type = config.get("model_type", "") - auto_map = config.get("auto_map") - - with self._lock_pool.get_lock(model_id).write_lock(): - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=category, - state=ModelStates.ACTIVE, - path=str(model_dir), - auto_map=auto_map, - _transformers_registered=False, # Lazy registration - ) - self._models[category.value][model_id] = model_info - - def _download_model_if_necessary(self, model_dir: str, model_id: str) -> bool: - """Returns: True if the model is existed or downloaded successfully, False otherwise.""" - if model_id in REPO_ID_MAP: - repo_id = REPO_ID_MAP[model_id] - else: - logger.error(f"Model {model_id} not found in REPO_ID_MAP") - return False - - weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE) - config_path = os.path.join(model_dir, MODEL_CONFIG_FILE) - - if not os.path.exists(weights_path): - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE, - local_dir=model_dir, - ) - except Exception as e: - logger.error(f"Failed to download model weights from HuggingFace: {e}") - return False - - if not os.path.exists(config_path): - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE, - local_dir=model_dir, - ) - except Exception as e: - logger.error(f"Failed to download model config from HuggingFace: {e}") - return False - - return True + ) def _callback_model_download_result( - self, future, model_id: str, category: ModelCategory + self, future, model_id: str ): """Callback function for handling model download results""" with self._lock_pool.get_lock(model_id).write_lock(): try: if future.result(): - if model_id in self._models[category.value]: - model_info = self._models[category.value][model_id] - model_info.state = ModelStates.ACTIVE - config_path = os.path.join(model_info.path, MODEL_CONFIG_FILE) - if os.path.exists(config_path): - with open(config_path, "r", encoding="utf-8") as f: - config = json.load(f) + model_info = self._models[ModelCategory.BUILTIN.value][model_id] + model_info.state = ModelStates.ACTIVE + config_path = os.path.join(model_info.path, MODEL_CONFIG_FILE_IN_JSON) + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + if model_info.model_type == "": model_info.model_type = config.get("model_type", "") - model_info.auto_map = config.get("auto_map") - logger.info( - f"Model {model_id} downloaded successfully and is ready to use." - ) + logger.info( + f"Model {model_id} downloaded successfully and is ready to use." + ) else: - if model_id in self._models[category.value]: - self._models[category.value][ - model_id - ].state = ModelStates.INACTIVE - logger.warning(f"Failed to download model {model_id}.") + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + logger.warning(f"Failed to download model {model_id}.") except Exception as e: logger.error(f"Error in download callback for model {model_id}: {e}") - if model_id in self._models[category.value]: - self._models[category.value][model_id].state = ModelStates.INACTIVE + self._models[ModelCategory.BUILTIN.value][model_id].state = ModelStates.INACTIVE + + def _process_user_defined_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a user-defined model directory.""" + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + model_type = "" + auto_map = {} + if os.path.exists(config_path): + config = load_model_config_in_json(Path(config_path)) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( + model_id=model_id, + model_type=model_type, + category=ModelCategory.USER_DEFINED, + state=ModelStates.ACTIVE, + path=str(model_dir), + auto_map=auto_map, + _transformers_registered=False, # Lazy registration + ) + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info # ==================== Registration Methods ==================== def register_model(self, model_id: str, uri: str) -> bool: """ Supported URI formats: - - repo:// + - repo:// (Maybe in the future) - file:// """ uri_type = parse_uri_type(uri) parsed_uri = get_parsed_uri(uri) - model_dir = Path(self.models_dir) / "user_defined" / model_id - model_dir.mkdir(parents=True, exist_ok=True) + model_dir = os.path.join(self._models_dir, ModelCategory.USER_DEFINED.value, model_id) + os.makedirs(model_dir, exist_ok=True) ensure_init_file(model_dir) if uri_type == UriType.REPO: @@ -241,7 +217,7 @@ def register_model(self, model_id: str, uri: str) -> bool: self._fetch_model_from_local(os.path.expanduser(parsed_uri), str(model_dir)) config_path, _ = validate_model_files(model_dir) - config = load_model_config(config_path) + config = load_model_config_in_json(config_path) model_type = config.get("model_type", "") auto_map = config.get("auto_map") @@ -279,7 +255,6 @@ def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): logger.info( f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" ) - # Use snapshot_download to download entire repository (including config.json and model.safetensors) try: snapshot_download( @@ -293,29 +268,13 @@ def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): def _fetch_model_from_local(self, source_path: str, storage_path: str): logger.info(f"Copying model from local path: {source_path} -> {storage_path}") - - source_dir = Path(source_path) - if not source_dir.is_dir(): + if not os.path.isdir(source_path): raise ValueError( f"Source path does not exist or is not a directory: {source_path}" ) - - source_config = source_dir / MODEL_CONFIG_FILE - source_weights = source_dir / MODEL_WEIGHTS_FILE - if not source_config.exists(): - raise ValueError( - f"Config file missing in source directory: {source_config}" - ) - if not source_weights.exists(): - raise ValueError( - f"Weights file missing in source directory: {source_weights}" - ) - - # Copy all files - storage_dir = Path(storage_path) - for file in source_dir.iterdir(): - if file.is_file(): - shutil.copy2(file, storage_dir / file.name) + for file in os.listdir(source_path): + if os.path.isfile(os.path.join(source_path, file)): + shutil.copy2(file, os.path.join(storage_path, file)) def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ @@ -360,7 +319,7 @@ def _register_other_model(self, model_info: ModelInfo): f"Registered other type model: {model_info.model_id} ({model_info.model_type})" ) - def ensure_transformers_registered(self, model_id: str) -> "ModelInfo": + def ensure_transformers_registered(self, model_id: str) -> ModelInfo: """ Ensure Transformers model is registered (called for lazy registration) This method uses locks to ensure thread safety. All check logic is within lock protection. @@ -422,54 +381,50 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: code=TSStatusCode.SUCCESS_STATUS.value, message="Show models successfully", ) + if req.modelId: + # Find specified model + model_info = None + for category_dict in self._models.values(): + if req.modelId in category_dict: + model_info = category_dict[req.modelId] + break - # Use global lock to protect entire dictionary structure - with self._lock_pool.get_lock("").read_lock(): - if req.modelId: - # Find specified model - model_info = None - for category_dict in self._models.values(): - if req.modelId in category_dict: - model_info = category_dict[req.modelId] - break - - if model_info: - return TShowModelsResp( - status=resp_status, - modelIdList=[req.modelId], - modelTypeMap={req.modelId: model_info.model_type}, - categoryMap={req.modelId: model_info.category.value}, - stateMap={req.modelId: model_info.state.value}, - ) - else: - return TShowModelsResp( - status=resp_status, - modelIdList=[], - modelTypeMap={}, - categoryMap={}, - stateMap={}, - ) + if model_info: + return TShowModelsResp( + status=resp_status, + modelIdList=[req.modelId], + modelTypeMap={req.modelId: model_info.model_type}, + categoryMap={req.modelId: model_info.category.value}, + stateMap={req.modelId: model_info.state.value}, + ) else: - # Return all models - model_id_list = [] - model_type_map = {} - category_map = {} - state_map = {} - - for category_dict in self._models.values(): - for model_id, model_info in category_dict.items(): - model_id_list.append(model_id) - model_type_map[model_id] = model_info.model_type - category_map[model_id] = model_info.category.value - state_map[model_id] = model_info.state.value - return TShowModelsResp( status=resp_status, - modelIdList=model_id_list, - modelTypeMap=model_type_map, - categoryMap=category_map, - stateMap=state_map, + modelIdList=[], + modelTypeMap={}, + categoryMap={}, + stateMap={}, ) + # Return all models + model_id_list = [] + model_type_map = {} + category_map = {} + state_map = {} + + for category_dict in self._models.values(): + for model_id, model_info in category_dict.items(): + model_id_list.append(model_id) + model_type_map[model_id] = model_info.model_type + category_map[model_id] = model_info.category.value + state_map[model_id] = model_info.state.value + + return TShowModelsResp( + status=resp_status, + modelIdList=model_id_list, + modelTypeMap=model_type_map, + categoryMap=category_map, + stateMap=state_map, + ) def delete_model(self, model_id: str) -> None: # Use write lock to protect entire deletion process @@ -488,7 +443,7 @@ def delete_model(self, model_id: str) -> None: if model_info.category == ModelCategory.BUILTIN: raise BuiltInModelDeletionError(model_id) - + model_info.state = ModelStates.DROPPING model_path = Path(model_info.path) if model_path.exists(): try: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sktime_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index 544193e4d9c65..dc1de32506e57 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -16,13 +16,10 @@ # under the License. # -import os -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors from torch import nn from transformers import Cache, DynamicCache, PreTrainedModel from transformers.activations import ACT2FN @@ -32,13 +29,10 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin -logger = Logger() - def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py similarity index 95% rename from iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 8d0909954bf24..4d761d0f00ad0 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/sundial_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -16,11 +16,9 @@ # under the License. # -import pandas as pd import torch from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -from iotdb.ainode.core.util.serde import convert_to_binary class SundialPipeline(ForecastPipeline): diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py similarity index 98% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py index 37bf56dfc59a7..fc9d7b41388bc 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py @@ -16,7 +16,7 @@ # under the License. # -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -29,11 +29,8 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin - -logger = Logger() +from iotdb.ainode.core.model.timer_xl.configuration_timer import TimerConfig +from iotdb.ainode.core.model.timer_xl.ts_generation_mixin import TSGenerationMixin def rotate_half(x): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py similarity index 92% rename from iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index cf6d35c805c5d..36e91e9f91b4e 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/timerxl_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -16,14 +16,12 @@ # under the License. # -import pandas as pd import torch from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -from iotdb.ainode.core.util.serde import convert_to_binary -class TimerxlPipeline(ForecastPipeline): +class TimerPipeline(ForecastPipeline): def __init__(self, model_id, **infer_kwargs): super().__init__(model_id, infer_kwargs=infer_kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py index 9da8486d3905e..93ad2ab1620ac 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -18,14 +18,13 @@ import importlib import json +import os.path import sys from contextlib import contextmanager from pathlib import Path from typing import Dict, Tuple -from iotdb.ainode.core.model.model_enums import ( - MODEL_CONFIG_FILE, - MODEL_WEIGHTS_FILE, +from iotdb.ainode.core.model.model_constants import ( UriType, ) @@ -58,7 +57,7 @@ def temporary_sys_path(path: str): sys.path.remove(path) -def load_model_config(config_path: Path) -> Dict: +def load_model_config_in_json(config_path: Path) -> Dict: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) @@ -87,8 +86,10 @@ def import_class_from_path(module_name, class_path: str): return getattr(module, class_name) -def ensure_init_file(path: Path): - """Ensure __init__.py file exists in the given path""" - init_file = path / "__init__.py" - if not init_file.exists(): - init_file.touch() +def ensure_init_file(dir_path: str): + """Ensure __init__.py file exists in the given dir path""" + init_file = os.path.join(dir_path, "__init__.py") + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(init_file): + with open(init_file, 'w'): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index de5de968a7e52..48460e81b87ca 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -20,7 +20,7 @@ from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.cluster_manager import ClusterManager from iotdb.ainode.core.manager.inference_manager import InferenceManager -from iotdb.ainode.core.manager.model_manager import get_model_manager +from iotdb.ainode.core.manager.model_manager import ModelManager from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.thrift.ainode import IAINodeRPCService @@ -40,17 +40,30 @@ TShowModelsReq, TShowModelsResp, TTrainingReq, - TUnloadModelReq, + TUnloadModelReq, TForecastResp, ) from iotdb.thrift.common.ttypes import TSStatus logger = Logger() +def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus: + """ + Ensure that the device IDs in the provided list are available. + """ + available_devices = get_available_devices() + for device_id in device_id_list: + if device_id not in available_devices: + return TSStatus( + code=TSStatusCode.INVALID_URI_ERROR.value, + message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + ) + return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) + class AINodeRPCServiceHandler(IAINodeRPCService.Iface): def __init__(self, ainode): self._ainode = ainode - self._model_manager = get_model_manager() + self._model_manager = ModelManager() self._inference_manager = InferenceManager() def stop(self) -> None: @@ -71,48 +84,27 @@ def showModels(self, req: TShowModelsReq) -> TShowModelsResp: return self._model_manager.show_models(req) def loadModel(self, req: TLoadModelReq) -> TSStatus: - if not self._model_manager.is_model_registered(req.existingModelId): - return TSStatus( - code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{req.existingModelId}] is not supported. You can use 'SHOW MODELS' to retrieve the available models.", - ) - - available_devices = get_available_devices() - for device_id in req.deviceIdList: - if device_id not in available_devices: - return TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", - ) - + status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return status + status = _ensure_device_id_is_available(req.deviceIdList) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return status return self._inference_manager.load_model(req) def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - if not self._model_manager.is_model_registered(req.modelId): - return TSStatus( - code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{req.modelId}] is not supported. You can use 'SHOW MODELS' to retrieve the available models.", - ) - - available_devices = get_available_devices() - for device_id in req.deviceIdList: - if device_id not in available_devices: - return TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", - ) - + status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return status + status = _ensure_device_id_is_available(req.deviceIdList) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return status return self._inference_manager.unload_model(req) def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: - available_devices = get_available_devices() - for device_id in req.deviceIdList: - if device_id not in available_devices: - status = TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", - ) - return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) + status = _ensure_device_id_is_available(req.deviceIdList) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) return self._inference_manager.show_loaded_models(req) def showAIDevices(self) -> TShowAIDevicesResp: @@ -122,9 +114,15 @@ def showAIDevices(self) -> TShowAIDevicesResp: ) def inference(self, req: TInferenceReq) -> TInferenceResp: + status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TInferenceResp(status, []) return self._inference_manager.inference(req) - def forecast(self, req: TForecastReq) -> TSStatus: + def forecast(self, req: TForecastReq) -> TForecastResp: + status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TForecastResp(status, []) return self._inference_manager.forecast(req) def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: @@ -132,3 +130,11 @@ def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass + + def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: + if not self._model_manager.is_model_registered(model_id): + return TSStatus( + code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, + message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", + ) + return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) \ No newline at end of file diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java index 2721fedafb1e6..8d9081f435273 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java @@ -21,21 +21,28 @@ import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientPoolFactory; import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; +/** Asynchronously send RPC requests to AINodes. */ public class AsyncAINodeHeartbeatClientPool { - private final IClientManager clientManager; + private final IClientManager clientManager; private AsyncAINodeHeartbeatClientPool() { clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeHeartbeatClientPoolFactory()); + new IClientManager.Factory() + .createClientManager( + new ClientPoolFactory.AsyncAINodeHeartbeatServiceClientPoolFactory()); } + /** + * Only used in LoadManager. + * + * @param endPoint The specific DataNode + */ public void getAINodeHeartBeat( TEndPoint endPoint, TAIHeartbeatReq req, AINodeHeartbeatHandler handler) { try { @@ -56,6 +63,6 @@ private AsyncAINodeHeartbeatClientPoolHolder() { } public static AsyncAINodeHeartbeatClientPool getInstance() { - return AsyncAINodeHeartbeatClientPool.AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; + return AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; } } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java index ccc19f1a9f382..324e351302787 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java @@ -63,7 +63,6 @@ public void writeAuditLog( } } - // TODO: Is the AsyncDataNodeHeartbeatClientPool must be a singleton? private static class AsyncDataNodeHeartbeatClientPoolHolder { private static final AsyncDataNodeHeartbeatClientPool INSTANCE = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java index e0b2c144c0eac..23b0a4e149d88 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java @@ -21,8 +21,6 @@ import org.apache.iotdb.commons.exception.runtime.SerializationRunTimeException; import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.subscription.ShowTopicPlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; @@ -583,15 +581,9 @@ public static ConfigPhysicalPlan create(final ByteBuffer buffer) throws IOExcept case DropModel: plan = new DropModelPlan(); break; - case ShowModel: - plan = new ShowModelPlan(); - break; case DropModelInNode: plan = new DropModelInNodePlan(); break; - case GetModelInfo: - plan = new GetModelInfoPlan(); - break; case CreatePipePlugin: plan = new CreatePipePluginPlan(); break; diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java deleted file mode 100644 index dd79910e51fa8..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; - -import java.util.Objects; - -public class GetModelInfoPlan extends ConfigPhysicalReadPlan { - - private String modelId; - - public GetModelInfoPlan() { - super(ConfigPhysicalPlanType.GetModelInfo); - } - - public GetModelInfoPlan(final TGetModelInfoReq getModelInfoReq) { - super(ConfigPhysicalPlanType.GetModelInfo); - this.modelId = getModelInfoReq.getModelId(); - } - - public String getModelId() { - return modelId; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final GetModelInfoPlan that = (GetModelInfoPlan) o; - return Objects.equals(modelId, that.modelId); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java deleted file mode 100644 index eca00e8827d96..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; - -import java.util.Objects; - -public class ShowModelPlan extends ConfigPhysicalReadPlan { - - private String modelName; - - public ShowModelPlan() { - super(ConfigPhysicalPlanType.ShowModel); - } - - public ShowModelPlan(final TShowModelsReq showModelReq) { - super(ConfigPhysicalPlanType.ShowModel); - if (showModelReq.isSetModelId()) { - this.modelName = showModelReq.getModelId(); - } - } - - public boolean isSetModelName() { - return modelName != null; - } - - public String getModelName() { - return modelName; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final ShowModelPlan that = (ShowModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java deleted file mode 100644 index cebc1301b8912..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.consensus.common.DataSet; - -public class GetModelInfoResp implements DataSet { - - private final TSStatus status; - - private int targetAINodeId; - private TEndPoint targetAINodeAddress; - - public TSStatus getStatus() { - return status; - } - - public GetModelInfoResp(TSStatus status) { - this.status = status; - } - - public int getTargetAINodeId() { - return targetAINodeId; - } - - public void setTargetAINodeId(int targetAINodeId) { - this.targetAINodeId = targetAINodeId; - } - - public void setTargetAINodeAddress(TAINodeConfiguration aiNodeConfiguration) { - if (aiNodeConfiguration.getLocation() == null) { - return; - } - this.targetAINodeAddress = aiNodeConfiguration.getLocation().getInternalEndPoint(); - } - - public TGetModelInfoResp convertToThriftResponse() { - TGetModelInfoResp resp = new TGetModelInfoResp(status); - resp.setAiNodeAddress(targetAINodeAddress); - return resp; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java deleted file mode 100644 index 7490a53a01c57..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.consensus.common.DataSet; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -// TODO: Will be removed in the future -public class ModelTableResp implements DataSet { - - private final TSStatus status; - private final List serializedAllModelInformation; - private Map modelTypeMap; - private Map algorithmMap; - - public ModelTableResp(TSStatus status) { - this.status = status; - this.serializedAllModelInformation = new ArrayList<>(); - } - - public void addModelInformation(List modelInformationList) throws IOException { - for (ModelInformation modelInformation : modelInformationList) { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - } - - public void addModelInformation(ModelInformation modelInformation) throws IOException { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - - public void setModelTypeMap(Map modelTypeMap) { - this.modelTypeMap = modelTypeMap; - } - - public void setAlgorithmMap(Map algorithmMap) { - this.algorithmMap = algorithmMap; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index 5d4b09adfc710..9d7151a8d20e3 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -19,22 +19,12 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.IDataSchema; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TFlushReq; import org.apache.iotdb.common.rpc.thrift.TPipeHeartbeatResp; import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet; @@ -58,7 +48,6 @@ import org.apache.iotdb.commons.conf.TrimProperties; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelStatus; import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.path.PathPatternTree; import org.apache.iotdb.commons.path.PathPatternUtil; @@ -97,7 +86,6 @@ import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; import org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan; import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp; @@ -129,7 +117,6 @@ import org.apache.iotdb.confignode.manager.schema.ClusterSchemaQuotaStatistics; import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -144,7 +131,6 @@ import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo; import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo; import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; @@ -163,13 +149,11 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq; @@ -186,7 +170,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -203,8 +186,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -257,11 +238,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.consensus.common.DataSet; import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.schemaengine.template.Template; import org.apache.iotdb.db.schemaengine.template.TemplateAlterOperationType; import org.apache.iotdb.db.schemaengine.template.alter.TemplateAlterOperationUtil; @@ -340,9 +318,6 @@ public class ConfigManager implements IManager { /** CQ. */ private final CQManager cqManager; - /** AI Model. */ - private final ModelManager modelManager; - /** Pipe */ private final PipeManager pipeManager; @@ -362,8 +337,6 @@ public class ConfigManager implements IManager { private static final String DATABASE = "\tDatabase="; - private static final String DOT = "."; - public ConfigManager() throws IOException { // Build the persistence module ClusterInfo clusterInfo = new ClusterInfo(); @@ -375,7 +348,6 @@ public ConfigManager() throws IOException { UDFInfo udfInfo = new UDFInfo(); TriggerInfo triggerInfo = new TriggerInfo(); CQInfo cqInfo = new CQInfo(); - ModelInfo modelInfo = new ModelInfo(); PipeInfo pipeInfo = new PipeInfo(); QuotaInfo quotaInfo = new QuotaInfo(); TTLInfo ttlInfo = new TTLInfo(); @@ -393,7 +365,6 @@ public ConfigManager() throws IOException { udfInfo, triggerInfo, cqInfo, - modelInfo, pipeInfo, subscriptionInfo, quotaInfo, @@ -415,7 +386,6 @@ public ConfigManager() throws IOException { this.udfManager = new UDFManager(this, udfInfo); this.triggerManager = new TriggerManager(this, triggerInfo); this.cqManager = new CQManager(this); - this.modelManager = new ModelManager(this, modelInfo); this.pipeManager = new PipeManager(this, pipeInfo); this.subscriptionManager = new SubscriptionManager(this, subscriptionInfo); this.auditLogger = new CNAuditLogger(this); @@ -1289,11 +1259,6 @@ public TriggerManager getTriggerManager() { return triggerManager; } - @Override - public ModelManager getModelManager() { - return modelManager; - } - @Override public PipeManager getPipeManager() { return pipeManager; @@ -2757,150 +2722,6 @@ public TSStatus transfer(List newUnknownDataList) { return transferResult; } - @Override - public TSStatus createModel(TCreateModelReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.createModel(req) - : status; - } - - private List fetchSchemaForTreeModel(TCreateTrainingReq req) { - List dataSchemaList = new ArrayList<>(); - for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) { - IDataSchema dataSchema = new IDataSchema(req.getDataSchemaForTree().getPath().get(i)); - dataSchema.setTimeRange(req.getTimeRanges().get(i)); - dataSchemaList.add(dataSchema); - } - return dataSchemaList; - } - - private List fetchSchemaForTableModel(TCreateTrainingReq req) { - return Collections.singletonList(new IDataSchema(req.getDataSchemaForTable().getTargetSql())); - } - - public TSStatus createTraining(TCreateTrainingReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - - TTrainingReq trainingReq = new TTrainingReq(); - trainingReq.setModelId(req.getModelId()); - if (req.isSetExistingModelId()) { - trainingReq.setExistingModelId(req.getExistingModelId()); - } - if (req.isSetParameters() && !req.getParameters().isEmpty()) { - trainingReq.setParameters(req.getParameters()); - } - - try { - status = getConsensusManager().write(new CreateModelPlan(req.getModelId())); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new MetadataException("Can't init model " + req.getModelId()); - } - - List dataSchema; - if (req.isTableModel) { - dataSchema = fetchSchemaForTableModel(req); - trainingReq.setDbType("iotdb.table"); - } else { - dataSchema = fetchSchemaForTreeModel(req); - trainingReq.setDbType("iotdb.tree"); - } - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal())); - trainingReq.setTargetDataSchema(dataSchema); - - TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0); - TEndPoint targetAINodeEndPoint = - new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort()); - try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) { - status = client.createTrainingTask(trainingReq); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new IllegalArgumentException(status.message); - } - } - } catch (final Exception e) { - status.setCode(TSStatusCode.CAN_NOT_CONNECT_CONFIGNODE.getStatusCode()); - status.setMessage(e.getMessage()); - try { - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.UNAVAILABLE.ordinal())); - } catch (Exception e2) { - LOGGER.error(e2.getMessage()); - } - } - return status; - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.dropModel(req) - : status; - } - - @Override - public TSStatus loadModel(TLoadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.loadModel(req) - : status; - } - - @Override - public TSStatus unloadModel(TUnloadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.unloadModel(req) - : status; - } - - @Override - public TShowModelsResp showModel(TShowModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showModel(req) - : new TShowModelsResp(status); - } - - @Override - public TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showLoadedModel(req) - : new TShowLoadedModelsResp(status, Collections.emptyMap()); - } - - @Override - public TShowAIDevicesResp showAIDevices() { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showAIDevices() - : new TShowAIDevicesResp(status, Collections.emptyList()); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.getModelInfo(req) - : new TGetModelInfoResp(status); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.updateModelInfo(req) - : status; - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) { TSStatus status = confirmLeader(); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java index 33e77db24907d..dff994d70e7e3 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java @@ -19,13 +19,6 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; @@ -82,7 +75,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -103,7 +95,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -120,8 +111,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -255,13 +244,6 @@ public interface IManager { */ CQManager getCQManager(); - /** - * Get {@link ModelManager}. - * - * @return {@link ModelManager} instance - */ - ModelManager getModelManager(); - /** * Get {@link PipeManager}. * @@ -880,30 +862,6 @@ TDataPartitionTableResp getOrCreateDataPartition( TSStatus transfer(List newUnknownDataList); - /** Create a model. */ - TSStatus createModel(TCreateModelReq req); - - /** Drop a model. */ - TSStatus dropModel(TDropModelReq req); - - /** Load the specific model to the specific devices. */ - TSStatus loadModel(TLoadModelReq req); - - /** Unload the specific model from the specific devices. */ - TSStatus unloadModel(TUnloadModelReq req); - - /** Return the model table. */ - TShowModelsResp showModel(TShowModelsReq req); - - /** Return the loaded model instances. */ - TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req); - - /** Return all available AI devices. */ - TShowAIDevicesResp showAIDevices(); - - /** Update the model state */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req); - /** Set space quota. */ TSStatus setSpaceQuota(TSetSpaceQuotaReq req); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java deleted file mode 100644 index 3efdbc222b6d2..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.manager; - -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.exception.NoAvailableAINodeException; -import org.apache.iotdb.confignode.persistence.ModelInfo; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class ModelManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class); - - private final ConfigManager configManager; - private final ModelInfo modelInfo; - - public ModelManager(ConfigManager configManager, ModelInfo modelInfo) { - this.configManager = configManager; - this.modelInfo = modelInfo; - } - - public TSStatus createModel(TCreateModelReq req) { - if (modelInfo.contain(req.modelName)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s already exists", req.modelName)); - } - try { - if (req.uri.isEmpty()) { - return configManager.getConsensusManager().write(new CreateModelPlan(req.modelName)); - } - return configManager.getProcedureManager().createModel(req.modelName, req.uri); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while getting model: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - public TSStatus dropModel(TDropModelReq req) { - if (modelInfo.checkModelType(req.getModelId()) != ModelType.USER_DEFINED) { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("Built-in model %s can't be removed", req.modelId)); - } - if (!modelInfo.contain(req.modelId)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s doesn't exists", req.modelId)); - } - return configManager.getProcedureManager().dropModel(req.getModelId()); - } - - public TSStatus loadModel(TLoadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq loadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq( - req.existingModelId, req.deviceIdList); - return client.loadModel(loadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to load model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TSStatus unloadModel(TUnloadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq unloadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq(req.modelId, req.deviceIdList); - return client.unloadModel(unloadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to unload model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TShowModelsResp showModel(final TShowModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowModelsReq showModelsReq = new TShowModelsReq(); - if (req.isSetModelId()) { - showModelsReq.setModelId(req.getModelId()); - } - TShowModelsResp resp = client.showModels(showModelsReq); - TShowModelsResp res = - new TShowModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setModelIdList(resp.getModelIdList()); - res.setModelTypeMap(resp.getModelTypeMap()); - res.setCategoryMap(resp.getCategoryMap()); - res.setStateMap(resp.getStateMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show models due to", e); - return new TShowModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowLoadedModelsResp showLoadedModel(final TShowLoadedModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowLoadedModelsReq showModelsReq = - new TShowLoadedModelsReq().setDeviceIdList(req.getDeviceIdList()); - TShowLoadedModelsResp resp = client.showLoadedModels(showModelsReq); - TShowLoadedModelsResp res = - new TShowLoadedModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceLoadedModelsMap(resp.getDeviceLoadedModelsMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show loaded models due to", e); - return new TShowLoadedModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowAIDevicesResp showAIDevices() { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp resp = client.showAIDevices(); - TShowAIDevicesResp res = - new TShowAIDevicesResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceIdList(resp.getDeviceIdList()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show AI devices due to", e); - return new TShowAIDevicesResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return new TGetModelInfoResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())) - .setAiNodeAddress( - configManager - .getNodeManager() - .getRegisteredAINodes() - .get(0) - .getLocation() - .getInternalEndPoint()); - } - - // Currently this method is only used by built-in timer_xl - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - if (!modelInfo.contain(req.getModelId())) { - return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode()) - .setMessage(String.format("Model %s doesn't exists", req.getModelId())); - } - try { - ModelInformation modelInformation = - new ModelInformation(ModelType.USER_DEFINED, req.getModelId()); - modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]); - modelInformation.setAttribute(req.getAttributes()); - modelInformation.setInputColumnSize(1); - if (req.isSetOutputLength()) { - modelInformation.setOutputLength(req.getOutputLength()); - } - if (req.isSetInputLength()) { - modelInformation.setInputLength(req.getInputLength()); - } - UpdateModelInfoPlan updateModelInfoPlan = - new UpdateModelInfoPlan(req.getModelId(), modelInformation); - if (req.isSetAiNodeIds()) { - updateModelInfoPlan.setNodeIds(req.getAiNodeIds()); - } - return configManager.getConsensusManager().write(updateModelInfoPlan); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while updating model info: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - private AINodeClient getAINodeClient() throws NoAvailableAINodeException, ClientManagerException { - List aiNodeInfo = configManager.getNodeManager().getRegisteredAINodeInfoList(); - if (aiNodeInfo.isEmpty()) { - throw new NoAvailableAINodeException(); - } - TEndPoint targetAINodeEndPoint = - new TEndPoint(aiNodeInfo.get(0).getInternalAddress(), aiNodeInfo.get(0).getInternalPort()); - try { - return AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public List getModelDistributions(String modelName) { - return modelInfo.getNodeIds(modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java index 2e4227af3fc8b..d67e7721eef83 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java @@ -61,8 +61,6 @@ import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler; import org.apache.iotdb.confignode.procedure.env.RemoveDataNodeHandler; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -1414,24 +1412,6 @@ public TSStatus createCQ(TCreateCQReq req, ScheduledExecutorService scheduledExe return waitingProcedureFinished(procedure); } - public TSStatus createModel(String modelName, String uri) { - long procedureId = executor.submitProcedure(new CreateModelProcedure(modelName, uri)); - LOGGER.info("CreateModelProcedure was submitted, procedureId: {}.", procedureId); - return RpcUtils.SUCCESS_STATUS; - } - - public TSStatus dropModel(String modelId) { - DropModelProcedure procedure = new DropModelProcedure(modelId); - executor.submitProcedure(procedure); - TSStatus status = waitingProcedureFinished(procedure); - if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return status; - } else { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(status.getMessage()); - } - } - public TSStatus createPipePlugin( PipePluginMeta pipePluginMeta, byte[] jarFile, boolean isSetIfNotExistsCondition) { final CreatePipePluginProcedure createPipePluginProcedure = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java deleted file mode 100644 index aeada03d15cc3..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.persistence; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelTable; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.commons.snapshot.SnapshotProcessor; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; -import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.PublicBAOS; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.annotation.concurrent.ThreadSafe; - -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; - -@ThreadSafe -public class ModelInfo implements SnapshotProcessor { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelInfo.class); - - private static final String SNAPSHOT_FILENAME = "model_info.snapshot"; - - private ModelTable modelTable; - - private final Map> modelNameToNodes; - - private final ReadWriteLock modelTableLock = new ReentrantReadWriteLock(); - - private static final Set builtInForecastModel = new HashSet<>(); - - private static final Set builtInAnomalyDetectionModel = new HashSet<>(); - - static { - builtInForecastModel.add("arima"); - builtInForecastModel.add("naive_forecaster"); - builtInForecastModel.add("stl_forecaster"); - builtInForecastModel.add("holtwinters"); - builtInForecastModel.add("exponential_smoothing"); - builtInForecastModel.add("timer_xl"); - builtInForecastModel.add("sundial"); - builtInAnomalyDetectionModel.add("gaussian_hmm"); - builtInAnomalyDetectionModel.add("gmm_hmm"); - builtInAnomalyDetectionModel.add("stray"); - } - - public ModelInfo() { - this.modelTable = new ModelTable(); - this.modelNameToNodes = new HashMap<>(); - } - - public boolean contain(String modelName) { - return modelTable.containsModel(modelName); - } - - public void acquireModelTableReadLock() { - LOGGER.info("acquire ModelTableReadLock"); - modelTableLock.readLock().lock(); - } - - public void releaseModelTableReadLock() { - LOGGER.info("release ModelTableReadLock"); - modelTableLock.readLock().unlock(); - } - - public void acquireModelTableWriteLock() { - LOGGER.info("acquire ModelTableWriteLock"); - modelTableLock.writeLock().lock(); - } - - public void releaseModelTableWriteLock() { - LOGGER.info("release ModelTableWriteLock"); - modelTableLock.writeLock().unlock(); - } - - // init the model in modeInfo, it won't update the details information of the model - public TSStatus createModel(CreateModelPlan plan) { - try { - acquireModelTableWriteLock(); - String modelName = plan.getModelName(); - modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING)); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } catch (Exception e) { - final String errorMessage = - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes, because of %s", - plan.getModelName(), e); - LOGGER.warn(errorMessage, e); - return new TSStatus(TSStatusCode.CREATE_MODEL_ERROR.getStatusCode()).setMessage(errorMessage); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModelInNode(int aiNodeId) { - acquireModelTableWriteLock(); - try { - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - entry.getValue().remove(Integer.valueOf(aiNodeId)); - // if list is empty, remove this model totally - if (entry.getValue().isEmpty()) { - modelTable.removeModel(entry.getKey()); - modelNameToNodes.remove(entry.getKey()); - } - } - // currently, we only have one AINode at a time, so we can just clear failed model. - modelTable.clearFailedModel(); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModel(String modelName) { - acquireModelTableWriteLock(); - TSStatus status; - if (modelTable.containsModel(modelName)) { - modelTable.removeModel(modelName); - modelNameToNodes.remove(modelName); - status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - status = - new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("model [%s] has not been created.", modelName)); - } - releaseModelTableWriteLock(); - return status; - } - - public List getNodeIds(String modelName) { - return modelNameToNodes.getOrDefault(modelName, Collections.emptyList()); - } - - private ModelInformation getModelByName(String modelName) { - ModelType modelType = checkModelType(modelName); - if (modelType != ModelType.USER_DEFINED) { - if (modelType == ModelType.BUILT_IN_FORECAST && builtInForecastModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName); - } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION - && builtInAnomalyDetectionModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName); - } - } else { - return modelTable.getModelInformationById(modelName); - } - return null; - } - - public ModelTableResp showModel(ShowModelPlan plan) { - acquireModelTableReadLock(); - try { - ModelTableResp modelTableResp = - new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - if (plan.isSetModelName()) { - ModelInformation modelInformation = getModelByName(plan.getModelName()); - if (modelInformation != null) { - modelTableResp.addModelInformation(modelInformation); - } - } else { - modelTableResp.addModelInformation(modelTable.getAllModelInformation()); - for (String modelName : builtInForecastModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName)); - } - for (String modelName : builtInAnomalyDetectionModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName)); - } - } - return modelTableResp; - } catch (IOException e) { - LOGGER.warn("Fail to get ModelTable", e); - return new ModelTableResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - private boolean containsBuiltInModelName(Set builtInModelSet, String modelName) { - // ignore the case - for (String builtInModelName : builtInModelSet) { - if (builtInModelName.equalsIgnoreCase(modelName)) { - return true; - } - } - return false; - } - - public ModelType checkModelType(String modelName) { - if (containsBuiltInModelName(builtInForecastModel, modelName)) { - return ModelType.BUILT_IN_FORECAST; - } else if (containsBuiltInModelName(builtInAnomalyDetectionModel, modelName)) { - return ModelType.BUILT_IN_ANOMALY_DETECTION; - } else { - return ModelType.USER_DEFINED; - } - } - - private int getAvailableAINodeForModel(String modelName, ModelType modelType) { - if (modelType == ModelType.USER_DEFINED) { - List aiNodeIds = modelNameToNodes.get(modelName); - if (aiNodeIds != null) { - return aiNodeIds.get(0); - } - } else { - // any AINode is fine for built-in model - // 0 is always the nodeId for configNode, so it's fine to use 0 as special value - return 0; - } - return -1; - } - - // This method will be used by dataNode to get schema of the model for inference - public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { - acquireModelTableReadLock(); - try { - String modelName = plan.getModelId(); - GetModelInfoResp getModelInfoResp; - ModelInformation modelInformation; - ModelType modelType; - // check if it's a built-in model - if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { - modelInformation = new ModelInformation(modelType, modelName); - } else { - modelInformation = modelTable.getModelInformationById(modelName); - } - - if (modelInformation != null) { - getModelInfoResp = - new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - } else { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("model [%s] has not been created.", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } - PublicBAOS buffer = new PublicBAOS(); - DataOutputStream stream = new DataOutputStream(buffer); - modelInformation.serialize(stream); - // select the nodeId to process the task, currently we default use the first one. - int aiNodeId = getAvailableAINodeForModel(modelName, modelType); - if (aiNodeId == -1) { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("There is no AINode with %s available", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } else { - getModelInfoResp.setTargetAINodeId(aiNodeId); - } - return getModelInfoResp; - } catch (IOException e) { - LOGGER.warn("Fail to get model info", e); - return new GetModelInfoResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - public TSStatus updateModelInfo(UpdateModelInfoPlan plan) { - acquireModelTableWriteLock(); - try { - String modelName = plan.getModelName(); - if (modelTable.containsModel(modelName)) { - modelTable.updateModel(modelName, plan.getModelInformation()); - } - if (!plan.getNodeIds().isEmpty()) { - // only used in model registration, so we can just put the nodeIds in the map without - // checking - modelNameToNodes.put(modelName, plan.getNodeIds()); - } - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - @Override - public boolean processTakeSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (snapshotFile.exists() && snapshotFile.isFile()) { - LOGGER.error( - "Failed to take snapshot of ModelInfo, because snapshot file [{}] is already exist.", - snapshotFile.getAbsolutePath()); - return false; - } - - acquireModelTableReadLock(); - try (FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)) { - modelTable.serialize(fileOutputStream); - ReadWriteIOUtils.write(modelNameToNodes.size(), fileOutputStream); - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - ReadWriteIOUtils.write(entry.getKey(), fileOutputStream); - ReadWriteIOUtils.write(entry.getValue().size(), fileOutputStream); - for (Integer nodeId : entry.getValue()) { - ReadWriteIOUtils.write(nodeId, fileOutputStream); - } - } - fileOutputStream.getFD().sync(); - return true; - } finally { - releaseModelTableReadLock(); - } - } - - @Override - public void processLoadSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (!snapshotFile.exists() || !snapshotFile.isFile()) { - LOGGER.error( - "Failed to load snapshot of ModelInfo, snapshot file [{}] does not exist.", - snapshotFile.getAbsolutePath()); - return; - } - acquireModelTableWriteLock(); - try (FileInputStream fileInputStream = new FileInputStream(snapshotFile)) { - modelTable.clear(); - modelTable = ModelTable.deserialize(fileInputStream); - int size = ReadWriteIOUtils.readInt(fileInputStream); - for (int i = 0; i < size; i++) { - String modelName = ReadWriteIOUtils.readString(fileInputStream); - int nodeSize = ReadWriteIOUtils.readInt(fileInputStream); - List nodes = new LinkedList<>(); - for (int j = 0; j < nodeSize; j++) { - nodes.add(ReadWriteIOUtils.readInt(fileInputStream)); - } - modelNameToNodes.put(modelName, nodes); - } - } finally { - releaseModelTableWriteLock(); - } - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java index fe8b28c4da2e5..d6bad518f6f4b 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java @@ -35,8 +35,6 @@ import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetUDFJarPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetNodePathsPartitionPlan; @@ -84,10 +82,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -150,7 +144,6 @@ import org.apache.iotdb.confignode.exception.physical.UnknownPhysicalPlanTypeException; import org.apache.iotdb.confignode.manager.pipe.agent.PipeConfigNodeAgent; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -210,8 +203,6 @@ public class ConfigPlanExecutor { private final CQInfo cqInfo; - private final ModelInfo modelInfo; - private final PipeInfo pipeInfo; private final SubscriptionInfo subscriptionInfo; @@ -230,7 +221,6 @@ public ConfigPlanExecutor( UDFInfo udfInfo, TriggerInfo triggerInfo, CQInfo cqInfo, - ModelInfo modelInfo, PipeInfo pipeInfo, SubscriptionInfo subscriptionInfo, QuotaInfo quotaInfo, @@ -262,9 +252,6 @@ public ConfigPlanExecutor( this.cqInfo = cqInfo; this.snapshotProcessorList.add(cqInfo); - this.modelInfo = modelInfo; - this.snapshotProcessorList.add(modelInfo); - this.pipeInfo = pipeInfo; this.snapshotProcessorList.add(pipeInfo); @@ -362,10 +349,6 @@ public DataSet executeQueryPlan(final ConfigPhysicalReadPlan req) return udfInfo.getUDFJar((GetUDFJarPlan) req); case GetAllFunctionTable: return udfInfo.getAllUDFTable(); - case ShowModel: - return modelInfo.showModel((ShowModelPlan) req); - case GetModelInfo: - return modelInfo.getModelInfo((GetModelInfoPlan) req); case GetPipePluginTable: return pipeInfo.getPipePluginInfo().showPipePlugins(); case GetPipePluginJar: @@ -648,14 +631,6 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) return cqInfo.activeCQ((ActiveCQPlan) physicalPlan); case UPDATE_CQ_LAST_EXEC_TIME: return cqInfo.updateCQLastExecutionTime((UpdateCQLastExecTimePlan) physicalPlan); - case CreateModel: - return modelInfo.createModel((CreateModelPlan) physicalPlan); - case UpdateModelInfo: - return modelInfo.updateModelInfo((UpdateModelInfoPlan) physicalPlan); - case DropModel: - return modelInfo.dropModel(((DropModelPlan) physicalPlan).getModelName()); - case DropModelInNode: - return modelInfo.dropModelInNode(((DropModelInNodePlan) physicalPlan).getNodeId()); case CreatePipePlugin: return pipeInfo.getPipePluginInfo().createPipePlugin((CreatePipePluginPlan) physicalPlan); case DropPipePlugin: diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java deleted file mode 100644 index 989061610213d..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.manager.ConfigManager; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.CreateModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - -public class CreateModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class); - private static final int RETRY_THRESHOLD = 0; - - private String modelName; - - private String uri; - - private ModelInformation modelInformation = null; - - private List aiNodeIds; - - private String loadErrorMsg = ""; - - public CreateModelProcedure() { - super(); - } - - public CreateModelProcedure(String modelName, String uri) { - super(); - this.modelName = modelName; - this.uri = uri; - this.aiNodeIds = new ArrayList<>(); - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, CreateModelState state) { - if (modelName == null || uri == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case LOADING: - initModel(env); - loadModel(env); - setNextState(CreateModelState.ACTIVE); - break; - case ACTIVE: - modelInformation.updateStatus(ModelStatus.ACTIVE); - updateModel(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing createModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in CreateModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to create model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - modelInformation = new ModelInformation(modelName, ModelStatus.UNAVAILABLE); - modelInformation.setAttribute(loadErrorMsg); - updateModel(env); - setFailure( - new ProcedureException( - String.format("Fail to create model [%s] at STATE [%s]", modelName, state))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void initModel(ConfigNodeProcedureEnv env) throws ConsensusException { - LOGGER.info("Start to add model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - TSStatus response = configManager.getConsensusManager().write(new CreateModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } - - private void checkModelInformationEquals(ModelInformation receiveModelInfo) { - if (modelInformation == null) { - modelInformation = receiveModelInfo; - } else { - if (!modelInformation.equals(receiveModelInfo)) { - throw new ModelManagementException( - String.format( - "Failed to load model [%s] on AI Nodes, model information is not equal in different nodes", - modelName)); - } - } - } - - private void loadModel(ConfigNodeProcedureEnv env) { - for (TAINodeConfiguration curNodeConfig : - env.getConfigManager().getNodeManager().getRegisteredAINodes()) { - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient(curNodeConfig.getLocation().getInternalEndPoint())) { - ModelInformation resp = client.registerModel(modelName, uri); - checkModelInformationEquals(resp); - aiNodeIds.add(curNodeConfig.getLocation().aiNodeId); - } catch (LoadModelException e) { - LOGGER.warn(e.getMessage()); - loadErrorMsg = e.getMessage(); - } catch (Exception e) { - LOGGER.warn( - "Failed to load model on AINode {} from ConfigNode", - curNodeConfig.getLocation().getInternalEndPoint()); - loadErrorMsg = e.getMessage(); - } - } - - if (aiNodeIds.isEmpty()) { - throw new ModelManagementException( - String.format("CREATE MODEL [%s] failed on all AINodes:[%s]", modelName, loadErrorMsg)); - } - } - - private void updateModel(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to update model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - try { - TSStatus response = - configManager - .getConsensusManager() - .write(new UpdateModelInfoPlan(modelName, modelInformation, aiNodeIds)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, CreateModelState state) - throws IOException, InterruptedException, ProcedureException { - // do nothing - } - - @Override - protected boolean isRollbackSupported(CreateModelState state) { - return false; - } - - @Override - protected CreateModelState getState(int stateId) { - return CreateModelState.values()[stateId]; - } - - @Override - protected int getStateId(CreateModelState createModelState) { - return createModelState.ordinal(); - } - - @Override - protected CreateModelState getInitialState() { - return CreateModelState.LOADING; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - ReadWriteIOUtils.write(uri, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - uri = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof CreateModelProcedure) { - CreateModelProcedure thatProc = (CreateModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && Objects.equals(thatProc.modelName, this.modelName) - && Objects.equals(thatProc.uri, this.uri); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName, uri); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java deleted file mode 100644 index daa029e04ddfd..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.DropModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Objects; - -import static org.apache.iotdb.confignode.procedure.state.model.DropModelState.CONFIG_NODE_DROPPED; - -public class DropModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class); - private static final int RETRY_THRESHOLD = 1; - - private String modelName; - - public DropModelProcedure() { - super(); - } - - public DropModelProcedure(String modelName) { - super(); - this.modelName = modelName; - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state) { - if (modelName == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case AI_NODE_DROPPED: - LOGGER.info("Start to drop model [{}] on AI Nodes", modelName); - dropModelOnAINode(env); - setNextState(CONFIG_NODE_DROPPED); - break; - case CONFIG_NODE_DROPPED: - dropModelOnConfigNode(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing dropModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in DropModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to drop model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - setFailure( - new ProcedureException( - String.format( - "Fail to drop model [%s] at STATE [%s], %s", - modelName, state, e.getMessage()))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void dropModelOnAINode(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to drop model file [{}] on AI Node", modelName); - - List aiNodes = - env.getConfigManager().getNodeManager().getRegisteredAINodes(); - aiNodes.forEach( - aiNode -> { - int nodeId = aiNode.getLocation().getAiNodeId(); - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient( - env.getConfigManager() - .getNodeManager() - .getRegisteredAINode(nodeId) - .getLocation() - .getInternalEndPoint())) { - TSStatus status = client.deleteModel(new TDeleteModelReq(modelName)); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - status.getMessage()); - } - } catch (Exception e) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - e.getMessage()); - } - }); - } - - private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) { - try { - TSStatus response = - env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(response.getMessage()); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Fail to start training model [%s] on AI Node: %s", modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, DropModelState state) - throws IOException, InterruptedException, ProcedureException { - // no need to rollback - } - - @Override - protected DropModelState getState(int stateId) { - return DropModelState.values()[stateId]; - } - - @Override - protected int getStateId(DropModelState dropModelState) { - return dropModelState.ordinal(); - } - - @Override - protected DropModelState getInitialState() { - return DropModelState.AI_NODE_DROPPED; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.DROP_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof DropModelProcedure) { - DropModelProcedure thatProc = (DropModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && (thatProc.modelName).equals(this.modelName); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java index 2cab08c28244e..98056fc1768ea 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java @@ -28,8 +28,8 @@ import org.apache.iotdb.confignode.procedure.exception.ProcedureException; import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState; import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.slf4j.Logger; @@ -75,7 +75,8 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState st case NODE_STOP: TSStatus resp = null; try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(removedAINode.getInternalEndPoint())) { + AINodeClientManager.getInstance() + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { resp = client.stopAINode(); } catch (Exception e) { LOGGER.warn( diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java index e023171f4fa88..f20a6999d5936 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java @@ -22,8 +22,6 @@ import org.apache.iotdb.commons.exception.runtime.ThriftSerDeException; import org.apache.iotdb.confignode.procedure.Procedure; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -263,12 +261,6 @@ public Procedure create(ByteBuffer buffer) throws IOException { case DROP_PIPE_PLUGIN_PROCEDURE: procedure = new DropPipePluginProcedure(); break; - case CREATE_MODEL_PROCEDURE: - procedure = new CreateModelProcedure(); - break; - case DROP_MODEL_PROCEDURE: - procedure = new DropModelProcedure(); - break; case AUTH_OPERATE_PROCEDURE: procedure = new AuthOperationProcedure(false); break; @@ -494,10 +486,6 @@ public static ProcedureType getProcedureType(final Procedure procedure) { return ProcedureType.CREATE_PIPE_PLUGIN_PROCEDURE; } else if (procedure instanceof DropPipePluginProcedure) { return ProcedureType.DROP_PIPE_PLUGIN_PROCEDURE; - } else if (procedure instanceof CreateModelProcedure) { - return ProcedureType.CREATE_MODEL_PROCEDURE; - } else if (procedure instanceof DropModelProcedure) { - return ProcedureType.DROP_MODEL_PROCEDURE; } else if (procedure instanceof CreatePipeProcedureV2) { return ProcedureType.CREATE_PIPE_PROCEDURE_V2; } else if (procedure instanceof StartPipeProcedureV2) { diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java index 65ac1fb24ad5a..d076a7d9d926e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java @@ -85,7 +85,9 @@ public enum ProcedureType { RENAME_VIEW_PROCEDURE((short) 764), /** AI Model */ + @Deprecated // Since 2.0.6, all models are managed by AINode CREATE_MODEL_PROCEDURE((short) 800), + @Deprecated // Since 2.0.6, all models are managed by AINode DROP_MODEL_PROCEDURE((short) 801), REMOVE_AI_NODE_PROCEDURE((short) 802), diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java index 59ce7352312f0..6582a5bfff8e1 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java @@ -115,7 +115,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -144,7 +143,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -163,8 +161,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -226,7 +222,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.confignode.service.ConfigNode; import org.apache.iotdb.consensus.exception.ConsensusException; import org.apache.iotdb.db.queryengine.plan.relational.type.AuthorRType; @@ -1362,26 +1357,6 @@ public TShowCQResp showCQ() { return configManager.showCQ(); } - @Override - public TSStatus createModel(TCreateModelReq req) { - return configManager.createModel(req); - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - return configManager.dropModel(req); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return configManager.getModelInfo(req); - } - - @Override - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return configManager.updateModelInfo(req); - } - @Override public TSStatus setSpaceQuota(final TSetSpaceQuotaReq req) throws TException { return configManager.setSpaceQuota(req); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java deleted file mode 100644 index 0d784617c0905..0000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.ClientManagerMetrics; -import org.apache.iotdb.commons.client.IClientPoolFactory; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ClientPoolProperty; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.concurrent.ThreadName; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.commons.pool2.impl.GenericKeyedObjectPool; - -import java.util.Optional; - -/** Dedicated factory for AINodeClient + AINodeClientPoolFactory. */ -public class AINodeClientFactory extends ThriftClientFactory { - - private static final int connectionTimeout = - CommonDescriptor.getInstance().getConfig().getDnConnectionTimeoutInMS(); - - public AINodeClientFactory( - ClientManager manager, ThriftClientProperty thriftProperty) { - super(manager, thriftProperty); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public void destroyObject(TEndPoint key, PooledObject pooled) throws Exception { - pooled.getObject().invalidate(); - } - - @Override - public boolean validateObject(TEndPoint key, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(org.apache.thrift.transport.TTransport::isOpen) - .orElse(false); - } - - /** The PoolFactory originally inside ClientPoolFactory — now moved here. */ - public static class AINodeClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - // Build thrift client properties - ThriftClientProperty thriftProperty = - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(connectionTimeout) - .setRpcThriftCompressionEnabled( - CommonDescriptor.getInstance().getConfig().isRpcThriftCompressionEnabled()) - .build(); - - GenericKeyedObjectPool pool = - new GenericKeyedObjectPool<>( - new AINodeClientFactory(manager, thriftProperty), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode( - CommonDescriptor.getInstance().getConfig().getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), pool); - - return pool; - } - } - - public static class AINodeHeartbeatClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - final CommonConfig conf = CommonDescriptor.getInstance().getConfig(); - - GenericKeyedObjectPool clientPool = - new GenericKeyedObjectPool<>( - new AsyncAINodeServiceClient.Factory( - manager, - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) - .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) - .setPrintLogWhenEncounterException(false) - .build(), - ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), clientPool); - - return clientPool; - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java index 2c037cf0f3e58..df80d49b502b0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java @@ -73,7 +73,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -102,7 +101,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -121,8 +119,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -184,7 +180,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; @@ -525,7 +520,8 @@ public TAINodeRestartResp restartAINode(TAINodeRestartReq req) throws TException @Override public TGetAINodeLocationResp getAINodeLocation() throws TException { - return client.getAINodeLocation(); + return executeRemoteCallWithRetry( + () -> client.getAINodeLocation(), resp -> !updateConfigNodeLeader(resp.status)); } @Override @@ -1339,28 +1335,6 @@ public TShowCQResp showCQ() throws TException { () -> client.showCQ(), resp -> !updateConfigNodeLeader(resp.status)); } - @Override - public TSStatus createModel(TCreateModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.createModel(req), status -> !updateConfigNodeLeader(status)); - } - - @Override - public TSStatus dropModel(TDropModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.dropModel(req), status -> !updateConfigNodeLeader(status)); - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.getModelInfo(req), resp -> !updateConfigNodeLeader(resp.getStatus())); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.updateModelInfo(req), status -> !updateConfigNodeLeader(status)); - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return executeRemoteCallWithRetry( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java index b5f5df430129f..da0d84d8466fe 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java @@ -27,12 +27,13 @@ import org.apache.iotdb.commons.consensus.ConfigRegionId; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; import org.apache.commons.pool2.impl.GenericKeyedObjectPool; public class DataNodeClientPoolFactory { - private static final IoTDBConfig conf = IoTDBDescriptor.getInstance().getConfig(); + private static final IoTDBConfig CONF = IoTDBDescriptor.getInstance().getConfig(); private DataNodeClientPoolFactory() { // Empty constructor @@ -49,11 +50,11 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() @@ -73,15 +74,38 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS() * 10) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS() * 10) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .setSelectorNumOfAsyncClientManager( - conf.getSelectorNumOfClientManager() / 10 > 0 - ? conf.getSelectorNumOfClientManager() / 10 + CONF.getSelectorNumOfClientManager() / 10 > 0 + ? CONF.getSelectorNumOfClientManager() / 10 : 1) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } + + public static class AINodeClientPoolFactory implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AINodeClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) + .build()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java deleted file mode 100644 index 54150b8f3007b..0000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; -import org.apache.iotdb.ainode.rpc.thrift.TConfigs; -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; -import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.ThriftClient; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.rpc.TConfigurationConst; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.thrift.TException; -import org.apache.thrift.transport.TSSLTransportFactory; -import org.apache.thrift.transport.TSocket; -import org.apache.thrift.transport.TTransport; -import org.apache.thrift.transport.TTransportException; -import org.apache.thrift.transport.layered.TFramedTransport; -import org.apache.tsfile.enums.TSDataType; -import org.apache.tsfile.read.common.block.TsBlock; -import org.apache.tsfile.read.common.block.column.TsBlockSerde; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; - -import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; -import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR; - -public class AINodeClient implements AutoCloseable, ThriftClient { - - private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class); - - private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - - private TEndPoint endPoint; - - private TTransport transport; - - private final ThriftClientProperty property; - private IAINodeRPCService.Client client; - - public static final String MSG_CONNECTION_FAIL = - "Fail to connect to AINode. Please check status of AINode"; - private static final int MAX_RETRY = 3; - - @FunctionalInterface - private interface RemoteCall { - R apply(IAINodeRPCService.Client c) throws TException; - } - - private final TsBlockSerde tsBlockSerde = new TsBlockSerde(); - - ClientManager clientManager; - - private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = - ConfigNodeClientManager.getInstance(); - - private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); - - public static TEndPoint getCurrentEndpoint() { - TAINodeLocation loc = CURRENT_LOCATION.get(); - if (loc == null) { - loc = refreshFromConfigNode(); - } - return (loc == null) ? null : pickEndpointFrom(loc); - } - - public static void updateGlobalAINodeLocation(final TAINodeLocation loc) { - if (loc != null) { - CURRENT_LOCATION.set(loc); - } - } - - private R executeRemoteCallWithRetry(RemoteCall call) throws TException { - TException last = null; - for (int attempt = 1; attempt <= MAX_RETRY; attempt++) { - try { - if (transport == null || !transport.isOpen()) { - final TEndPoint ep = getCurrentEndpoint(); - if (ep == null) { - throw new TException("AINode endpoint unavailable"); - } - this.endPoint = ep; - init(); - } - return call.apply(client); - } catch (TException e) { - last = e; - invalidate(); - final TAINodeLocation loc = refreshFromConfigNode(); - if (loc != null) { - this.endPoint = pickEndpointFrom(loc); - } - try { - Thread.sleep(1000L * attempt); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - } - } - } - throw (last != null ? last : new TException(MSG_CONNECTION_FAIL)); - } - - private static TAINodeLocation refreshFromConfigNode() { - try (final ConfigNodeClient cn = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TGetAINodeLocationResp resp = cn.getAINodeLocation(); - if (resp != null && resp.isSetAiNodeLocation()) { - final TAINodeLocation loc = resp.getAiNodeLocation(); - CURRENT_LOCATION.set(loc); - return loc; - } - } catch (Exception e) { - LoggerFactory.getLogger(AINodeClient.class) - .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); - } - return null; - } - - private static TEndPoint pickEndpointFrom(final TAINodeLocation loc) { - if (loc == null) return null; - if (loc.isSetInternalEndPoint() && loc.getInternalEndPoint() != null) { - return loc.getInternalEndPoint(); - } - return null; - } - - public AINodeClient( - ThriftClientProperty property, - TEndPoint endPoint, - ClientManager clientManager) - throws TException { - this.property = property; - this.clientManager = clientManager; - // Instance default endpoint (pool key). Global location can override it on retries. - this.endPoint = endPoint; - init(); - } - - private void init() throws TException { - try { - if (commonConfig.isEnableInternalSSL()) { - TSSLTransportFactory.TSSLTransportParameters params = - new TSSLTransportFactory.TSSLTransportParameters(); - params.setTrustStore(commonConfig.getTrustStorePath(), commonConfig.getTrustStorePwd()); - params.setKeyStore(commonConfig.getKeyStorePath(), commonConfig.getKeyStorePwd()); - transport = - new TFramedTransport.Factory() - .getTransport( - TSSLTransportFactory.getClientSocket( - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs(), - params)); - } else { - transport = - new TFramedTransport.Factory() - .getTransport( - new TSocket( - TConfigurationConst.defaultTConfiguration, - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs())); - } - if (!transport.isOpen()) { - transport.open(); - } - } catch (TTransportException e) { - throw new TException(MSG_CONNECTION_FAIL); - } - client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); - } - - public TTransport getTransport() { - return transport; - } - - public TSStatus stopAINode() throws TException { - try { - TSStatus status = client.stopAINode(); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(status.message); - } - return status; - } catch (TException e) { - logger.warn( - "Failed to connect to AINode from ConfigNode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public ModelInformation registerModel(String modelName, String uri) throws LoadModelException { - try { - TRegisterModelReq req = new TRegisterModelReq(uri, modelName); - TRegisterModelResp resp = client.registerModel(req); - if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new LoadModelException(resp.status.message, resp.status.getCode()); - } - return parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs()); - } catch (TException e) { - throw new LoadModelException( - e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()); - } - } - - private ModelInformation parseModelInformation( - String modelName, String attributes, TConfigs configs) { - int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray(); - int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray(); - - TSDataType[] inputType = new TSDataType[inputShape[1]]; - TSDataType[] outputType = new TSDataType[outputShape[1]]; - for (int i = 0; i < inputShape[1]; i++) { - inputType[i] = TSDataType.values()[configs.getInput_type().get(i)]; - } - for (int i = 0; i < outputShape[1]; i++) { - outputType[i] = TSDataType.values()[configs.getOutput_type().get(i)]; - } - - return new ModelInformation( - modelName, inputShape, outputShape, inputType, outputType, attributes); - } - - public TSStatus deleteModel(TDeleteModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.deleteModel(req)); - } - - public TSStatus loadModel(TLoadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.loadModel(req)); - } - - public TSStatus unloadModel(TUnloadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.unloadModel(req)); - } - - public TShowModelsResp showModels(TShowModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showModels(req)); - } - - public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showLoadedModels(req)); - } - - public TShowAIDevicesResp showAIDevices() throws TException { - return executeRemoteCallWithRetry(IAINodeRPCService.Client::showAIDevices); - } - - public TInferenceResp inference( - String modelId, - TsBlock inputTsBlock, - Map inferenceAttributes, - TWindowParams windowParams) - throws TException { - try { - TInferenceReq inferenceReq = new TInferenceReq(modelId, tsBlockSerde.serialize(inputTsBlock)); - if (windowParams != null) { - inferenceReq.setWindowParams(windowParams); - } - if (inferenceAttributes != null) { - inferenceReq.setInferenceAttributes(inferenceAttributes); - } - return executeRemoteCallWithRetry(c -> c.inference(inferenceReq)); - } catch (IOException e) { - throw new TException("An exception occurred while serializing input data", e); - } catch (TException e) { - logger.warn( - "Error happens in AINode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public TForecastResp forecast( - String modelId, TsBlock inputTsBlock, int outputLength, Map options) { - try { - TForecastReq forecastReq = - new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock), outputLength); - forecastReq.setOptions(options); - return executeRemoteCallWithRetry(c -> c.forecast(forecastReq)); - } catch (IOException e) { - TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode()); - tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage())); - return new TForecastResp(tsStatus); - } catch (TException e) { - TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode()); - tsStatus.setMessage( - String.format( - "Failed to connect to AINode when executing %s: %s", - Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage())); - return new TForecastResp(tsStatus); - } - } - - public TSStatus createTrainingTask(TTrainingReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.createTrainingTask(req)); - } - - @Override - public void close() throws Exception { - clientManager.returnClient(endPoint, this); - } - - @Override - public void invalidate() { - Optional.ofNullable(transport).ifPresent(TTransport::close); - } - - @Override - public void invalidateAll() { - clientManager.clear(endPoint); - } - - @Override - public boolean printLogWhenEncounterException() { - return property.isPrintLogWhenEncounterException(); - } - - public static class Factory extends ThriftClientFactory { - - public Factory( - ClientManager clientClientManager, - ThriftClientProperty thriftClientProperty) { - super(clientClientManager, thriftClientProperty); - } - - @Override - public void destroyObject(TEndPoint tEndPoint, PooledObject pooledObject) - throws Exception { - pooledObject.getObject().invalidate(); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public boolean validateObject(TEndPoint tEndPoint, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(TTransport::isOpen) - .orElse(false); - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java deleted file mode 100644 index faef1c1ae7b60..0000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; - -public class AINodeClientManager { - - public static final int DEFAULT_AINODE_ID = 0; - - private static final AINodeClientManager INSTANCE = new AINodeClientManager(); - - private final IClientManager clientManager; - - private volatile TEndPoint defaultAINodeEndPoint; - - private AINodeClientManager() { - this.clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeClientPoolFactory()); - } - - public static AINodeClientManager getInstance() { - return INSTANCE; - } - - public void updateDefaultAINodeLocation(TEndPoint endPoint) { - this.defaultAINodeEndPoint = endPoint; - } - - public AINodeClient borrowClient(TEndPoint endPoint) throws Exception { - return clientManager.borrowClient(endPoint); - } - - public AINodeClient borrowClient(int aiNodeId) throws Exception { - if (aiNodeId != DEFAULT_AINODE_ID) { - throw new IllegalArgumentException("Unsupported AINodeId: " + aiNodeId); - } - if (defaultAINodeEndPoint == null) { - defaultAINodeEndPoint = AINodeClient.getCurrentEndpoint(); - } - return clientManager.borrowClient(defaultAINodeEndPoint); - } - - public void clear(TEndPoint endPoint) { - clientManager.clear(endPoint); - } - - public void clearAll() { - clientManager.close(); - } - - public IClientManager getRawClientManager() { - return clientManager; - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java new file mode 100644 index 0000000000000..5eaffc40af9cd --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatResp; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; +import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ClientManager; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.client.factory.ThriftClientFactory; +import org.apache.iotdb.commons.client.property.ThriftClientProperty; +import org.apache.iotdb.commons.client.sync.SyncThriftClientWithErrorHandler; +import org.apache.iotdb.commons.conf.CommonConfig; +import org.apache.iotdb.commons.conf.CommonDescriptor; +import org.apache.iotdb.commons.consensus.ConfigRegionId; +import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; +import org.apache.iotdb.db.conf.IoTDBConfig; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.ConfigNodeClient; +import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; +import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; + +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLHandshakeException; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class AINodeClient implements IAINodeRPCService.Iface, AutoCloseable, ThriftClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeClient.class); + + private static final CommonConfig COMMON_CONFIG = CommonDescriptor.getInstance().getConfig(); + private static final IoTDBConfig IOTDB_CONFIG = IoTDBDescriptor.getInstance().getConfig(); + + private TTransport transport; + + private final ThriftClientProperty property; + private IAINodeRPCService.Client client; + + private static final int MAX_RETRY = 5; + private static final int RETRY_INTERVAL_MS = 100; + public static final String MSG_ALL_RETRY_FAILED = + String.format( + "Failed to connect to AINode after %d retries, please check the status of AINode", + MAX_RETRY); + public static final String MSG_AINODE_CONNECTION_FAIL = + "Fail to connect to AINode from DataNode %s when executing %s."; + private static final String UNSUPPORTED_INVOCATION = + "This method is not supported for invocation by DataNode"; + + @Override + public TSStatus stopAINode() throws TException { + return executeRemoteCallWithRetry(() -> client.stopAINode()); + } + + @Override + public TShowModelsResp showModels(TShowModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showModels(req)); + } + + @Override + public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showLoadedModels(req)); + } + + @Override + public TShowAIDevicesResp showAIDevices() throws TException { + return executeRemoteCallWithRetry(() -> client.showAIDevices()); + } + + @Override + public TSStatus deleteModel(TDeleteModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.deleteModel(req)); + } + + @Override + public TRegisterModelResp registerModel(TRegisterModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.registerModel(req)); + } + + @Override + public TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) { + throw new UnsupportedOperationException(UNSUPPORTED_INVOCATION); + } + + @Override + public TSStatus createTrainingTask(TTrainingReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.createTrainingTask(req)); + } + + @Override + public TSStatus loadModel(TLoadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.loadModel(req)); + } + + @Override + public TSStatus unloadModel(TUnloadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.unloadModel(req)); + } + + @Override + public TInferenceResp inference(TInferenceReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.inference(req)); + } + + @Override + public TForecastResp forecast(TForecastReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.forecast(req)); + } + + @FunctionalInterface + private interface RemoteCall { + R apply() throws TException; + } + + ClientManager clientManager; + + private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = + ConfigNodeClientManager.getInstance(); + + private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); + + private R executeRemoteCallWithRetry(RemoteCall call) throws TException { + for (int attempt = 0; attempt < MAX_RETRY; attempt++) { + try { + return call.apply(); + } catch (TException e) { + final String message = + String.format( + MSG_AINODE_CONNECTION_FAIL, + IOTDB_CONFIG.getAddressAndPort(), + Thread.currentThread().getStackTrace()[2].getMethodName()); + LOGGER.warn(message, e); + CURRENT_LOCATION.set(null); + if (e.getCause() != null && e.getCause() instanceof SSLHandshakeException) { + throw e; + } + } + try { + TimeUnit.MILLISECONDS.sleep(RETRY_INTERVAL_MS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.warn( + "Unexpected interruption when waiting to try to connect to AINode, may because current node has been down. Will break current execution process to avoid meaningless wait."); + break; + } + tryToConnect(property.getConnectionTimeoutMs()); + } + throw new TException(MSG_ALL_RETRY_FAILED); + } + + private void tryToConnect(int timeoutMs) { + TEndPoint endpoint = getCurrentEndpoint(); + if (endpoint != null) { + try { + connect(endpoint, timeoutMs); + return; + } catch (TException e) { + LOGGER.warn("The current AINode may have been down {}, because", endpoint, e); + CURRENT_LOCATION.set(null); + } + } else { + LOGGER.warn("Cannot connect to any AINode due to there are no available ones."); + } + if (transport != null) { + transport.close(); + } + } + + public void connect(TEndPoint endpoint, int timeoutMs) throws TException { + transport = + COMMON_CONFIG.isEnableInternalSSL() + ? DeepCopyRpcTransportFactory.INSTANCE.getTransport( + endpoint.getIp(), + endpoint.getPort(), + timeoutMs, + COMMON_CONFIG.getTrustStorePath(), + COMMON_CONFIG.getTrustStorePwd(), + COMMON_CONFIG.getKeyStorePath(), + COMMON_CONFIG.getKeyStorePwd()) + : DeepCopyRpcTransportFactory.INSTANCE.getTransport( + // As there is a try-catch already, we do not need to use TSocket.wrap + endpoint.getIp(), endpoint.getPort(), timeoutMs); + if (!transport.isOpen()) { + transport.open(); + } + client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); + } + + public TEndPoint getCurrentEndpoint() { + TAINodeLocation loc = CURRENT_LOCATION.get(); + if (loc == null) { + loc = refreshFromConfigNode(); + } + return (loc == null) ? null : loc.getInternalEndPoint(); + } + + private TAINodeLocation refreshFromConfigNode() { + try (final ConfigNodeClient cn = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + final TGetAINodeLocationResp resp = cn.getAINodeLocation(); + if (resp.isSetAiNodeLocation()) { + final TAINodeLocation loc = resp.getAiNodeLocation(); + CURRENT_LOCATION.set(loc); + return loc; + } + } catch (Exception e) { + LoggerFactory.getLogger(AINodeClient.class) + .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); + } + return null; + } + + public AINodeClient( + ThriftClientProperty property, ClientManager clientManager) { + this.property = property; + this.clientManager = clientManager; + tryToConnect(property.getConnectionTimeoutMs()); + } + + public TTransport getTransport() { + return transport; + } + + @Override + public void close() { + clientManager.returnClient(AINodeClientManager.AINODE_ID_PLACEHOLDER, this); + } + + @Override + public void invalidate() { + Optional.ofNullable(transport).ifPresent(TTransport::close); + } + + @Override + public void invalidateAll() { + clientManager.clear(AINodeClientManager.AINODE_ID_PLACEHOLDER); + } + + @Override + public boolean printLogWhenEncounterException() { + return property.isPrintLogWhenEncounterException(); + } + + public static class Factory extends ThriftClientFactory { + + public Factory( + ClientManager clientClientManager, + ThriftClientProperty thriftClientProperty) { + super(clientClientManager, thriftClientProperty); + } + + @Override + public void destroyObject(Integer aiNodeId, PooledObject pooledObject) { + pooledObject.getObject().invalidate(); + } + + @Override + public PooledObject makeObject(Integer Integer) throws Exception { + return new DefaultPooledObject<>( + SyncThriftClientWithErrorHandler.newErrorHandler( + AINodeClient.class, + AINodeClient.class.getConstructor( + thriftClientProperty.getClass(), clientManager.getClass()), + thriftClientProperty, + clientManager)); + } + + @Override + public boolean validateObject(Integer Integer, PooledObject pooledObject) { + return Optional.ofNullable(pooledObject.getObject().getTransport()) + .map(TTransport::isOpen) + .orElse(false); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java new file mode 100644 index 0000000000000..698c8e7938836 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; + +public class AINodeClientManager { + + public static final int AINODE_ID_PLACEHOLDER = 0; + + private AINodeClientManager() { + // Empty constructor + } + + public static IClientManager getInstance() { + return AINodeClientManagerHolder.INSTANCE; + } + + private static class AINodeClientManagerHolder { + + private static final IClientManager INSTANCE = + new IClientManager.Factory() + .createClientManager(new DataNodeClientPoolFactory.AINodeClientPoolFactory()); + + private AINodeClientManagerHolder() { + // Empty constructor + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index 7126af78b8b51..a1e22c73b4e5b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -19,11 +19,12 @@ package org.apache.iotdb.db.queryengine.execution.operator.process.ai; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; import org.apache.iotdb.db.queryengine.execution.operator.Operator; import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; @@ -309,12 +310,11 @@ private void submitInferenceTask() { () -> { try (AINodeClient client = AINodeClientManager.getInstance() - .borrowClient(modelInferenceDescriptor.getTargetAINode())) { + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { return client.inference( - modelInferenceDescriptor.getModelName(), - finalInputTsBlock, - modelInferenceDescriptor.getInferenceAttributes(), - windowParams); + new TInferenceReq( + modelInferenceDescriptor.getModelName(), + serde.serialize(finalInputTsBlock))); } catch (Exception e) { throw new ModelInferenceProcessException(e.getMessage()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java index fc68881656595..daceffce6b7b1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java @@ -19,14 +19,11 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; import org.apache.iotdb.common.rpc.thrift.Model; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.audit.UserEntity; import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.conf.IoTDBConstant; @@ -68,8 +65,6 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClient; import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; @@ -157,8 +152,6 @@ public static Iterator getSupplier( return new SubscriptionSupplier(dataTypes, userEntity); case InformationSchema.VIEWS: return new ViewsSupplier(dataTypes, userEntity); - case InformationSchema.MODELS: - return new ModelsSupplier(dataTypes); case InformationSchema.FUNCTIONS: return new FunctionsSupplier(dataTypes); case InformationSchema.CONFIGURATIONS: @@ -798,112 +791,6 @@ public boolean hasNext() { } } - private static class ModelsSupplier extends TsBlockSupplier { - private final ModelIterator iterator; - - private ModelsSupplier(final List dataTypes) throws Exception { - super(dataTypes); - final TEndPoint ep = AINodeClient.getCurrentEndpoint(); - try (final AINodeClient ai = AINodeClientManager.getInstance().borrowClient(ep)) { - iterator = new ModelIterator(ai.showModels(new TShowModelsReq())); - } - } - - private static class ModelIterator implements Iterator { - - private int index = 0; - private final TShowModelsResp resp; - - private ModelIterator(TShowModelsResp resp) { - this.resp = resp; - } - - @Override - public boolean hasNext() { - return index < resp.getModelIdListSize(); - } - - @Override - public ModelInfoInString next() { - String modelId = resp.getModelIdList().get(index++); - return new ModelInfoInString( - modelId, - resp.getModelTypeMap().get(modelId), - resp.getCategoryMap().get(modelId), - resp.getStateMap().get(modelId)); - } - } - - private static class ModelInfoInString { - - private final String modelId; - private final String modelType; - private final String category; - private final String state; - - public ModelInfoInString(String modelId, String modelType, String category, String state) { - this.modelId = modelId; - this.modelType = modelType; - this.category = category; - this.state = state; - } - - public String getModelId() { - return modelId; - } - - public String getModelType() { - return modelType; - } - - public String getCategory() { - return category; - } - - public String getState() { - return state; - } - } - - @Override - protected void constructLine() { - final ModelInfoInString modelInfo = iterator.next(); - columnBuilders[0].writeBinary( - new Binary(modelInfo.getModelId(), TSFileConfig.STRING_CHARSET)); - columnBuilders[1].writeBinary( - new Binary(modelInfo.getModelType(), TSFileConfig.STRING_CHARSET)); - columnBuilders[2].writeBinary( - new Binary(modelInfo.getCategory(), TSFileConfig.STRING_CHARSET)); - columnBuilders[3].writeBinary(new Binary(modelInfo.getState(), TSFileConfig.STRING_CHARSET)); - // if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) { - // columnBuilders[3].writeBinary( - // new Binary( - // INPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + INPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // columnBuilders[4].writeBinary( - // new Binary(ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // } else { - // columnBuilders[3].appendNull(); - // columnBuilders[4].writeBinary( - // new Binary("Built-in model in IoTDB", TSFileConfig.STRING_CHARSET)); - // } - resultBuilder.declarePosition(); - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - } - private static class FunctionsSupplier extends TsBlockSupplier { private final Iterator udfIterator; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java index 586e12e589ab1..1feecaefde9c5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java @@ -20,12 +20,8 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; public interface IModelFetcher { /** Get model information by model id from configNode. */ TSStatus fetchModel(String modelId, Analysis analysis); - - // currently only used by table model - ModelInferenceDescriptor fetchModel(String modelName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java index dbeee4e8ed4b6..df729ca0ee35f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java @@ -21,21 +21,12 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.db.exception.ainode.ModelNotFoundException; -import org.apache.iotdb.db.exception.sql.StatementAnalyzeException; import org.apache.iotdb.db.protocol.client.ConfigNodeClient; import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; -import org.apache.thrift.TException; - +// TODO: This class should contact with AINode directly and cache model info in DataNode public class ModelFetcher implements IModelFetcher { private final IClientManager configNodeClientManager = @@ -56,33 +47,6 @@ private ModelFetcher() {} @Override public TSStatus fetchModel(String modelName, Analysis analysis) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new StatementAnalyzeException(e.getMessage()); - } - } - - @Override - public ModelInferenceDescriptor fetchModel(String modelName) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new ModelInferenceDescriptor(getModelInfoResp.aiNodeAddress); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new IoTDBRuntimeException( - String.format("fetch model [%s] info failed: %s", modelName, e.getMessage()), - TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - } + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index d0f7c7f99d7ed..01f6757f02ebf 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -19,7 +19,10 @@ package org.apache.iotdb.db.queryengine.plan.execution.config.executor; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; @@ -96,7 +99,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCountTimeSlotListResp; import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; @@ -114,7 +116,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -175,8 +176,8 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; @@ -379,6 +380,8 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = ConfigNodeClientManager.getInstance(); + private static final IClientManager AI_NODE_CLIENT_MANAGER = + AINodeClientManager.getInstance(); /** FIXME Consolidate this clientManager with the upper one. */ private static final IClientManager @@ -3596,16 +3599,16 @@ public SettableFuture showContinuousQueries() { @Override public SettableFuture createModel(String modelId, String uri) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TCreateModelReq req = new TCreateModelReq(modelId, uri); - final TSStatus status = client.createModel(req); - if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { - future.setException(new IoTDBException(status)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TRegisterModelReq req = new TRegisterModelReq(modelId, uri); + final TRegisterModelResp resp = client.registerModel(req); + if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != resp.getStatus().getCode()) { + future.setException(new IoTDBException(resp.getStatus())); } else { future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); } - } catch (final ClientManagerException | TException e) { + } catch (final TException | ClientManagerException e) { future.setException(e); } return future; @@ -3614,9 +3617,9 @@ public SettableFuture createModel(String modelId, String uri) @Override public SettableFuture dropModel(final String modelId) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TSStatus executionStatus = client.dropModel(new TDropModelReq(modelId)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TSStatus executionStatus = client.deleteModel(new TDeleteModelReq(modelId)); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) { future.setException(new IoTDBException(executionStatus)); } else { @@ -3632,7 +3635,7 @@ public SettableFuture dropModel(final String modelId) { public SettableFuture showModels(final String modelId) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowModelsReq req = new TShowModelsReq(); if (modelId != null) { req.setModelId(modelId); @@ -3653,7 +3656,7 @@ public SettableFuture showModels(final String modelId) { public SettableFuture showLoadedModels(List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowLoadedModelsReq req = new TShowLoadedModelsReq(); req.setDeviceIdList(deviceIdList != null ? deviceIdList : new ArrayList<>()); final TShowLoadedModelsResp resp = ai.showLoadedModels(req); @@ -3672,7 +3675,7 @@ public SettableFuture showLoadedModels(List deviceIdLi public SettableFuture showAIDevices() { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowAIDevicesResp resp = ai.showAIDevices(); if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { future.setException(new IoTDBException(resp.getStatus())); @@ -3690,7 +3693,7 @@ public SettableFuture loadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TLoadModelReq req = new TLoadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.loadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3709,7 +3712,7 @@ public SettableFuture unloadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TUnloadModelReq req = new TUnloadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.unloadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3734,7 +3737,7 @@ public SettableFuture createTraining( @Nullable List pathList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TTrainingReq req = new TTrainingReq(); req.setModelId(modelId); req.setParameters(parameters); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java index c01308f9f375c..cd219fd68162a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java @@ -40,7 +40,6 @@ import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableArgumentAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableFunctionInvocationAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; -import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName; @@ -4693,11 +4692,6 @@ public Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional String functionName = node.getName().toString(); TableFunction function = metadata.getTableFunction(functionName); - // set model fetcher for ForecastTableFunction - if (function instanceof ForecastTableFunction) { - ((ForecastTableFunction) function).setModelFetcher(metadata.getModelFetcher()); - } - Node errorLocation = node; if (!node.getArguments().isEmpty()) { errorLocation = node.getArguments().get(0); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 887d7c26d305e..08f7ec6c8335c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -19,14 +19,14 @@ package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; import org.apache.iotdb.db.exception.sql.SemanticException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.relational.TableFunction; import org.apache.iotdb.udf.api.relational.access.Record; @@ -74,8 +74,9 @@ public class ForecastTableFunction implements TableFunction { + private static final TsBlockSerde SERDE = new TsBlockSerde(); + public static class ForecastTableFunctionHandle implements TableFunctionHandle { - TEndPoint targetAINode; String modelId; int maxInputLength; int outputLength; @@ -95,7 +96,6 @@ public ForecastTableFunctionHandle( int outputLength, long outputStartTime, long outputInterval, - TEndPoint targetAINode, List types) { this.keepInput = keepInput; this.maxInputLength = maxInputLength; @@ -104,7 +104,6 @@ public ForecastTableFunctionHandle( this.outputLength = outputLength; this.outputStartTime = outputStartTime; this.outputInterval = outputInterval; - this.targetAINode = targetAINode; this.types = types; } @@ -112,8 +111,6 @@ public ForecastTableFunctionHandle( public byte[] serialize() { try (PublicBAOS publicBAOS = new PublicBAOS(); DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { - ReadWriteIOUtils.write(targetAINode.getIp(), outputStream); - ReadWriteIOUtils.write(targetAINode.getPort(), outputStream); ReadWriteIOUtils.write(modelId, outputStream); ReadWriteIOUtils.write(maxInputLength, outputStream); ReadWriteIOUtils.write(outputLength, outputStream); @@ -138,8 +135,6 @@ public byte[] serialize() { @Override public void deserialize(byte[] bytes) { ByteBuffer buffer = ByteBuffer.wrap(bytes); - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelId = ReadWriteIOUtils.readString(buffer); this.maxInputLength = ReadWriteIOUtils.readInt(buffer); this.outputLength = ReadWriteIOUtils.readInt(buffer); @@ -168,7 +163,6 @@ public boolean equals(Object o) { && outputStartTime == that.outputStartTime && outputInterval == that.outputInterval && keepInput == that.keepInput - && Objects.equals(targetAINode, that.targetAINode) && Objects.equals(modelId, that.modelId) && Objects.equals(options, that.options) && Objects.equals(types, that.types); @@ -177,7 +171,6 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash( - targetAINode, modelId, maxInputLength, outputLength, @@ -284,8 +277,6 @@ public TableFunctionAnalysis analyze(Map arguments) { String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); } - TEndPoint targetAINode = getModelInfo(modelId).getTargetAINode(); - int outputLength = (int) ((ScalarArgument) arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue(); if (outputLength <= 0) { @@ -390,7 +381,6 @@ public TableFunctionAnalysis analyze(Map arguments) { outputLength, outputStartTime, outputInterval, - targetAINode, predicatedColumnTypes); // outputColumnSchema @@ -417,10 +407,6 @@ public TableFunctionDataProcessor getDataProcessor() { }; } - private ModelInferenceDescriptor getModelInfo(String modelId) { - return modelFetcher.fetchModel(modelId); - } - // only allow for INT32, INT64, FLOAT, DOUBLE private void checkType(Type type, String columnName) { if (!ALLOWED_INPUT_TYPES.contains(type)) { @@ -456,9 +442,9 @@ private static Map parseOptions(String options) { private static class ForecastDataProcessor implements TableFunctionDataProcessor { private static final TsBlockSerde SERDE = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); - private final TEndPoint targetAINode; private final String modelId; private final int maxInputLength; private final int outputLength; @@ -471,7 +457,6 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor private final TsBlockBuilder inputTsBlockBuilder; public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { - this.targetAINode = functionHandle.targetAINode; this.modelId = functionHandle.modelId; this.maxInputLength = functionHandle.maxInputLength; this.outputLength = functionHandle.outputLength; @@ -619,8 +604,12 @@ private TsBlock forecast() { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(modelId, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(modelId, SERDE.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java index f0c041ad8053d..db706d4980cba 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.TableHeaderSchemaValidator; @@ -211,9 +210,4 @@ DataPartition getDataPartitionWithUnclosedTimeRange( final String database, final List sgNameToQueryParamsMap); TableFunction getTableFunction(final String functionName); - - /** - * @return ModelFetcher - */ - IModelFetcher getModelFetcher(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 0342c513f96a2..c15071be69785 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -1471,11 +1471,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - return modelFetcher; - } - public static boolean isTwoNumericType(List argumentTypes) { return argumentTypes.size() == 2 && isNumericType(argumentTypes.get(0)) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java index 4676559bd7b15..f8cf497546e6c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java @@ -96,7 +96,6 @@ public List getDataNodeLocations(final String tableName) { case InformationSchema.TOPICS: case InformationSchema.SUBSCRIPTIONS: case InformationSchema.VIEWS: - case InformationSchema.MODELS: case InformationSchema.FUNCTIONS: case InformationSchema.CONFIGURATIONS: case InformationSchema.KEYWORDS: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index 260410954d4c6..09f00f8ed672b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -19,14 +19,12 @@ package org.apache.iotdb.db.queryengine.plan.udf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.UDTF; import org.apache.iotdb.udf.api.access.Row; @@ -54,8 +52,8 @@ public class UDTFForecast implements UDTF { private static final TsBlockSerde serde = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); - private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); private String model_id; private int maxInputLength; private int outputLength; @@ -66,7 +64,6 @@ public class UDTFForecast implements UDTF { List types; private LinkedList inputRows; private TsBlockBuilder inputTsBlockBuilder; - private final IModelFetcher modelFetcher = ModelFetcher.getInstance(); private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); @@ -112,8 +109,6 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati throw new IllegalArgumentException( "MODEL_ID parameter must be provided and cannot be empty."); } - ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); - this.targetAINode = descriptor.getTargetAINode(); this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); this.outputLength = @@ -211,8 +206,12 @@ private TsBlock forecast() throws Exception { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(model_id, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(model_id, serde.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException( e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode()); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java index e60a14b727cad..79c031560973a 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.AlignedDeviceEntry; @@ -402,11 +401,6 @@ public TableFunction getTableFunction(String functionName) { return null; } - @Override - public IModelFetcher getModelFetcher() { - return null; - } - private static final DataPartition DATA_PARTITION = MockTSBSDataPartition.constructDataPartition(); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index e56b48936b96d..7bbfe150ade45 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -19,7 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; @@ -378,7 +377,6 @@ public void testForecastFunction() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan @@ -439,7 +437,6 @@ public void testForecastFunctionWithNoLowerCase() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java index aa9fcdfd1b514..4b1d18944b732 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java @@ -19,8 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -32,12 +30,10 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.function.Exclude; import org.apache.iotdb.db.queryengine.plan.function.Repeat; import org.apache.iotdb.db.queryengine.plan.function.Split; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.SubtractionResolver; @@ -560,21 +556,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - String modelId = "timer_xl"; - IModelFetcher fetcher = Mockito.mock(IModelFetcher.class); - ModelInferenceDescriptor descriptor = Mockito.mock(ModelInferenceDescriptor.class); - Mockito.when(descriptor.getTargetAINode()).thenReturn(new TEndPoint("127.0.0.1", 10810)); - ModelInformation modelInformation = Mockito.mock(ModelInformation.class); - Mockito.when(modelInformation.available()).thenReturn(true); - Mockito.when(modelInformation.getInputShape()).thenReturn(new int[] {1440, 96}); - Mockito.when(descriptor.getModelInformation()).thenReturn(modelInformation); - Mockito.when(descriptor.getModelName()).thenReturn(modelId); - Mockito.when(fetcher.fetchModel(modelId)).thenReturn(descriptor); - return fetcher; - } - private static final DataPartition TABLE_DATA_PARTITION = MockTableModelDataPartition.constructDataPartition(DB1); diff --git a/iotdb-core/node-commons/pom.xml b/iotdb-core/node-commons/pom.xml index 85ff69ee8ac7a..e7c508c195d55 100644 --- a/iotdb-core/node-commons/pom.xml +++ b/iotdb-core/node-commons/pom.xml @@ -65,6 +65,11 @@ iotdb-thrift-confignode 2.0.6-SNAPSHOT + + org.apache.iotdb + iotdb-thrift-ainode + 2.0.6-SNAPSHOT + org.apache.iotdb iotdb-thrift diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java index 106d67b6279d9..115f322348c06 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java @@ -20,6 +20,7 @@ package org.apache.iotdb.commons.client; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncConfigNodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeExternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeInternalServiceClient; @@ -390,4 +391,31 @@ public GenericKeyedObjectPool create return clientPool; } } + + public static class AsyncAINodeHeartbeatServiceClientPoolFactory + implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AsyncAINodeInternalServiceClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) + .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) + .setPrintLogWhenEncounterException(false) + .build(), + ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java similarity index 83% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java rename to iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java index 26130287697c4..8cbd55759633f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.db.protocol.client.ainode; +package org.apache.iotdb.commons.client.async; import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -35,20 +35,20 @@ import java.io.IOException; -public class AsyncAINodeServiceClient extends IAINodeRPCService.AsyncClient +public class AsyncAINodeInternalServiceClient extends IAINodeRPCService.AsyncClient implements ThriftClient { private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - private final boolean printLogWhenEncounterException; private final TEndPoint endPoint; - private final ClientManager clientManager; + private final boolean printLogWhenEncounterException; + private final ClientManager clientManager; - public AsyncAINodeServiceClient( + public AsyncAINodeInternalServiceClient( ThriftClientProperty property, TEndPoint endPoint, TAsyncClientManager tClientManager, - ClientManager clientManager) + ClientManager clientManager) throws IOException { super( property.getProtocolFactory(), @@ -122,10 +122,10 @@ public boolean isReady() { } public static class Factory - extends AsyncThriftClientFactory { + extends AsyncThriftClientFactory { public Factory( - ClientManager clientManager, + ClientManager clientManager, ThriftClientProperty thriftClientProperty, String threadName) { super(clientManager, thriftClientProperty, threadName); @@ -133,14 +133,15 @@ public Factory( @Override public void destroyObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { pooledObject.getObject().close(); } @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { + public PooledObject makeObject(TEndPoint endPoint) + throws Exception { return new DefaultPooledObject<>( - new AsyncAINodeServiceClient( + new AsyncAINodeInternalServiceClient( thriftClientProperty, endPoint, tManagers[clientCnt.incrementAndGet() % tManagers.length], @@ -149,7 +150,7 @@ public PooledObject makeObject(TEndPoint endPoint) thr @Override public boolean validateObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { return pooledObject.getObject().isReady(); } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java index 2db41cc3c2dd7..243bc41c40ce1 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java @@ -43,7 +43,6 @@ public class InformationSchema { public static final String TOPICS = "topics"; public static final String SUBSCRIPTIONS = "subscriptions"; public static final String VIEWS = "views"; - public static final String MODELS = "models"; public static final String FUNCTIONS = "functions"; public static final String CONFIGURATIONS = "configurations"; public static final String KEYWORDS = "keywords"; @@ -256,23 +255,6 @@ public class InformationSchema { viewTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(VIEWS, viewTable); - final TsTable modelTable = new TsTable(MODELS); - modelTable.addColumnSchema( - new TagColumnSchema(ColumnHeaderConstant.MODEL_ID_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema(ColumnHeaderConstant.MODEL_TYPE_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.STATE.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.CONFIGS.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.NOTES.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); - schemaTables.put(MODELS, modelTable); - final TsTable functionTable = new TsTable(FUNCTIONS); functionTable.addColumnSchema( new TagColumnSchema(ColumnHeaderConstant.FUNCTION_NAME_TABLE_MODEL, TSDataType.STRING)); diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift index d8f6318063ebb..f2b8ec6b8b071 100644 --- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift +++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift @@ -1096,34 +1096,6 @@ struct TUnsetSchemaTemplateReq { 4: optional bool isGeneratedByPipe } -struct TCreateModelReq { - 1: required string modelName - 2: required string uri -} - -struct TDropModelReq { - 1: required string modelId -} - -struct TGetModelInfoReq { - 1: required string modelId -} - -struct TGetModelInfoResp { - 1: required common.TSStatus status - 2: optional binary modelInfo - 3: optional common.TEndPoint aiNodeAddress -} - -struct TUpdateModelInfoReq { - 1: required string modelId - 2: required i32 modelStatus - 3: optional string attributes - 4: optional list aiNodeIds - 5: optional i32 inputLength - 6: optional i32 outputLength -} - struct TDataSchemaForTable{ 1: required string targetSql } @@ -1132,16 +1104,6 @@ struct TDataSchemaForTree{ 1: required list path } -struct TCreateTrainingReq { - 1: required string modelId - 2: required bool isTableModel - 3: required string existingModelId - 4: optional TDataSchemaForTable dataSchemaForTable - 5: optional TDataSchemaForTree dataSchemaForTree - 6: optional map parameters - 7: optional list> timeRanges -} - // ==================================================== // Quota // ==================================================== @@ -2006,31 +1968,6 @@ service IConfigNodeRPCService { */ TShowCQResp showCQ() - // ==================================================== - // AI Model - // ==================================================== - - /** - * Create a model - * - * @return SUCCESS_STATUS if the model was created successfully - */ - common.TSStatus createModel(TCreateModelReq req) - - /** - * Drop a model - * - * @return SUCCESS_STATUS if the model was removed successfully - */ - common.TSStatus dropModel(TDropModelReq req) - - /** - * Return the model info by model_id - */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req) - - common.TSStatus updateModelInfo(TUpdateModelInfoReq req) - // ====================================================== // Quota // ====================================================== From d65148909513d54e74b32ff41fc4527bb1449075 Mon Sep 17 00:00:00 2001 From: RkGrit Date: Fri, 28 Nov 2025 11:47:29 +0800 Subject: [PATCH 05/38] Support loading inference pipelines for user-defined models --- iotdb-core/ainode/iotdb/ainode/core/config.py | 24 --------- .../core/inference/inference_request_pool.py | 8 +-- .../core/inference/pipeline/__init__.py | 11 ----- .../core/inference/pipeline/basic_pipeline.py | 2 +- .../inference/pipeline/pipeline_loader.py | 49 +++++++++++++++++++ .../pool_scheduler/basic_pool_scheduler.py | 2 +- .../ainode/core/manager/inference_manager.py | 12 ++--- .../ainode/core/model/model_constants.py | 1 + .../iotdb/ainode/core/model/model_info.py | 1 + .../iotdb/ainode/core/model/model_storage.py | 27 +++++++--- .../ainode/iotdb/ainode/core/model/utils.py | 25 ++++++---- .../ainode/iotdb/ainode/core/rpc/handler.py | 13 ++--- 12 files changed, 101 insertions(+), 74 deletions(-) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index 04ec3ee68c16a..f30e1ecf73fff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -20,7 +20,6 @@ from iotdb.ainode.core.constant import ( AINODE_BUILD_INFO, - AINODE_BUILTIN_MODELS_DIR, AINODE_CLUSTER_INGRESS_ADDRESS, AINODE_CLUSTER_INGRESS_PASSWORD, AINODE_CLUSTER_INGRESS_PORT, @@ -31,7 +30,6 @@ AINODE_CONF_FILE_NAME, AINODE_CONF_GIT_FILE_NAME, AINODE_CONF_POM_FILE_NAME, - AINODE_FINETUNE_MODELS_DIR, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, AINODE_INFERENCE_MAX_PREDICT_LENGTH, @@ -45,7 +43,6 @@ AINODE_SYSTEM_FILE_NAME, AINODE_TARGET_CONFIG_NODE_LIST, AINODE_THRIFT_COMPRESSION_ENABLED, - AINODE_USER_DEFINED_MODELS_DIR, AINODE_VERSION_INFO, ) from iotdb.ainode.core.exception import BadNodeUrlError @@ -97,9 +94,6 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR - self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR - self._ain_finetune_models_dir = AINODE_FINETUNE_MODELS_DIR - self._ain_user_defined_models_dir = AINODE_USER_DEFINED_MODELS_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -208,24 +202,6 @@ def get_ain_models_dir(self) -> str: def set_ain_models_dir(self, ain_models_dir: str) -> None: self._ain_models_dir = ain_models_dir - def get_ain_builtin_models_dir(self) -> str: - return self._ain_builtin_models_dir - - def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None: - self._ain_builtin_models_dir = ain_builtin_models_dir - - def get_ain_finetune_models_dir(self) -> str: - return self._ain_finetune_models_dir - - def set_ain_finetune_models_dir(self, ain_finetune_models_dir: str) -> None: - self._ain_finetune_models_dir = ain_finetune_models_dir - - def get_ain_user_defined_models_dir(self) -> str: - return self._ain_user_defined_models_dir - - def set_ain_user_defined_models_dir(self, ain_user_defined_models_dir: str) -> None: - self._ain_user_defined_models_dir = ain_user_defined_models_dir - def get_ain_system_dir(self) -> str: return self._ain_system_dir diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index b56ffce461f5f..2fb00988bef9b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -30,7 +30,7 @@ from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest -from iotdb.ainode.core.inference.pipeline import get_pipeline +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, ) @@ -119,7 +119,6 @@ def _step(self): batch_output = self._inference_pipeline.infer( batch_inputs, predict_length=requests[0].max_new_tokens, - # num_samples=10, revin=True, ) offset = 0 @@ -128,12 +127,9 @@ def _step(self): cur_batch_size = request.batch_size cur_output = batch_output[offset : offset + cur_batch_size] offset += cur_batch_size - # request.write_step_output(cur_output.mean(dim=1)) request.write_step_output(cur_output) - # self._inference_pipeline.post_decode() if request.is_finished(): - # self._inference_pipeline.post_inference() # ensure the output tensor is on CPU before sending to result queue request.output_tensor = request.output_tensor.cpu() self._finished_queue.put(request) @@ -157,7 +153,7 @@ def run(self): INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) self._request_scheduler.device = self.device - self._inference_pipeline = get_pipeline(self.model_info.model_id, self.device) + self._inference_pipeline = load_pipeline(self.model_info, str(self.device)) self.ready_event.set() activate_daemon = threading.Thread( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py index 53cf7b1086891..a4797b632bb18 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py @@ -16,14 +16,3 @@ # under the License. # -from iotdb.ainode.core.model.sundial.pipeline_sundial import SundialPipeline -from iotdb.ainode.core.model.timer_xl.pipeline_timer import TimerPipeline - - -def get_pipeline(model_id, device): - if model_id == "timer_xl": - return TimerPipeline(model_id, device=device) - elif model_id == "sundial": - return SundialPipeline(model_id, device=device) - else: - raise ValueError(f"Unsupported model_id: {model_id} with pipeline") diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index 19efe0220c64f..2d967aea9bce5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -29,7 +29,7 @@ def __init__(self, model_id, **infer_kwargs): self.model_id = model_id self.device = infer_kwargs.get("device", "cpu") self.model = ModelManager().load_model( - model_id, device_map=str(self.device) + model_id, device_map=self.device ) def _preprocess(self, inputs): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py new file mode 100644 index 0000000000000..be0fb996b2a48 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from pathlib import Path + +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_storage import ModelInfo +from iotdb.ainode.core.model.utils import temporary_sys_path, import_class_from_path + +logger = Logger() + + +def load_pipeline(model_info: ModelInfo, device: str, **kwargs): + if model_info.category == ModelCategory.BUILTIN: + if model_info.model_id == "timer_xl": + from iotdb.ainode.core.model.timer_xl.pipeline_timer import TimerPipeline + pipeline_cls = TimerPipeline + elif model_info.model_id == "sundial": + from iotdb.ainode.core.model.sundial.pipeline_sundial import SundialPipeline + pipeline_cls = SundialPipeline + else: + logger.error( + f"Unsupported built-in model {model_info.model_id}." + ) + return None + else: + module_parent = str(Path(model_info.path).parent.absolute()) + with temporary_sys_path(module_parent): + pipeline_cls = import_class_from_path( + model_info.model_id, model_info.pipeline_cls + ) + + return pipeline_cls(model_info.model_id, device=device) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 6ad55f742a4e1..d5b54280e96b3 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -28,7 +28,7 @@ ScaleActionType, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import get_model_manager, ModelManager +from iotdb.ainode.core.manager.model_manager import ModelManager from iotdb.ainode.core.manager.utils import ( INFERENCE_EXTRA_MEMORY_RATIO, INFERENCE_MEMORY_USAGE_RATIO, diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 6f022f0ddd1e4..30e71ebd75fba 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -18,7 +18,6 @@ import threading import time -from abc import ABC, abstractmethod from typing import Dict import pandas as pd @@ -29,19 +28,17 @@ from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import ( InferenceModelInternalError, - InvalidWindowArgumentError, NumericalRangeException, - runtime_error_extractor, ) from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, InferenceRequestProxy, ) -from iotdb.ainode.core.inference.pipeline import get_pipeline +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.pool_controller import PoolController from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import get_model_manager +from iotdb.ainode.core.manager.model_manager import ModelManager from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.ainode.core.util.serde import convert_to_binary @@ -67,7 +64,7 @@ class InferenceManager: ) # How often to check for requests in the result queue def __init__(self): - self._model_manager = get_model_manager() + self._model_manager = ModelManager() self._model_mem_usage_map: Dict[str, int] = ( {} ) # store model memory usage for each model @@ -211,7 +208,8 @@ def _run( outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) else: - inference_pipeline = get_pipeline(model_id, device="cpu") + model_info = self._model_manager.get_model_info(model_id) + inference_pipeline = load_pipeline(model_info, device="cpu") outputs = inference_pipeline.infer( inputs, predict_length=predict_length, **inference_attrs ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py index ef24830142a39..1e096af379dff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -24,6 +24,7 @@ MODEL_WEIGHTS_FILE_IN_PT = "model.pt" MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + class ModelCategory(Enum): BUILTIN = "builtin" USER_DEFINED = "user_defined" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 690bc09fe8aae..905b13fef2f7d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -32,6 +32,7 @@ def __init__( repo_id: str = "", path: str = "", auto_map: Optional[Dict] = None, + _transformers_registered: bool = False, ): self.model_id = model_id diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index e2c9e8bf21316..de16d05c36c1d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -162,6 +162,7 @@ def _callback_model_download_result( config = json.load(f) if model_info.model_type == "": model_info.model_type = config.get("model_type", "") + model_info.auto_map = config.get("auto_map") logger.info( f"Model {model_id} downloaded successfully and is ready to use." ) @@ -179,10 +180,12 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) model_type = "" auto_map = {} + pipeline_class = "" if os.path.exists(config_path): - config = load_model_config_in_json(Path(config_path)) + config = load_model_config_in_json(config_path) model_type = config.get("model_type", "") - auto_map = config.get("auto_map") + auto_map = config.get("auto_map", None) + pipeline_class = config.get("pipeline_class", "") with self._lock_pool.get_lock(model_id).write_lock(): model_info = ModelInfo( @@ -190,6 +193,7 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, + pipeline_cls=pipeline_class, path=str(model_dir), auto_map=auto_map, _transformers_registered=False, # Lazy registration @@ -212,14 +216,15 @@ def register_model(self, model_id: str, uri: str) -> bool: ensure_init_file(model_dir) if uri_type == UriType.REPO: - self._fetch_model_from_hf_repo(parsed_uri, str(model_dir)) + self._fetch_model_from_hf_repo(parsed_uri, model_dir) else: - self._fetch_model_from_local(os.path.expanduser(parsed_uri), str(model_dir)) + self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) config_path, _ = validate_model_files(model_dir) config = load_model_config_in_json(config_path) model_type = config.get("model_type", "") auto_map = config.get("auto_map") + pipeline_class = config.get("pipeline_class", "") with self._lock_pool.get_lock(model_id).write_lock(): model_info = ModelInfo( @@ -227,6 +232,7 @@ def register_model(self, model_id: str, uri: str) -> bool: model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, + pipeline_cls=pipeline_class, path=str(model_dir), auto_map=auto_map, _transformers_registered=False, # Register later @@ -268,13 +274,18 @@ def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): def _fetch_model_from_local(self, source_path: str, storage_path: str): logger.info(f"Copying model from local path: {source_path} -> {storage_path}") - if not os.path.isdir(source_path): + source_dir = Path(source_path) + if not source_dir.is_dir(): raise ValueError( f"Source path does not exist or is not a directory: {source_path}" ) - for file in os.listdir(source_path): - if os.path.isfile(os.path.join(source_path, file)): - shutil.copy2(file, os.path.join(storage_path, file)) + + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + return + def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py index 93ad2ab1620ac..ce7524bcf729c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -21,9 +21,12 @@ import os.path import sys from contextlib import contextmanager -from pathlib import Path from typing import Dict, Tuple +from iotdb.ainode.core.model.model_constants import ( + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + MODEL_CONFIG_FILE_IN_JSON, +) from iotdb.ainode.core.model.model_constants import ( UriType, ) @@ -57,25 +60,27 @@ def temporary_sys_path(path: str): sys.path.remove(path) -def load_model_config_in_json(config_path: Path) -> Dict: +def load_model_config_in_json(config_path: str) -> Dict: with open(config_path, "r", encoding="utf-8") as f: return json.load(f) -def validate_model_files(model_dir: Path) -> Tuple[Path, Path]: +def validate_model_files(model_dir: str) -> Tuple[str, str]: """Validate model files exist, return config and weights file paths""" - config_path = model_dir / MODEL_CONFIG_FILE - weights_path = model_dir / MODEL_WEIGHTS_FILE - if not config_path.exists(): + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + + if not os.path.exists(config_path): raise ValueError(f"Model config file does not exist: {config_path}") - if not weights_path.exists(): + if not os.path.exists(weights_path): raise ValueError(f"Model weights file does not exist: {weights_path}") # Create __init__.py file to ensure model directory can be imported as a module - init_file = model_dir / "__init__.py" - if not init_file.exists(): - init_file.touch() + init_file = os.path.join(model_dir, "__init__.py") + if not os.path.exists(init_file): + with open(init_file, 'w'): + pass return config_path, weights_path diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index 48460e81b87ca..24792607268b9 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -60,6 +60,7 @@ def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus: ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) + class AINodeRPCServiceHandler(IAINodeRPCService.Iface): def __init__(self, ainode): self._ainode = ainode @@ -84,7 +85,7 @@ def showModels(self, req: TShowModelsReq) -> TShowModelsResp: return self._model_manager.show_models(req) def loadModel(self, req: TLoadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) + status = self._ensure_model_is_registered(req.existingModelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -93,7 +94,7 @@ def loadModel(self, req: TLoadModelReq) -> TSStatus: return self._inference_manager.load_model(req) def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + status = self._ensure_model_is_registered(req.modelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -114,13 +115,13 @@ def showAIDevices(self) -> TShowAIDevicesResp: ) def inference(self, req: TInferenceReq) -> TInferenceResp: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + status = self._ensure_model_is_registered(req.modelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return TInferenceResp(status, []) return self._inference_manager.inference(req) def forecast(self, req: TForecastReq) -> TForecastResp: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + status = self._ensure_model_is_registered(req.modelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return TForecastResp(status, []) return self._inference_manager.forecast(req) @@ -131,10 +132,10 @@ def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass - def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: + def _ensure_model_is_registered(self, model_id: str) -> TSStatus: if not self._model_manager.is_model_registered(model_id): return TSStatus( code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", + message=f"Model [{model_id}] is not available. You can use 'SHOW MODELS' to retrieve the available models.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) \ No newline at end of file From 59fe301c763985d88552bedf6b58da9c344670cb Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sat, 29 Nov 2025 13:00:19 +0800 Subject: [PATCH 06/38] Fix IoTDBDatabaseIT fix ci --- .../relational/it/schema/IoTDBDatabaseIT.java | 14 +------ .../ainode/iotdb/ainode/core/rpc/client.py | 39 ------------------- 2 files changed, 1 insertion(+), 52 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index e04ff838819e6..609a228022f8d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -403,7 +403,6 @@ public void testInformationSchema() throws SQLException { "databases,INF,", "functions,INF,", "keywords,INF,", - "models,INF,", "nodes,INF,", "pipe_plugins,INF,", "pipes,INF,", @@ -504,16 +503,6 @@ public void testInformationSchema() throws SQLException { "database,STRING,TAG,", "table_name,STRING,TAG,", "view_definition,STRING,ATTRIBUTE,"))); - TestUtils.assertResultSetEqual( - statement.executeQuery("desc models"), - "ColumnName,DataType,Category,", - new HashSet<>( - Arrays.asList( - "model_id,STRING,TAG,", - "model_type,STRING,ATTRIBUTE,", - "state,STRING,ATTRIBUTE,", - "configs,STRING,ATTRIBUTE,", - "notes,STRING,ATTRIBUTE,"))); TestUtils.assertResultSetEqual( statement.executeQuery("desc functions"), "ColumnName,DataType,Category,", @@ -638,7 +627,6 @@ public void testInformationSchema() throws SQLException { "information_schema,pipes,INF,USING,null,SYSTEM VIEW,", "information_schema,subscriptions,INF,USING,null,SYSTEM VIEW,", "information_schema,views,INF,USING,null,SYSTEM VIEW,", - "information_schema,models,INF,USING,null,SYSTEM VIEW,", "information_schema,functions,INF,USING,null,SYSTEM VIEW,", "information_schema,configurations,INF,USING,null,SYSTEM VIEW,", "information_schema,keywords,INF,USING,null,SYSTEM VIEW,", @@ -651,7 +639,7 @@ public void testInformationSchema() throws SQLException { TestUtils.assertResultSetEqual( statement.executeQuery("count devices from tables where status = 'USING'"), "count(devices),", - Collections.singleton("20,")); + Collections.singleton("18,")); TestUtils.assertResultSetEqual( statement.executeQuery( "select * from columns where table_name = 'queries' or database = 'test'"), diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py index e2be6459508b7..ea6362ef080af 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py @@ -38,7 +38,6 @@ TAINodeRemoveReq, TAINodeRestartReq, TNodeVersionInfo, - TUpdateModelInfoReq, ) logger = Logger() @@ -155,13 +154,8 @@ def _wait_and_reconnect(self) -> None: self._try_to_connect() except TException: # can not connect to each config node - self._sync_latest_config_node_list() self._try_to_connect() - def _sync_latest_config_node_list(self) -> None: - # TODO - pass - def _update_config_node_leader(self, status: TSStatus) -> bool: if status.code == TSStatusCode.REDIRECTION_RECOMMEND.get_status_code(): if status.redirectNode is not None: @@ -271,36 +265,3 @@ def get_ainode_configuration(self, node_id: int) -> map: self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) - - def update_model_info( - self, - model_id: str, - model_status: int, - attribute: str = "", - ainode_id=None, - input_length=0, - output_length=0, - ) -> None: - if ainode_id is None: - ainode_id = [] - for _ in range(0, self._RETRY_NUM): - try: - req = TUpdateModelInfoReq(model_id, model_status, attribute) - if ainode_id is not None: - req.aiNodeIds = ainode_id - req.inputLength = input_length - req.outputLength = output_length - status = self._client.updateModelInfo(req) - if not self._update_config_node_leader(status): - verify_success( - status, "An error occurs when calling update model info" - ) - return status - except TTransport.TException: - logger.warning( - "Failed to connect to ConfigNode {} from AINode when executing update model info", - self._config_leader, - ) - self._config_leader = None - self._wait_and_reconnect() - raise TException(self._MSG_RECONNECTION_FAIL) From e441289944448b510b856e60543012af1a5d9e3f Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sat, 29 Nov 2025 16:46:20 +0800 Subject: [PATCH 07/38] refactor AIN CI tests Update AINodeInstanceManagementIT.java Fix. CI --- .../ainode/it/AINodeCallInferenceIT.java | 116 ++++++ .../ainode/it/AINodeConcurrentForecastIT.java | 114 ++++++ .../it/AINodeConcurrentInferenceIT.java | 187 ---------- .../iotdb/ainode/it/AINodeForecastIT.java | 103 ++++++ .../iotdb/ainode/it/AINodeInferenceSQLIT.java | 344 ------------------ .../ainode/it/AINodeInstanceManagementIT.java | 79 +--- .../iotdb/ainode/it/AINodeModelManageIT.java | 49 +-- .../iotdb/ainode/utils/AINodeTestUtils.java | 126 ++++++- .../test/resources/ainode-example/config.yaml | 5 - .../test/resources/ainode-example/model.pt | Bin 1906 -> 0 bytes 10 files changed, 467 insertions(+), 656 deletions(-) create mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java create mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java delete mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java create mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java delete mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java delete mode 100644 integration-test/src/test/resources/ainode-example/config.yaml delete mode 100644 integration-test/src/test/resources/ainode-example/model.pt diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java new file mode 100644 index 0000000000000..6bdb3e25b91b5 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeCallInferenceIT { + + private static final String[] WRITE_SQL_IN_TREE = + new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; + + private static final String CALL_INFERENCE_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"select s%d from root.AI\")"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareData(WRITE_SQL_IN_TREE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void callInferenceTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + callInferenceTest(statement, modelInfo); + } + } + } + + public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format(CALL_INFERENCE_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertTrue(count > 0); + } + } + } + + @Test + public void errorCallInferenceTestInTree() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; + errorTest(statement, sql, "1505: model [notFound404] has not been created."); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java new file mode 100644 index 0000000000000..a23eec97497df --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeConcurrentForecastIT { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataForTableModel(); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + private static void prepareDataForTableModel() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE root"); + statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); + } + } + } + + @Test + public void concurrentGPUForecastTest() throws SQLException, InterruptedException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) { + concurrentGPUForecastTest(modelInfo); + } + } + + public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException, InterruptedException { + final int forecastLength = 512; + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + // Single forecast request can be processed successfully + final String forecastSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength); + final int threadCnt = 10; + final int loop = 100; + final String devices = "0,1"; + statement.execute( + String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelOnSpecifiedDevice( + statement, modelInfo.getModelId(), modelInfo.getModelType(), devices); + long startTime = System.currentTimeMillis(); + concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength); + long endTime = System.currentTimeMillis(); + LOGGER.info( + String.format( + "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", + modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime)); + statement.execute( + String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java deleted file mode 100644 index a08990d472fe6..0000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import com.google.common.collect.ImmutableSet; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.HashSet; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeConcurrentInferenceIT { - - private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class); - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataForTreeModel(); - prepareDataForTableModel(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - private static void prepareDataForTreeModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root.AI"); - statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)", - i, Math.sin(i * Math.PI / 1440))); - } - } - } - - private static void prepareDataForTableModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root"); - statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); - } - } - } - - // @Test - public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException { - concurrentGPUCallInferenceTest("timer_xl"); - concurrentGPUCallInferenceTest("sundial"); - } - - private void concurrentGPUCallInferenceTest(String modelId) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - concurrentInference( - statement, - String.format( - "CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)", - modelId, predictLength), - threadCnt, - loop, - predictLength); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - String forecastTableFunctionSql = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d"; - String forecastUDTFSql = - "SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM root.AI"; - - @Test - public void concurrentGPUForecastTest() throws SQLException, InterruptedException { - concurrentGPUForecastTest("timer_xl", forecastUDTFSql); - concurrentGPUForecastTest("sundial", forecastUDTFSql); - concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql); - concurrentGPUForecastTest("sundial", forecastTableFunctionSql); - } - - public void concurrentGPUForecastTest(String modelId, String selectSql) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - long startTime = System.currentTimeMillis(); - concurrentInference( - statement, - String.format(selectSql, modelId, predictLength), - threadCnt, - loop, - predictLength); - long endTime = System.currentTimeMillis(); - LOGGER.info( - String.format( - "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", - modelId, threadCnt * loop, threadCnt, loop, endTime - startTime)); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) - throws SQLException, InterruptedException { - Set targetDevices = ImmutableSet.copyOf(device.split(",")); - LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); - for (int retry = 0; retry < 200; retry++) { - Set foundDevices = new HashSet<>(); - try (final ResultSet resultSet = - statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { - while (resultSet.next()) { - String deviceId = resultSet.getString("DeviceId"); - String loadedModelId = resultSet.getString("ModelId"); - int count = resultSet.getInt("Count(instances)"); - LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); - if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { - foundDevices.add(deviceId); - LOGGER.info("Model {} is loaded to device {}", modelId, device); - } - } - if (foundDevices.containsAll(targetDevices)) { - LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); - return; - } - } - TimeUnit.SECONDS.sleep(3); - } - Assert.fail("Model " + modelId + " is not loaded on device " + device); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java new file mode 100644 index 0000000000000..8953bec07a745 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeForecastIT { + + private static final String[] WRITE_SQL_IN_TABLE = + new String[] { + "CREATE DATABASE root", + "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", + }; + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM root.AI) ORDER BY time)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareData(WRITE_SQL_IN_TABLE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void forecastTableFunctionTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + forecastTableFunctionTest(statement, modelInfo); + } + } + } + + public void forecastTableFunctionTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String forecastTableFunctionSQL = + String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertTrue(count > 0); + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java deleted file mode 100644 index 70f7a1d9f9eb7..0000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData; -import static org.junit.Assert.assertEquals; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeInferenceSQLIT { - - static String[] WRITE_SQL_IN_TREE = - new String[] { - "set configuration \"trusted_uri_pattern\"='.*'", - "create model identity using uri \"" + EXAMPLE_MODEL_PATH + "\"", - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; - - static String[] WRITE_SQL_IN_TABLE = - new String[] { - "CREATE DATABASE root", - "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", - }; - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TREE); - prepareTableData(WRITE_SQL_IN_TABLE); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - // @Test - public void callInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void callInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - public void callInferenceTest(Statement statement) throws SQLException { - // SQL0: Invoke timer-sundial and timer-xl to inference, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(sundial, \"select s1 from root.AI\", generateTime=true, predict_length=720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(timer_xl, \"select s2 from root.AI\", generateTime=true, predict_length=256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", generateTime=true)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "CALL INFERENCE(identity, \"select s2,s0,s3,s1 from root.AI\", generateTime=false)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "CALL INFERENCE(naive_forecaster, \"select s0 from root.AI\", predict_length=3, generateTime=true)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "CALL INFERENCE(holtwinters, \"select s0 from root.AI\", predict_length=6, generateTime=true)"; - // TODO: enable following tests after refactor the CALL INFERENCE - - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2,output3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - - // @Test - public void errorCallInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void errorCallInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - public void errorCallInferenceTest(Statement statement) { - String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; - errorTest(statement, sql, "1505: model [notFound404] has not been created."); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", window=head(2))"; - // TODO: enable following tests after refactor the CALL INFERENCE - // errorTest(statement, sql, "701: Window output 2 is not equal to input size of model 7"); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI limit 5\")"; - // errorTest( - // statement, - // sql, - // "301: The number of rows 5 in the input data does not match the model input 7. Try to - // use LIMIT in SQL or WINDOW in CALL INFERENCE"); - sql = "CREATE MODEL 中文 USING URI \"" + EXAMPLE_MODEL_PATH + "\""; - errorTest(statement, sql, "701: ModelId can only contain letters, numbers, and underscores"); - } - - @Test - public void selectForecastTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - // SQL0: Invoke timer-sundial and timer-xl to forecast, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s1 FROM root.AI) ORDER BY time, output_length=>720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s1"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'timer_xl', input=>(SELECT time,s2 FROM root.AI) ORDER BY time, output_length=>256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s2"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s0,s1,s2,s3 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s2,s0,s3,s1 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "SELECT * FROM FORECAST(model_id=>'naive_forecaster', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>3)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "SELECT * FROM FORECAST(model_id=>'holtwinters', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>6)"; - // TODO: enable following tests after refactor the FORECAST - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2,s3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s2,s0,s3,s1"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index 93351c0178526..56b9a5bbda732 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -35,14 +35,14 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.TimeUnit; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; public class AINodeInstanceManagementIT { - private static final int WAITING_TIME_SEC = 30; private static final Set TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1")); @BeforeClass @@ -85,52 +85,18 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter } // Load sundial to each device - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS 0")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - while (resultSet.next()) { - Assert.assertEquals("0", resultSet.getString("DeviceID")); - Assert.assertEquals("Timer-Sundial", resultSet.getString("ModelType")); - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - } - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "sundial", "sundial", TARGET_DEVICES.toString()); // Load timer_xl to each device - statement.execute("LOAD MODEL timer_xl TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - if (resultSet.getString("ModelType").equals("Timer-XL")) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "timer_xl", "timer_xl", TARGET_DEVICES.toString()); // Clean every device - statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - statement.execute("UNLOAD MODEL timer_xl FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); + statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES)); + checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } private static final int LOOP_CNT = 10; @@ -141,23 +107,9 @@ public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + checkModelOnSpecifiedDevice(statement, "sundial", "sundial", TARGET_DEVICES.toString()); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } } @@ -170,12 +122,7 @@ public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedExcept statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); } - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC * LOOP_CNT); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 2a1461e4a15b6..25cdf0f8ceef7 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -19,6 +19,7 @@ package org.apache.iotdb.ainode.it; +import org.apache.iotdb.ainode.utils.AINodeTestUtils; import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo; import org.apache.iotdb.it.env.EnvFactory; import org.apache.iotdb.it.framework.IoTDBTestRunner; @@ -36,13 +37,8 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; -import java.util.AbstractMap; -import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.junit.Assert.assertEquals; @@ -54,36 +50,6 @@ @Category({AIClusterIT.class}) public class AINodeModelManageIT { - private static final Map BUILT_IN_MODEL_MAP = - Stream.of( - new AbstractMap.SimpleEntry<>( - "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "holtwinters", - new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "exponential_smoothing", - new FakeModelInfo( - "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "naive_forecaster", - new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stl_forecaster", - new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gaussian_hmm", - new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "sundial", new FakeModelInfo("sundial", "Timer-Sundial", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "timer_xl", new FakeModelInfo("timer_xl", "Timer-XL", "BUILT-IN", "ACTIVE"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - @BeforeClass public static void setUp() throws Exception { // Init 1C1D1A cluster environment @@ -95,7 +61,7 @@ public static void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } - @Test + // @Test public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -103,7 +69,7 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup } } - @Test + // @Test public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -114,8 +80,7 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru private void userDefinedModelManagementTest(Statement statement) throws SQLException, InterruptedException { final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = - "create model operationTest using uri \"" + EXAMPLE_MODEL_PATH + "\""; + final String registerSql = "create model operationTest using uri \"" + "\""; final String showSql = "SHOW MODELS operationTest"; final String dropSql = "DROP MODEL operationTest"; @@ -208,10 +173,10 @@ private void showBuiltInModelTest(Statement statement) throws SQLException { resultSet.getString(2), resultSet.getString(3), resultSet.getString(4)); - assertTrue(BUILT_IN_MODEL_MAP.containsKey(modelInfo.getModelId())); - assertEquals(BUILT_IN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); + assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId())); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); } } - assertEquals(BUILT_IN_MODEL_MAP.size(), built_in_model_count); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.size(), built_in_model_count); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index cbb0b03b22997..d9ddb6a4e097d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -19,30 +19,70 @@ package org.apache.iotdb.ainode.utils; -import java.io.File; +import com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.util.AbstractMap; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class AINodeTestUtils { - public static final String EXAMPLE_MODEL_PATH = - "file://" - + System.getProperty("user.dir") - + File.separator - + "src" - + File.separator - + "test" - + File.separator - + "resources" - + File.separator - + "ainode-example"; + public static final Map BUILTIN_LTSM_MAP = + Stream.of( + new AbstractMap.SimpleEntry<>( + "sundial", new FakeModelInfo("sundial", "sundial", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "timer_xl", new FakeModelInfo("timer_xl", "timer", "BUILT-IN", "ACTIVE"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static final Map BUILTIN_MODEL_MAP; + + static { + Map tmp = + Stream.of( + new AbstractMap.SimpleEntry<>( + "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "holtwinters", + new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "exponential_smoothing", + new FakeModelInfo( + "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "naive_forecaster", + new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "stl_forecaster", + new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "gaussian_hmm", + new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), + new AbstractMap.SimpleEntry<>( + "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + tmp.putAll(BUILTIN_LTSM_MAP); + BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp); + } + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeTestUtils.class); public static void checkHeader(ResultSetMetaData resultSetMetaData, String title) throws SQLException { @@ -94,6 +134,68 @@ public static void concurrentInference( } } + public static void checkModelOnSpecifiedDevice( + Statement statement, String modelId, String modelType, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 200; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + String loadedModelType = resultSet.getString("ModelType"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) + && loadedModelType.equals(modelType) + && targetDevices.contains(deviceId) + && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.containsAll(targetDevices)) { + LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is not loaded on device " + device); + } + + public static void checkModelNotOnSpecifiedDevice( + Statement statement, String modelId, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 50; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.isEmpty()) { + LOGGER.info("Model {} is unloaded from devices {}.", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is still loaded on device " + device); + } + public static class FakeModelInfo { private final String modelId; diff --git a/integration-test/src/test/resources/ainode-example/config.yaml b/integration-test/src/test/resources/ainode-example/config.yaml deleted file mode 100644 index 665acb8704e24..0000000000000 --- a/integration-test/src/test/resources/ainode-example/config.yaml +++ /dev/null @@ -1,5 +0,0 @@ -configs: - input_shape: [7, 4] - output_shape: [7, 4] - input_type: ["float32", "float32", "float32", "float32"] - output_type: ["float32", "float32", "float32", "float32"] diff --git a/integration-test/src/test/resources/ainode-example/model.pt b/integration-test/src/test/resources/ainode-example/model.pt deleted file mode 100644 index 67d4aec6999f1b677d7e71e2415ba3178f7f618b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1906 zcmWIWW@cev;NW1u0D=sH44EmZc_o=8mHH`(C5d_k**R`bybMvupn)klKE5QsC^;iO zUJp#`<>l$+=BJeAq!#PtWagzN7IAq(jo~U}&}^*LhydAEQk0mPmzkGd$k-7f2IR+Q z7RRTR=H$dDB_?N=Cl;l|XXNK+7c%*kCKWR41$eV_IGE>6O9!e1;Q*ksMS#x6bhiB#-dViGfQ?(V%DD<11r@UFA?#QIlCwVffiWf>m5x#By; zQ&>r6;m-W5=so`@Z`yUBQ9fazXs5P_&@Un7ZR@1x`&G-h96P%IoF&KZ8*3iSdV8{2 zJa(p5;FmJZ6N~OPWNgcxadu0uml~Ve@@M8{`=2l}omY+CA*g;)if{S*Z1ZqU@u&P- zs+mE-=bLt7=Q3c>y@3WF_E@)JFd!rN^ioojO4H-P2}FmafNWrjkN`T!%|FQ3F(f|R zGsGi4I3&o^&pkfG(aFcPkRbvn%TUPJTFB(hkPJy+S(znz@dcU5**U3PNu`-NDe;+k zB{`YJC0vEf8nGIwB|+W{-VE)9EMN-AYAs}K2PdJ8ANSg30nGzpP!hr(24W0C$YGFI zT#}eqQVdD{d}zLFVA2GeoU6|n4m6GdgmIfJz+i_kxh%D)I5R)b&B+SQOnm7LUCx*b z6t@@WrH3*BZ3bc7whJNKo>WK%01K|K~Mo?hEOFrbnMGz!`0x-%!h;~E?gq*pI zP_%9b^5EKuE|1Wihn#S2P|W|vNRIi442y0PazX}`%Lwob7+^>~LCO~BW*{d=0fYfS zRtPha8PE)Xt{XWa38Cn|gsdB$fYJ3MN3%SN{xD$fg!=${;tTL*W7C0Zl4I6|YiEbD nV6+4{^)N8}0A+X}0O|uv2|yJ9V+AP23d#%&>_7-o4^ayM@w(>s From 64364173236d729ba70e818bd14c5cad9bdb9cef Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 1 Dec 2025 20:32:08 +0800 Subject: [PATCH 08/38] call inference bug fix --- .../process/ai/InferenceOperator.java | 71 +--------- .../plan/analyze/AnalyzeVisitor.java | 134 ++---------------- .../plan/analyze/ModelFetcher.java | 13 +- .../plan/node/process/AI/InferenceNode.java | 3 +- .../model/ModelInferenceDescriptor.java | 61 +------- .../iotdb/commons/model/ModelInformation.java | 43 +++--- .../iotdb/commons/model/ModelTable.java | 4 +- 7 files changed, 53 insertions(+), 276 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index a1e22c73b4e5b..ace55bd0ecf75 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -21,7 +21,6 @@ import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; import org.apache.iotdb.db.protocol.client.an.AINodeClient; import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; @@ -29,9 +28,6 @@ import org.apache.iotdb.db.queryengine.execution.operator.Operator; import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; @@ -76,7 +72,6 @@ public class InferenceOperator implements ProcessOperator { private int resultIndex = 0; private List results; private final TsBlockSerde serde = new TsBlockSerde(); - private InferenceWindowType windowType = null; private final boolean generateTimeColumn; private long maxTimestamp; @@ -110,10 +105,6 @@ public InferenceOperator( this.maxReturnSize = maxReturnSize; this.totalRow = 0; - if (modelInferenceDescriptor.getInferenceWindowParameter() != null) { - windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType(); - } - if (generateTimeColumn) { this.interval = 0; this.minTimestamp = Long.MAX_VALUE; @@ -238,62 +229,6 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) { } } - private TWindowParams getWindowParams() { - TWindowParams windowParams; - if (windowType == null) { - return null; - } - if (windowType == InferenceWindowType.COUNT) { - CountInferenceWindowParameter countInferenceWindowParameter = - (CountInferenceWindowParameter) modelInferenceDescriptor.getInferenceWindowParameter(); - windowParams = new TWindowParams(); - windowParams.setWindowInterval((int) countInferenceWindowParameter.getInterval()); - windowParams.setWindowStep((int) countInferenceWindowParameter.getStep()); - } else { - windowParams = null; - } - return windowParams; - } - - private TsBlock preProcess(TsBlock inputTsBlock) { - // boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn(); - boolean notBuiltIn = false; - if (windowType == null || windowType == InferenceWindowType.HEAD) { - if (notBuiltIn - && totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data does not match the model input %s. Try to use LIMIT in SQL or WINDOW in CALL INFERENCE", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - return inputTsBlock; - } else if (windowType == InferenceWindowType.COUNT) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - } else if (windowType == InferenceWindowType.TAIL) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - // Tail window logic: get the latest data for inference - long windowSize = - (int) - ((BottomInferenceWindowParameter) - modelInferenceDescriptor.getInferenceWindowParameter()) - .getWindowSize(); - return inputTsBlock.subTsBlock((int) (totalRow - windowSize)); - } - return inputTsBlock; - } - private void submitInferenceTask() { if (generateTimeColumn) { @@ -302,9 +237,6 @@ private void submitInferenceTask() { TsBlock inputTsBlock = inputTsBlockBuilder.build(); - TsBlock finalInputTsBlock = preProcess(inputTsBlock); - TWindowParams windowParams = getWindowParams(); - inferenceExecutionFuture = Futures.submit( () -> { @@ -313,8 +245,7 @@ private void submitInferenceTask() { .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { return client.inference( new TInferenceReq( - modelInferenceDescriptor.getModelName(), - serde.serialize(finalInputTsBlock))); + modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock))); } catch (Exception e) { throw new ModelInferenceProcessException(e.getMessage()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 34a289b76c9da..dc56fe118b7b3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -25,7 +25,6 @@ import org.apache.iotdb.commons.conf.IoTDBConstant; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -55,14 +54,6 @@ import org.apache.iotdb.db.queryengine.common.schematree.IMeasurementSchemaInfo; import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.HeadInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.TailInferenceWindow; import org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet; import org.apache.iotdb.db.queryengine.plan.analyze.load.LoadTsFileAnalyzer; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; @@ -425,46 +416,14 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem return; } - // Get model metadata from configNode and do some check + // Get model metadata from AINode String modelId = queryStatement.getModelId(); TSStatus status = modelFetcher.fetchModel(modelId, analysis); if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { throw new GetModelInfoException(status.getMessage()); } - // set inference window if there is - if (queryStatement.isSetInferenceWindow()) { - InferenceWindow window = queryStatement.getInferenceWindow(); - if (InferenceWindowType.HEAD == window.getType()) { - long windowSize = ((HeadInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - if (queryStatement.hasLimit() && queryStatement.getRowLimit() < windowSize) { - throw new SemanticException( - "Limit in Sql should be larger than window size in inference"); - } - // optimize head window by limitNode - queryStatement.setRowLimit(windowSize); - } else if (InferenceWindowType.TAIL == window.getType()) { - long windowSize = ((TailInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new BottomInferenceWindowParameter(windowSize); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } else if (InferenceWindowType.COUNT == window.getType()) { - CountInferenceWindow countInferenceWindow = (CountInferenceWindow) window; - // checkWindowSize(countInferenceWindow.getInterval(), modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new CountInferenceWindowParameter( - countInferenceWindow.getInterval(), countInferenceWindow.getStep()); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } - } - - // set inference attributes if there is + // Set inference attributes if there is if (queryStatement.hasInferenceAttributes()) { analysis .getModelInferenceDescriptor() @@ -472,12 +431,6 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem } } - private void checkWindowSize(long windowSize, ModelInformation modelInformation) { - if (modelInformation.isBuiltIn()) { - return; - } - } - private ISchemaTree analyzeSchema( QueryStatement queryStatement, Analysis analysis, @@ -1717,22 +1670,11 @@ static void analyzeOutput( } if (queryStatement.hasModelInference()) { - ModelInformation modelInformation = analysis.getModelInformation(); // check input - checkInputShape(modelInformation, outputExpressions); - checkInputType(analysis, modelInformation, outputExpressions); - + checkInputType(analysis, outputExpressions); // set output List columnHeaders = new ArrayList<>(); - int[] outputShape = modelInformation.getOutputShape(); - TSDataType[] outputDataType = modelInformation.getOutputDataType(); - for (int i = 0; i < outputShape[1]; i++) { - columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME + i, outputDataType[i])); - } - analysis - .getModelInferenceDescriptor() - .setOutputColumnNames( - columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList())); + columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME, TSDataType.DOUBLE)); boolean isIgnoreTimestamp = !queryStatement.isGenerateTime(); analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); return; @@ -1756,74 +1698,16 @@ static void analyzeOutput( analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); } - // check if the result of SQL matches the input of model - private static void checkInputShape( - ModelInformation modelInformation, List> outputExpressions) { - if (modelInformation.isBuiltIn()) { - modelInformation.setInputColumnSize(outputExpressions.size()); - return; - } - - // check inputShape - int[] inputShape = modelInformation.getInputShape(); - if (inputShape.length != 2) { - throw new SemanticException( - String.format( - "The input shape of model is not correct, the dimension of input shape should be 2, actual dimension is %d", - inputShape.length)); - } - int columnNumber = inputShape[1]; - if (columnNumber != outputExpressions.size()) { - throw new SemanticException( - String.format( - "The column number of SQL result does not match the number of model input [%d] for inference", - columnNumber)); - } - } - private static void checkInputType( - Analysis analysis, - ModelInformation modelInformation, - List> outputExpressions) { - - if (modelInformation.isBuiltIn()) { - TSDataType[] inputType = new TSDataType[outputExpressions.size()]; - for (int i = 0; i < outputExpressions.size(); i++) { - Expression inputExpression = outputExpressions.get(i).left; - TSDataType inputDataType = analysis.getType(inputExpression); - if (!inputDataType.isNumeric()) { - throw new SemanticException( - String.format( - "The type of SQL result column [%s in %d] should be numeric when inference", - inputDataType, i)); - } - inputType[i] = inputDataType; - } - modelInformation.setInputDataType(inputType); - return; - } - - TSDataType[] inputType = modelInformation.getInputDataType(); - if (inputType.length != modelInformation.getInputShape()[1]) { - throw new SemanticException( - String.format( - "The inputType does not match the input shape [%d] for inference", - modelInformation.getInputShape()[1])); - } - for (int i = 0; i < inputType.length; i++) { + Analysis analysis, List> outputExpressions) { + for (int i = 0; i < outputExpressions.size(); i++) { Expression inputExpression = outputExpressions.get(i).left; TSDataType inputDataType = analysis.getType(inputExpression); - boolean isExpressionNumeric = inputDataType.isNumeric(); - boolean isModelNumeric = inputType[i].isNumeric(); - if (isExpressionNumeric && isModelNumeric) { - // every model supports numeric by default - continue; - } - if (inputDataType != inputType[i]) { + if (!inputDataType.isNumeric()) { throw new SemanticException( String.format( - "The type of SQL result column [%s in %d] does not match the type of model input [%s] when inference", - inputDataType, i, inputType[i])); + "The type of SQL result column [%s in %d] should be numeric when inference", + inputDataType, i)); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java index df729ca0ee35f..b4123c237bbd3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java @@ -20,18 +20,13 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; +import org.apache.iotdb.commons.model.ModelInformation; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; // TODO: This class should contact with AINode directly and cache model info in DataNode public class ModelFetcher implements IModelFetcher { - private final IClientManager configNodeClientManager = - ConfigNodeClientManager.getInstance(); - private static final class ModelFetcherHolder { private static final ModelFetcher INSTANCE = new ModelFetcher(); @@ -46,7 +41,9 @@ public static ModelFetcher getInstance() { private ModelFetcher() {} @Override - public TSStatus fetchModel(String modelName, Analysis analysis) { + public TSStatus fetchModel(String modelId, Analysis analysis) { + analysis.setModelInferenceDescriptor( + new ModelInferenceDescriptor(new ModelInformation(modelId))); return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java index 09205c9eb5647..a01acf86db57f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java @@ -31,6 +31,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -90,7 +91,7 @@ public PlanNode clone() { @Override public List getOutputColumnNames() { - return modelInferenceDescriptor.getOutputColumnNames(); + return Collections.singletonList("output"); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java index b7c6aaa4f4b01..1301ec97eb32e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java @@ -19,9 +19,7 @@ package org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; import org.apache.tsfile.utils.ReadWriteIOUtils; @@ -36,19 +34,15 @@ public class ModelInferenceDescriptor { - private final TEndPoint targetAINode; - private ModelInformation modelInformation; + private final ModelInformation modelInformation; private List outputColumnNames; - private InferenceWindowParameter inferenceWindowParameter; private Map inferenceAttributes; - public ModelInferenceDescriptor(TEndPoint targetAINode) { - this.targetAINode = targetAINode; + public ModelInferenceDescriptor(ModelInformation modelInformation) { + this.modelInformation = modelInformation; } private ModelInferenceDescriptor(ByteBuffer buffer) { - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelInformation = ModelInformation.deserialize(buffer); int outputColumnNamesSize = ReadWriteIOUtils.readInt(buffer); if (outputColumnNamesSize == 0) { @@ -59,12 +53,6 @@ private ModelInferenceDescriptor(ByteBuffer buffer) { this.outputColumnNames.add(ReadWriteIOUtils.readString(buffer)); } } - boolean hasInferenceWindowParameter = ReadWriteIOUtils.readBool(buffer); - if (hasInferenceWindowParameter) { - this.inferenceWindowParameter = InferenceWindowParameter.deserialize(buffer); - } else { - this.inferenceWindowParameter = null; - } int inferenceAttributesSize = ReadWriteIOUtils.readInt(buffer); if (inferenceAttributesSize == 0) { this.inferenceAttributes = null; @@ -85,24 +73,12 @@ public Map getInferenceAttributes() { return inferenceAttributes; } - public void setInferenceWindowParameter(InferenceWindowParameter inferenceWindowParameter) { - this.inferenceWindowParameter = inferenceWindowParameter; - } - - public InferenceWindowParameter getInferenceWindowParameter() { - return inferenceWindowParameter; - } - public ModelInformation getModelInformation() { return modelInformation; } - public TEndPoint getTargetAINode() { - return targetAINode; - } - - public String getModelName() { - return modelInformation.getModelName(); + public String getModelId() { + return modelInformation.getModelId(); } public void setOutputColumnNames(List outputColumnNames) { @@ -114,8 +90,6 @@ public List getOutputColumnNames() { } public void serialize(ByteBuffer byteBuffer) { - ReadWriteIOUtils.write(targetAINode.ip, byteBuffer); - ReadWriteIOUtils.write(targetAINode.port, byteBuffer); modelInformation.serialize(byteBuffer); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, byteBuffer); @@ -125,12 +99,6 @@ public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(outputColumnName, byteBuffer); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, byteBuffer); - } else { - ReadWriteIOUtils.write(true, byteBuffer); - inferenceWindowParameter.serialize(byteBuffer); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, byteBuffer); } else { @@ -143,8 +111,6 @@ public void serialize(ByteBuffer byteBuffer) { } public void serialize(DataOutputStream stream) throws IOException { - ReadWriteIOUtils.write(targetAINode.ip, stream); - ReadWriteIOUtils.write(targetAINode.port, stream); modelInformation.serialize(stream); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, stream); @@ -154,12 +120,6 @@ public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(outputColumnName, stream); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, stream); - } else { - ReadWriteIOUtils.write(true, stream); - inferenceWindowParameter.serialize(stream); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, stream); } else { @@ -184,20 +144,13 @@ public boolean equals(Object o) { return false; } ModelInferenceDescriptor that = (ModelInferenceDescriptor) o; - return targetAINode.equals(that.targetAINode) - && modelInformation.equals(that.modelInformation) + return modelInformation.equals(that.modelInformation) && outputColumnNames.equals(that.outputColumnNames) - && inferenceWindowParameter.equals(that.inferenceWindowParameter) && inferenceAttributes.equals(that.inferenceAttributes); } @Override public int hashCode() { - return Objects.hash( - targetAINode, - modelInformation, - outputColumnNames, - inferenceWindowParameter, - inferenceAttributes); + return Objects.hash(modelInformation, outputColumnNames, inferenceAttributes); } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java index 3fa107685438e..01968833db7fb 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -32,9 +32,12 @@ public class ModelInformation { + private static final int[] DEFAULT_MODEL_INPUT_SHAPE = new int[] {2880, 1}; + private static final int[] DEFAULT_MODEL_OUTPUT_SHAPE = new int[] {720, 1}; + ModelType modelType; - private final String modelName; + private final String modelId; private final int[] inputShape; @@ -48,9 +51,17 @@ public class ModelInformation { String attribute = ""; + public ModelInformation(String modelId) { + this.modelId = modelId; + this.inputShape = DEFAULT_MODEL_INPUT_SHAPE; + this.inputDataType = new TSDataType[] {TSDataType.DOUBLE}; + this.outputShape = DEFAULT_MODEL_OUTPUT_SHAPE; + this.outputDataType = new TSDataType[] {TSDataType.DOUBLE}; + } + public ModelInformation( ModelType modelType, - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, @@ -58,7 +69,7 @@ public ModelInformation( String attribute, ModelStatus status) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -68,14 +79,14 @@ public ModelInformation( } public ModelInformation( - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, TSDataType[] outputDataType, String attribute) { this.modelType = ModelType.USER_DEFINED; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -83,9 +94,9 @@ public ModelInformation( this.attribute = attribute; } - public ModelInformation(String modelName, ModelStatus status) { + public ModelInformation(String modelId, ModelStatus status) { this.modelType = ModelType.BUILT_IN_FORECAST; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[0]; this.outputShape = new int[0]; this.outputDataType = new TSDataType[0]; @@ -94,9 +105,9 @@ public ModelInformation(String modelName, ModelStatus status) { } // init built-in modelInformation - public ModelInformation(ModelType modelType, String modelName) { + public ModelInformation(ModelType modelType, String modelId) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[2]; this.outputShape = new int[2]; this.inputDataType = new TSDataType[0]; @@ -116,8 +127,8 @@ public void updateStatus(ModelStatus status) { this.status = status; } - public String getModelName() { - return modelName; + public String getModelId() { + return modelId; } public void setInputLength(int length) { @@ -197,7 +208,7 @@ public void setAttribute(String attribute) { public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -222,7 +233,7 @@ public void serialize(DataOutputStream stream) throws IOException { public void serialize(FileOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -247,7 +258,7 @@ public void serialize(FileOutputStream stream) throws IOException { public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(modelType.ordinal(), byteBuffer); ReadWriteIOUtils.write(status.ordinal(), byteBuffer); - ReadWriteIOUtils.write(modelName, byteBuffer); + ReadWriteIOUtils.write(modelId, byteBuffer); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -353,7 +364,7 @@ public static ModelInformation deserialize(InputStream stream) throws IOExceptio public ByteBuffer serializeShowModelResult() throws IOException { PublicBAOS buffer = new PublicBAOS(); DataOutputStream stream = new DataOutputStream(buffer); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelType.toString(), stream); ReadWriteIOUtils.write(status.toString(), stream); ReadWriteIOUtils.write(Arrays.toString(inputShape), stream); @@ -370,7 +381,7 @@ public ByteBuffer serializeShowModelResult() throws IOException { public boolean equals(Object obj) { if (obj instanceof ModelInformation) { ModelInformation other = (ModelInformation) obj; - return modelName.equals(other.modelName) + return modelId.equals(other.modelId) && modelType.equals(other.modelType) && Arrays.equals(inputShape, other.inputShape) && Arrays.equals(outputShape, other.outputShape) diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java index 64aff12f284ef..6c6100086316e 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java @@ -42,7 +42,7 @@ public boolean containsModel(String modelId) { } public void addModel(ModelInformation modelInformation) { - modelInfoMap.put(modelInformation.getModelName(), modelInformation); + modelInfoMap.put(modelInformation.getModelId(), modelInformation); } public void removeModel(String modelId) { @@ -63,7 +63,7 @@ public ModelInformation getModelInformationById(String modelId) { public void clearFailedModel() { for (ModelInformation modelInformation : modelInfoMap.values()) { if (modelInformation.getStatus() == ModelStatus.UNAVAILABLE) { - modelInfoMap.remove(modelInformation.getModelName()); + modelInfoMap.remove(modelInformation.getModelId()); } } } From c247bd14c9224ee88826ec74cc35ee92ba594725 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Tue, 2 Dec 2025 14:43:52 +0800 Subject: [PATCH 09/38] Remove useless codes in CN --- .../consensus/request/ConfigPhysicalPlan.java | 16 --- .../request/write/model/CreateModelPlan.java | 79 ------------ .../write/model/DropModelInNodePlan.java | 70 ---------- .../request/write/model/DropModelPlan.java | 79 ------------ .../write/model/UpdateModelInfoPlan.java | 122 ------------------ .../impl/node/RemoveAINodeProcedure.java | 10 +- .../procedure/state/RemoveAINodeState.java | 1 - 7 files changed, 1 insertion(+), 376 deletions(-) delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java index 23b0a4e149d88..65c1ee0a9fed5 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java @@ -50,10 +50,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -572,18 +568,6 @@ public static ConfigPhysicalPlan create(final ByteBuffer buffer) throws IOExcept case UPDATE_CQ_LAST_EXEC_TIME: plan = new UpdateCQLastExecTimePlan(); break; - case CreateModel: - plan = new CreateModelPlan(); - break; - case UpdateModelInfo: - plan = new UpdateModelInfoPlan(); - break; - case DropModel: - plan = new DropModelPlan(); - break; - case DropModelInNode: - plan = new DropModelInNodePlan(); - break; case CreatePipePlugin: plan = new CreatePipePluginPlan(); break; diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java deleted file mode 100644 index 61e37cdd21877..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class CreateModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public CreateModelPlan() { - super(ConfigPhysicalPlanType.CreateModel); - } - - public CreateModelPlan(String modelName) { - super(ConfigPhysicalPlanType.CreateModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - CreateModelPlan that = (CreateModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java deleted file mode 100644 index 885543f84e156..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelInNodePlan extends ConfigPhysicalPlan { - - private int nodeId; - - public DropModelInNodePlan() { - super(ConfigPhysicalPlanType.DropModelInNode); - } - - public DropModelInNodePlan(int nodeId) { - super(ConfigPhysicalPlanType.DropModelInNode); - this.nodeId = nodeId; - } - - public int getNodeId() { - return nodeId; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - stream.writeInt(nodeId); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - nodeId = buffer.getInt(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof DropModelInNodePlan)) return false; - DropModelInNodePlan that = (DropModelInNodePlan) o; - return nodeId == that.nodeId; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), nodeId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java deleted file mode 100644 index 813b116c645c5..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public DropModelPlan() { - super(ConfigPhysicalPlanType.DropModel); - } - - public DropModelPlan(String modelName) { - super(ConfigPhysicalPlanType.DropModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - DropModelPlan that = (DropModelPlan) o; - return modelName.equals(that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java deleted file mode 100644 index ce7219e428139..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - -public class UpdateModelInfoPlan extends ConfigPhysicalPlan { - - private String modelName; - private ModelInformation modelInformation; - - // The node which has the model which is only updated in model registration - private List nodeIds; - - public UpdateModelInfoPlan() { - super(ConfigPhysicalPlanType.UpdateModelInfo); - } - - public UpdateModelInfoPlan(String modelName, ModelInformation modelInformation) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = Collections.emptyList(); - } - - public UpdateModelInfoPlan( - String modelName, ModelInformation modelInformation, List nodeIds) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = nodeIds; - } - - public String getModelName() { - return modelName; - } - - public ModelInformation getModelInformation() { - return modelInformation; - } - - public List getNodeIds() { - return nodeIds; - } - - public void setNodeIds(List nodeIds) { - this.nodeIds = nodeIds; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - this.modelInformation.serialize(stream); - ReadWriteIOUtils.write(nodeIds.size(), stream); - for (Integer nodeId : nodeIds) { - ReadWriteIOUtils.write(nodeId, stream); - } - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - this.modelName = ReadWriteIOUtils.readString(buffer); - this.modelInformation = ModelInformation.deserialize(buffer); - int size = ReadWriteIOUtils.readInt(buffer); - this.nodeIds = new ArrayList<>(); - for (int i = 0; i < size; i++) { - this.nodeIds.add(ReadWriteIOUtils.readInt(buffer)); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - UpdateModelInfoPlan that = (UpdateModelInfoPlan) o; - return modelName.equals(that.modelName) - && modelInformation.equals(that.modelInformation) - && nodeIds.equals(that.nodeIds); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName, modelInformation, nodeIds); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java index 98056fc1768ea..2a1c6881b1413 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java @@ -23,7 +23,6 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; import org.apache.iotdb.confignode.procedure.exception.ProcedureException; import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState; @@ -65,13 +64,6 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState st try { switch (state) { - case MODEL_DELETE: - env.getConfigManager() - .getConsensusManager() - .write(new DropModelInNodePlan(removedAINode.aiNodeId)); - // Cause the AINode is removed, so we don't need to remove the model file. - setNextState(RemoveAINodeState.NODE_STOP); - break; case NODE_STOP: TSStatus resp = null; try (AINodeClient client = @@ -149,7 +141,7 @@ protected int getStateId(RemoveAINodeState removeAINodeState) { @Override protected RemoveAINodeState getInitialState() { - return RemoveAINodeState.MODEL_DELETE; + return RemoveAINodeState.NODE_STOP; } @Override diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java index 8a1a6a1bb03b5..49820df663616 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java @@ -20,7 +20,6 @@ package org.apache.iotdb.confignode.procedure.state; public enum RemoveAINodeState { - MODEL_DELETE, NODE_STOP, NODE_REMOVE } From 1b7fb2759651760bd7bfb8f649baf85ef94d054e Mon Sep 17 00:00:00 2001 From: Gewu <89496957+RkGrit@users.noreply.github.com> Date: Tue, 2 Dec 2025 16:47:35 +0800 Subject: [PATCH 10/38] Support loading inference pipelines for user-defined models (#16835) * stash * Support loading inference pipelines for user-defined models * Support loading inference pipelines for different models --- iotdb-core/ainode/iotdb/ainode/core/config.py | 13 +++ .../ainode/iotdb/ainode/core/constant.py | 1 + .../ainode/iotdb/ainode/core/exception.py | 5 +- .../core/inference/pipeline/__init__.py | 1 - .../core/inference/pipeline/basic_pipeline.py | 4 +- .../inference/pipeline/pipeline_loader.py | 29 +++--- .../pool_scheduler/basic_pool_scheduler.py | 1 + .../ainode/core/manager/inference_manager.py | 2 +- .../ainode/core/manager/model_manager.py | 8 +- .../ainode/core/model/model_constants.py | 1 - .../iotdb/ainode/core/model/model_info.py | 5 +- .../iotdb/ainode/core/model/model_loader.py | 24 +++-- .../iotdb/ainode/core/model/model_storage.py | 99 ++++++++++++------- .../core/model/sktime/pipeline_sktime.py | 6 +- .../ainode/iotdb/ainode/core/model/utils.py | 8 +- .../ainode/iotdb/ainode/core/rpc/handler.py | 3 +- 16 files changed, 136 insertions(+), 74 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index f30e1ecf73fff..8f9f256dfc16a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -36,6 +36,7 @@ AINODE_INFERENCE_MEMORY_USAGE_RATIO, AINODE_INFERENCE_MODEL_MEM_USAGE_MAP, AINODE_LOG_DIR, + AINODE_MODELS_BUILTIN_DIR, AINODE_MODELS_DIR, AINODE_RPC_ADDRESS, AINODE_RPC_PORT, @@ -94,6 +95,7 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR + self._ain_models_builtin_dir = AINODE_MODELS_BUILTIN_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -202,6 +204,12 @@ def get_ain_models_dir(self) -> str: def set_ain_models_dir(self, ain_models_dir: str) -> None: self._ain_models_dir = ain_models_dir + def get_ain_models_builtin_dir(self) -> str: + return self._ain_models_builtin_dir + + def set_ain_models_builtin_dir(self, ain_models_builtin_dir: str) -> None: + self._ain_models_builtin_dir = ain_models_builtin_dir + def get_ain_system_dir(self) -> str: return self._ain_system_dir @@ -366,6 +374,11 @@ def _load_config_from_file(self) -> None: if "ain_models_dir" in config_keys: self._config.set_ain_models_dir(file_configs["ain_models_dir"]) + if "ain_models_builtin_dir" in config_keys: + self._config.set_ain_models_builtin_dir( + file_configs["ain_models_builtin_dir"] + ) + if "ain_system_dir" in config_keys: self._config.set_ain_system_dir(file_configs["ain_system_dir"]) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 3576ac711ce0e..abd288eee8d93 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -59,6 +59,7 @@ ) AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") +AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model" AINODE_SYSTEM_DIR = "data/ainode/system" AINODE_LOG_DIR = "logs" diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index 91ad096418872..30b9d54dcc7df 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -17,7 +17,10 @@ # import re -from iotdb.ainode.core.model.model_constants import MODEL_WEIGHTS_FILE_IN_PT, MODEL_CONFIG_FILE_IN_YAML +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_YAML, + MODEL_WEIGHTS_FILE_IN_PT, +) class _BaseError(Exception): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py index a4797b632bb18..2a1e720805f29 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py @@ -15,4 +15,3 @@ # specific language governing permissions and limitations # under the License. # - diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index 2d967aea9bce5..438b8a2611ee7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -28,9 +28,7 @@ class BasicPipeline(ABC): def __init__(self, model_id, **infer_kwargs): self.model_id = model_id self.device = infer_kwargs.get("device", "cpu") - self.model = ModelManager().load_model( - model_id, device_map=self.device - ) + self.model = ModelManager().load_model(model_id, device_map=self.device) def _preprocess(self, inputs): """ diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py index be0fb996b2a48..f7004547e112f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -16,31 +16,34 @@ # under the License. # +import os from pathlib import Path +from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ModelCategory from iotdb.ainode.core.model.model_storage import ModelInfo -from iotdb.ainode.core.model.utils import temporary_sys_path, import_class_from_path +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path logger = Logger() def load_pipeline(model_info: ModelInfo, device: str, **kwargs): if model_info.category == ModelCategory.BUILTIN: - if model_info.model_id == "timer_xl": - from iotdb.ainode.core.model.timer_xl.pipeline_timer import TimerPipeline - pipeline_cls = TimerPipeline - elif model_info.model_id == "sundial": - from iotdb.ainode.core.model.sundial.pipeline_sundial import SundialPipeline - pipeline_cls = SundialPipeline - else: - logger.error( - f"Unsupported built-in model {model_info.model_id}." - ) - return None + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls) else: - module_parent = str(Path(model_info.path).parent.absolute()) + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + module_parent = str(Path(model_path).parent.absolute()) with temporary_sys_path(module_parent): pipeline_cls = import_class_from_path( model_info.model_id, model_info.pipeline_cls diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index d5b54280e96b3..9fbc1b0fca4f3 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -41,6 +41,7 @@ logger = Logger() + def _estimate_shared_pool_size_by_total_mem( device: torch.device, existing_model_infos: List[ModelInfo], diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 30e71ebd75fba..183c942e3b5aa 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -64,7 +64,7 @@ class InferenceManager: ) # How often to check for requests in the result queue def __init__(self): - self._model_manager = ModelManager() + self._model_manager = ModelManager() self._model_mem_usage_map: Dict[str, int] = ( {} ) # store model memory usage for each model diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index 4cafbbecd8c0b..c7b1a525d2cb0 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -36,6 +36,7 @@ logger = Logger() + @singleton class ModelManager: def __init__(self): @@ -49,14 +50,15 @@ def register_model( try: if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) - return TRegisterModelResp( - get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) except ValueError as e: return TRegisterModelResp( get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) ) except Exception as e: - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))) + return TRegisterModelResp( + get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + ) def show_models(self, req: TShowModelsReq) -> TShowModelsResp: return self._model_storage.show_models(req) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py index 1e096af379dff..9f1801b5073a7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -17,7 +17,6 @@ # from enum import Enum - # Model file constants MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" MODEL_CONFIG_FILE_IN_JSON = "config.json" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 905b13fef2f7d..9ba510947fa4c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -20,6 +20,7 @@ from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates + class ModelInfo: def __init__( self, @@ -30,7 +31,6 @@ def __init__( model_cls: str = "", pipeline_cls: str = "", repo_id: str = "", - path: str = "", auto_map: Optional[Dict] = None, _transformers_registered: bool = False, @@ -42,7 +42,6 @@ def __init__( self.model_cls = model_cls self.pipeline_cls = pipeline_cls self.repo_id = repo_id - self.path = path self.auto_map = auto_map # If exists, indicates it's a Transformers model self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers @@ -50,7 +49,7 @@ def __repr__(self): return ( f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " f"category={self.category.value}, state={self.state.value}, " - f"path={self.path}, has_auto_map={self.auto_map is not None})" + f"has_auto_map={self.auto_map is not None})" ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index dc8220aad9a25..f4fd85366ed1e 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -72,6 +72,11 @@ def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): trust_remote_code = kwargs.get("trust_remote_code", True) train_from_scratch = kwargs.get("train_from_scratch", False) + model_path = os.path.join( + self.storage.get_models_dir(), + model_info.category.value, + model_info.model_id, + ) if model_info.category == ModelCategory.BUILTIN: if model_info.model_id == "timer_xl": from iotdb.ainode.core.model.timer_xl.configuration_timer import ( @@ -100,7 +105,7 @@ def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): f"Unsupported built-in Transformers model {model_info.model_id}." ) else: - model_config = AutoConfig.from_pretrained(model_info.path) + model_config = AutoConfig.from_pretrained(model_path) if ( type(model_config) in AutoModelForTimeSeriesPrediction._model_mapping.keys() @@ -132,7 +137,7 @@ def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): ) else: model = load_class.from_pretrained( - model_info.path, + model_path, trust_remote_code=trust_remote_code, device_map=device_map, ) @@ -142,11 +147,16 @@ def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): def load_model_from_pt(self, model_info: ModelInfo, **kwargs): device_map = kwargs.get("device_map", "cpu") acceleration = kwargs.get("acceleration", False) - model_path = os.path.join(model_info.path, "model.pt") - if not os.path.exists(model_path): - logger.error(f"Model file not found at {model_path}.") - raise ModelNotExistError(model_path) - model = torch.jit.load(model_path) + model_path = os.path.join( + self.storage.get_models_dir(), + model_info.category.value, + model_info.model_id, + ) + model_file = os.path.join(model_path, "model.pt") + if not os.path.exists(model_file): + logger.error(f"Model file not found at {model_file}.") + raise ModelNotExistError(model_file) + model = torch.jit.load(model_file) if ( isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index de16d05c36c1d..bb94ced0b5bca 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -21,7 +21,7 @@ import os import shutil from pathlib import Path -from typing import List, Optional, Dict +from typing import Dict, List, Optional from huggingface_hub import hf_hub_download, snapshot_download from transformers import AutoConfig, AutoModelForCausalLM @@ -30,14 +30,27 @@ from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates, UriType, \ - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, MODEL_CONFIG_FILE_IN_JSON +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + ModelCategory, + ModelStates, + UriType, +) from iotdb.ainode.core.model.model_info import ( BUILTIN_HF_TRANSFORMERS_MODEL_MAP, BUILTIN_SKTIME_MODEL_MAP, - ModelInfo) -from iotdb.ainode.core.model.utils import ensure_init_file, load_model_config_in_json, parse_uri_type, get_parsed_uri, \ - validate_model_files, temporary_sys_path, import_class_from_path + ModelInfo, +) +from iotdb.ainode.core.model.utils import ( + ensure_init_file, + get_parsed_uri, + import_class_from_path, + load_model_config_in_json, + parse_uri_type, + temporary_sys_path, + validate_model_files, +) from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus @@ -88,7 +101,9 @@ def _discover_category(self, category: ModelCategory): elif category == ModelCategory.USER_DEFINED: for model_id in os.listdir(category_path): if os.path.isdir(os.path.join(category_path, model_id)): - self._process_user_defined_model_directory(os.path.join(category_path, model_id), model_id) + self._process_user_defined_model_directory( + os.path.join(category_path, model_id), model_id + ) def _discover_builtin_models(self, category_path: str): # Register SKTIME models directly from map @@ -104,14 +119,16 @@ def _discover_builtin_models(self, category_path: str): os.makedirs(model_dir, exist_ok=True) self._process_builtin_model_directory(model_dir, model_id) - def _process_builtin_model_directory( - self, model_dir: str, model_id: str - ): + def _process_builtin_model_directory(self, model_dir: str, model_id: str): """Handling the discovery logic for a builtin model directory.""" ensure_init_file(model_dir) with self._lock_pool.get_lock(model_id).write_lock(): - self._models[ModelCategory.BUILTIN.value][model_id] = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] - self._models[ModelCategory.BUILTIN.value][model_id].state = ModelStates.ACTIVATING + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] + ) + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.ACTIVATING def _download_model_if_necessary() -> bool: """Returns: True if the model is existed or downloaded successfully, False otherwise.""" @@ -126,7 +143,9 @@ def _download_model_if_necessary() -> bool: local_dir=model_dir, ) except Exception as e: - logger.error(f"Failed to download model weights from HuggingFace: {e}") + logger.error( + f"Failed to download model weights from HuggingFace: {e}" + ) return False if not os.path.exists(config_path): try: @@ -136,33 +155,36 @@ def _download_model_if_necessary() -> bool: local_dir=model_dir, ) except Exception as e: - logger.error(f"Failed to download model config from HuggingFace: {e}") + logger.error( + f"Failed to download model config from HuggingFace: {e}" + ) return False return True future = self._executor.submit(_download_model_if_necessary) future.add_done_callback( - lambda f, mid=model_id: self._callback_model_download_result( - f, mid - ) + lambda f, mid=model_id: self._callback_model_download_result(f, mid) ) - def _callback_model_download_result( - self, future, model_id: str - ): + def _callback_model_download_result(self, future, model_id: str): """Callback function for handling model download results""" with self._lock_pool.get_lock(model_id).write_lock(): try: if future.result(): model_info = self._models[ModelCategory.BUILTIN.value][model_id] model_info.state = ModelStates.ACTIVE - config_path = os.path.join(model_info.path, MODEL_CONFIG_FILE_IN_JSON) + config_path = os.path.join( + self._models_dir, + ModelCategory.BUILTIN.value, + model_id, + MODEL_CONFIG_FILE_IN_JSON, + ) if os.path.exists(config_path): with open(config_path, "r", encoding="utf-8") as f: config = json.load(f) if model_info.model_type == "": model_info.model_type = config.get("model_type", "") - model_info.auto_map = config.get("auto_map") + model_info.auto_map = config.get("auto_map", None) logger.info( f"Model {model_id} downloaded successfully and is ready to use." ) @@ -173,19 +195,21 @@ def _callback_model_download_result( logger.warning(f"Failed to download model {model_id}.") except Exception as e: logger.error(f"Error in download callback for model {model_id}: {e}") - self._models[ModelCategory.BUILTIN.value][model_id].state = ModelStates.INACTIVE + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE def _process_user_defined_model_directory(self, model_dir: str, model_id: str): """Handling the discovery logic for a user-defined model directory.""" config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) model_type = "" auto_map = {} - pipeline_class = "" + pipeline_cls = "" if os.path.exists(config_path): config = load_model_config_in_json(config_path) model_type = config.get("model_type", "") auto_map = config.get("auto_map", None) - pipeline_class = config.get("pipeline_class", "") + pipeline_cls = config.get("pipeline_cls", "") with self._lock_pool.get_lock(model_id).write_lock(): model_info = ModelInfo( @@ -193,8 +217,7 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, - pipeline_cls=pipeline_class, - path=str(model_dir), + pipeline_cls=pipeline_cls, auto_map=auto_map, _transformers_registered=False, # Lazy registration ) @@ -211,7 +234,9 @@ def register_model(self, model_id: str, uri: str) -> bool: uri_type = parse_uri_type(uri) parsed_uri = get_parsed_uri(uri) - model_dir = os.path.join(self._models_dir, ModelCategory.USER_DEFINED.value, model_id) + model_dir = os.path.join( + self._models_dir, ModelCategory.USER_DEFINED.value, model_id + ) os.makedirs(model_dir, exist_ok=True) ensure_init_file(model_dir) @@ -224,7 +249,7 @@ def register_model(self, model_id: str, uri: str) -> bool: config = load_model_config_in_json(config_path) model_type = config.get("model_type", "") auto_map = config.get("auto_map") - pipeline_class = config.get("pipeline_class", "") + pipeline_cls = config.get("pipeline_cls", "") with self._lock_pool.get_lock(model_id).write_lock(): model_info = ModelInfo( @@ -232,8 +257,7 @@ def register_model(self, model_id: str, uri: str) -> bool: model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, - pipeline_cls=pipeline_class, - path=str(model_dir), + pipeline_cls=pipeline_cls, auto_map=auto_map, _transformers_registered=False, # Register later ) @@ -286,7 +310,6 @@ def _fetch_model_from_local(self, source_path: str, storage_path: str): shutil.copy2(file, storage_dir / file.name) return - def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ Register Transformers model to auto-loading mechanism (internal method) @@ -299,7 +322,10 @@ def _register_transformers_model(self, model_info: ModelInfo) -> bool: auto_model_path = auto_map.get("AutoModelForCausalLM") try: - module_parent = str(Path(model_info.path).parent.absolute()) + model_path = os.path.join( + self._models_dir, model_info.category.value, model_info.model_id + ) + module_parent = str(Path(model_path).parent.absolute()) with temporary_sys_path(module_parent): config_class = import_class_from_path( model_info.model_id, auto_config_path @@ -455,7 +481,9 @@ def delete_model(self, model_id: str) -> None: if model_info.category == ModelCategory.BUILTIN: raise BuiltInModelDeletionError(model_id) model_info.state = ModelStates.DROPPING - model_path = Path(model_info.path) + model_path = os.path.join( + self._models_dir, model_info.category.value, model_id + ) if model_path.exists(): try: shutil.rmtree(model_path) @@ -540,3 +568,6 @@ def get_registered_models(self) -> List[str]: for category_dict in self._models.values(): model_ids.extend(category_dict.keys()) return model_ids + + def get_models_dir(self): + return self._models_dir diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py index 004222db7cad5..0d283dd2c7a5a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -39,7 +39,11 @@ def infer(self, inputs, **infer_kwargs): # Batch processing: convert each row to Series outputs = [] for i in range(input_ids.shape[0]): - series = pd.Series(input_ids[i].cpu().numpy() if isinstance(input_ids, torch.Tensor) else input_ids[i]) + series = pd.Series( + input_ids[i].cpu().numpy() + if isinstance(input_ids, torch.Tensor) + else input_ids[i] + ) output = self.model.generate(series) outputs.append(output) output = np.array(outputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py index ce7524bcf729c..1cd0ee44912d5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -24,10 +24,8 @@ from typing import Dict, Tuple from iotdb.ainode.core.model.model_constants import ( - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, MODEL_CONFIG_FILE_IN_JSON, -) -from iotdb.ainode.core.model.model_constants import ( + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, UriType, ) @@ -79,7 +77,7 @@ def validate_model_files(model_dir: str) -> Tuple[str, str]: # Create __init__.py file to ensure model directory can be imported as a module init_file = os.path.join(model_dir, "__init__.py") if not os.path.exists(init_file): - with open(init_file, 'w'): + with open(init_file, "w"): pass return config_path, weights_path @@ -96,5 +94,5 @@ def ensure_init_file(dir_path: str): init_file = os.path.join(dir_path, "__init__.py") os.makedirs(dir_path, exist_ok=True) if not os.path.exists(init_file): - with open(init_file, 'w'): + with open(init_file, "w"): pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index 24792607268b9..ba836a7747d84 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -29,6 +29,7 @@ TAIHeartbeatResp, TDeleteModelReq, TForecastReq, + TForecastResp, TInferenceReq, TInferenceResp, TLoadModelReq, @@ -138,4 +139,4 @@ def _ensure_model_is_registered(self, model_id: str) -> TSStatus: code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, message=f"Model [{model_id}] is not available. You can use 'SHOW MODELS' to retrieve the available models.", ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) \ No newline at end of file + return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) From 98ab1b3d8db21d18515a8fdf7d87741632ade15b Mon Sep 17 00:00:00 2001 From: Gewu <89496957+RkGrit@users.noreply.github.com> Date: Wed, 3 Dec 2025 16:30:59 +0800 Subject: [PATCH 11/38] Main process manages models, child process loads models. (#16850) --- .../core/inference/pipeline/basic_pipeline.py | 20 +- .../inference/pipeline/pipeline_loader.py | 2 +- .../ainode/core/manager/model_manager.py | 6 +- .../ainode/iotdb/ainode/core/manager/utils.py | 4 +- .../ainode/core/model/model_constants.py | 7 + .../iotdb/ainode/core/model/model_info.py | 5 +- .../iotdb/ainode/core/model/model_loader.py | 238 ++++++++---------- .../iotdb/ainode/core/model/model_storage.py | 7 +- .../core/model/sktime/pipeline_sktime.py | 4 +- .../core/model/sundial/pipeline_sundial.py | 4 +- .../core/model/timer_xl/pipeline_timer.py | 4 +- .../ainode/iotdb/ainode/core/rpc/handler.py | 4 +- 12 files changed, 147 insertions(+), 158 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index 438b8a2611ee7..e0bfd8c43f4c5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -21,14 +21,14 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.manager.model_manager import ModelManager +from iotdb.ainode.core.model.model_loader import load_model class BasicPipeline(ABC): - def __init__(self, model_id, **infer_kwargs): - self.model_id = model_id + def __init__(self, model_info, **infer_kwargs): + self.model_info = model_info self.device = infer_kwargs.get("device", "cpu") - self.model = ModelManager().load_model(model_id, device_map=self.device) + self.model = load_model(model_info, device_map=self.device) def _preprocess(self, inputs): """ @@ -45,8 +45,8 @@ def _postprocess(self, output: torch.Tensor): class ForecastPipeline(BasicPipeline): - def __init__(self, model_id, **infer_kwargs): - super().__init__(model_id, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **infer_kwargs): + super().__init__(model_info, infer_kwargs=infer_kwargs) def _preprocess(self, inputs): if len(inputs.shape) != 2: @@ -63,8 +63,8 @@ def _postprocess(self, output: torch.Tensor): class ClassificationPipeline(BasicPipeline): - def __init__(self, model_id, **infer_kwargs): - super().__init__(model_id, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **infer_kwargs): + super().__init__(model_info, infer_kwargs=infer_kwargs) def _preprocess(self, inputs): pass @@ -80,8 +80,8 @@ def _postprocess(self, output: torch.Tensor): class ChatPipeline(BasicPipeline): - def __init__(self, model_id, **infer_kwargs): - super().__init__(model_id, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **infer_kwargs): + super().__init__(model_info, infer_kwargs=infer_kwargs) def _preprocess(self, inputs): pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py index f7004547e112f..f2dd9bdc73715 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -49,4 +49,4 @@ def load_pipeline(model_info: ModelInfo, device: str, **kwargs): model_info.model_id, model_info.pipeline_cls ) - return pipeline_cls(model_info.model_id, device=device) + return pipeline_cls(model_info, device=device) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index c7b1a525d2cb0..4aa179558c1bb 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -21,7 +21,7 @@ from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_loader import ModelLoader +from iotdb.ainode.core.model.model_loader import load_model from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.decorator import singleton @@ -41,7 +41,6 @@ class ModelManager: def __init__(self): self._model_storage = ModelStorage() - self._model_loader = ModelLoader(storage=self._model_storage) def register_model( self, @@ -75,7 +74,8 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus: return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) def load_model(self, model_id: str, **kwargs) -> Any: - return self._model_loader.load_model(model_id=model_id, **kwargs) + model_info = self.get_model_info(model_id) + return load_model(model_info=model_info, **kwargs) def get_model_info( self, diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index afb87ee8c5acf..87a528f02b572 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -25,6 +25,7 @@ from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager +from iotdb.ainode.core.model.model_loader import load_model logger = Logger() @@ -46,7 +47,8 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id).to(device) + model_info = ModelManager().get_model_info(model_id) + model = load_model(model_info).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py index 9f1801b5073a7..c42ec98551b83 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -24,6 +24,13 @@ MODEL_CONFIG_FILE_IN_YAML = "config.yaml" +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + class ModelCategory(Enum): BUILTIN = "builtin" USER_DEFINED = "user_defined" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 9ba510947fa4c..718ead530dd2c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -28,17 +28,18 @@ def __init__( category: ModelCategory, state: ModelStates, model_type: str = "", + config_cls: str = "", model_cls: str = "", pipeline_cls: str = "", repo_id: str = "", auto_map: Optional[Dict] = None, - _transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type self.category = category self.state = state + self.config_cls = config_cls self.model_cls = model_cls self.pipeline_cls = pipeline_cls self.repo_id = repo_id @@ -113,6 +114,7 @@ def __repr__(self): category=ModelCategory.BUILTIN, state=ModelStates.INACTIVE, model_type="timer", + config_cls="configuration_timer.TimerConfig", model_cls="modeling_timer.TimerForPrediction", pipeline_cls="pipeline_timer.TimerPipeline", repo_id="thuml/timer-base-84m", @@ -122,6 +124,7 @@ def __repr__(self): category=ModelCategory.BUILTIN, state=ModelStates.INACTIVE, model_type="sundial", + config_cls="configuration_sundial.SundialConfig", model_cls="modeling_sundial.SundialForPrediction", pipeline_cls="pipeline_sundial.SundialPipeline", repo_id="thuml/sundial-base-128m", diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index f4fd85366ed1e..fba47de9be59d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -17,6 +17,7 @@ # import os +from pathlib import Path from typing import Any import torch @@ -30,151 +31,126 @@ AutoModelForTokenClassification, ) +from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ModelCategory from iotdb.ainode.core.model.model_info import ModelInfo -from iotdb.ainode.core.model.model_storage import ModelStorage from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path logger = Logger() -class ModelLoader: - """Model loader - unified interface for loading different types of models""" - - def __init__(self, storage: ModelStorage): - self.storage = storage - - def load_model(self, model_id: str, **kwargs) -> Any: - # Lazy registration: if it's a Transformers model and not registered, register it first - model_info = self.storage.ensure_transformers_registered(model_id) - if not model_info: - logger.error( - f"Model {model_id} failed to register to Transformers, cannot load." +def load_model(model_info: ModelInfo, **kwargs) -> Any: + if model_info.auto_map is not None: + model = load_model_from_transformers(model_info, **kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_info.model_id) + else: + model = load_model_from_pt(model_info, **kwargs) + + logger.info( + f"Model {model_info.model_id} loaded to device {model.device} successfully." + ) + return model + + +def load_model_from_transformers(model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + trust_remote_code = kwargs.get("trust_remote_code", True) + train_from_scratch = kwargs.get("train_from_scratch", False) + + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + + if model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + config_cls = import_class_from_path(module_name, model_info.config_cls) + model_cls = import_class_from_path(module_name, model_info.model_cls) + elif model_info.model_cls and model_info.config_cls: + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_cls = import_class_from_path( + model_info.model_id, model_info.config_cls ) - return None - - if model_info.auto_map is not None: - model = self.load_model_from_transformers(model_info, **kwargs) + model_cls = import_class_from_path( + model_info.model_id, model_info.model_cls + ) + else: + config_cls = AutoConfig.from_pretrained(model_path) + if type(config_cls) in AutoModelForTimeSeriesPrediction._model_mapping.keys(): + model_cls = AutoModelForTimeSeriesPrediction + elif ( + type(config_cls) in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + model_cls = AutoModelForNextSentencePrediction + elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys(): + model_cls = AutoModelForSeq2SeqLM + elif ( + type(config_cls) in AutoModelForSequenceClassification._model_mapping.keys() + ): + model_cls = AutoModelForSequenceClassification + elif type(config_cls) in AutoModelForTokenClassification._model_mapping.keys(): + model_cls = AutoModelForTokenClassification else: - if model_info.model_type == "sktime": - model = create_sktime_model(model_id) - else: - model = self.load_model_from_pt(model_info, **kwargs) + model_cls = AutoModelForCausalLM + + if train_from_scratch: + model = model_cls.from_config( + config_cls, trust_remote_code=trust_remote_code, device_map=device_map + ) + else: + model = model_cls.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + device_map=device_map, + ) - logger.info(f"Model {model_id} loaded to device {model.device} successfully.") + return model + + +def load_model_from_pt(model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + acceleration = kwargs.get("acceleration", False) + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + model_file = os.path.join(model_path, "model.pt") + if not os.path.exists(model_file): + logger.error(f"Model file not found at {model_file}.") + raise ModelNotExistError(model_file) + model = torch.jit.load(model_file) + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: return model + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + return model.to(device_map) - def load_model_from_transformers(self, model_info: ModelInfo, **kwargs): - model_config, load_class = None, None - device_map = kwargs.get("device_map", "cpu") - trust_remote_code = kwargs.get("trust_remote_code", True) - train_from_scratch = kwargs.get("train_from_scratch", False) - model_path = os.path.join( - self.storage.get_models_dir(), - model_info.category.value, - model_info.model_id, - ) - if model_info.category == ModelCategory.BUILTIN: - if model_info.model_id == "timer_xl": - from iotdb.ainode.core.model.timer_xl.configuration_timer import ( - TimerConfig, - ) - - model_config = TimerConfig() - from iotdb.ainode.core.model.timer_xl.modeling_timer import ( - TimerForPrediction, - ) - - load_class = TimerForPrediction - elif model_info.model_id == "sundial": - from iotdb.ainode.core.model.sundial.configuration_sundial import ( - SundialConfig, - ) - - model_config = SundialConfig() - from iotdb.ainode.core.model.sundial.modeling_sundial import ( - SundialForPrediction, - ) - - load_class = SundialForPrediction - else: - logger.error( - f"Unsupported built-in Transformers model {model_info.model_id}." - ) - else: - model_config = AutoConfig.from_pretrained(model_path) - if ( - type(model_config) - in AutoModelForTimeSeriesPrediction._model_mapping.keys() - ): - load_class = AutoModelForTimeSeriesPrediction - elif ( - type(model_config) - in AutoModelForNextSentencePrediction._model_mapping.keys() - ): - load_class = AutoModelForNextSentencePrediction - elif type(model_config) in AutoModelForSeq2SeqLM._model_mapping.keys(): - load_class = AutoModelForSeq2SeqLM - elif ( - type(model_config) - in AutoModelForSequenceClassification._model_mapping.keys() - ): - load_class = AutoModelForSequenceClassification - elif ( - type(model_config) - in AutoModelForTokenClassification._model_mapping.keys() - ): - load_class = AutoModelForTokenClassification - else: - load_class = AutoModelForCausalLM - - if train_from_scratch: - model = load_class.from_config( - model_config, trust_remote_code=trust_remote_code, device_map=device_map - ) - else: - model = load_class.from_pretrained( - model_path, - trust_remote_code=trust_remote_code, - device_map=device_map, - ) +def load_model_for_efficient_inference(self): + # TODO: An efficient model loading method for inference based on model_arguments + pass - return model - def load_model_from_pt(self, model_info: ModelInfo, **kwargs): - device_map = kwargs.get("device_map", "cpu") - acceleration = kwargs.get("acceleration", False) - model_path = os.path.join( - self.storage.get_models_dir(), - model_info.category.value, - model_info.model_id, - ) - model_file = os.path.join(model_path, "model.pt") - if not os.path.exists(model_file): - logger.error(f"Model file not found at {model_file}.") - raise ModelNotExistError(model_file) - model = torch.jit.load(model_file) - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - try: - model = torch.compile(model) - except Exception as e: - logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") - return model.to(device_map) - - def load_model_for_efficient_inference(self): - # TODO: An efficient model loading method for inference based on model_arguments - pass - - def load_model_for_powerful_finetune(self): - # TODO: An powerful model loading method for finetune based on model_arguments - pass - - def unload_model(self): - pass +def load_model_for_powerful_finetune(self): + # TODO: An powerful model loading method for finetune based on model_arguments + pass + + +def unload_model(self): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index bb94ced0b5bca..b3f3ee023156d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -555,6 +555,10 @@ def get_model_infos( def is_model_registered(self, model_id: str) -> bool: """Check if model is registered (search in _models)""" + # Lazy registration: if it's a Transformers model and not registered, register it first + if self.ensure_transformers_registered(model_id) is None: + return False + with self._lock_pool.get_lock("").read_lock(): for category_dict in self._models.values(): if model_id in category_dict: @@ -568,6 +572,3 @@ def get_registered_models(self) -> List[str]: for category_dict in self._models.values(): model_ids.extend(category_dict.keys()) return model_ids - - def get_models_dir(self): - return self._models_dir diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py index 0d283dd2c7a5a..b9032aaeb064b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -24,8 +24,8 @@ class SktimePipeline(BasicPipeline): - def __init__(self, model_id, **infer_kwargs): - super().__init__(model_id, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **infer_kwargs): + super().__init__(model_info, infer_kwargs=infer_kwargs) def _preprocess(self, inputs): return super()._preprocess(inputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 4d761d0f00ad0..add2973578f03 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -22,8 +22,8 @@ class SundialPipeline(ForecastPipeline): - def __init__(self, model_id, **infer_kwargs): - super().__init__(model_id, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **infer_kwargs): + super().__init__(model_info, infer_kwargs=infer_kwargs) def _preprocess(self, inputs): return super()._preprocess(inputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index 36e91e9f91b4e..38e5effa27b22 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -22,8 +22,8 @@ class TimerPipeline(ForecastPipeline): - def __init__(self, model_id, **infer_kwargs): - super().__init__(model_id, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **infer_kwargs): + super().__init__(model_info, infer_kwargs=infer_kwargs) def _preprocess(self, inputs): return super()._preprocess(inputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index ba836a7747d84..6c4eedeb99f7f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -41,7 +41,7 @@ TShowModelsReq, TShowModelsResp, TTrainingReq, - TUnloadModelReq, TForecastResp, + TUnloadModelReq, ) from iotdb.thrift.common.ttypes import TSStatus @@ -137,6 +137,6 @@ def _ensure_model_is_registered(self, model_id: str) -> TSStatus: if not self._model_manager.is_model_registered(model_id): return TSStatus( code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not available. You can use 'SHOW MODELS' to retrieve the available models.", + message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) From c066df48a6d9fb1bcc7d1641a798c2d168c24e2c Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Wed, 3 Dec 2025 20:03:59 +0800 Subject: [PATCH 12/38] unify some parameter names --- .../ainode/iotdb/ainode/core/constant.py | 5 +- .../core/inference/pipeline/basic_pipeline.py | 19 +++---- .../inference/pipeline/pipeline_loader.py | 8 +-- .../ainode/core/inference/pool_controller.py | 57 ++++++++++--------- .../pool_scheduler/basic_pool_scheduler.py | 2 +- .../ainode/core/manager/model_manager.py | 7 +-- .../ainode/iotdb/ainode/core/manager/utils.py | 4 +- .../iotdb/ainode/core/model/model_loader.py | 14 ++--- .../core/model/sktime/pipeline_sktime.py | 4 +- .../core/model/sundial/pipeline_sundial.py | 4 +- .../core/model/timer_xl/pipeline_timer.py | 4 +- 11 files changed, 64 insertions(+), 64 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index abd288eee8d93..100ef0138eb91 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -49,10 +49,13 @@ # AINode inference configuration AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 + +# TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { "sundial": 1036 * 1024**2, # 1036 MiB - "timer_xl": 856 * 1024**2, # 856 MiB + "timer": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes + AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( 1.2 # the overhead ratio for inference, used to estimate the pool size diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index e0bfd8c43f4c5..489caf7863c30 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -25,16 +25,15 @@ class BasicPipeline(ABC): - def __init__(self, model_info, **infer_kwargs): + def __init__(self, model_info, **model_kwargs): self.model_info = model_info - self.device = infer_kwargs.get("device", "cpu") - self.model = load_model(model_info, device_map=self.device) + self.device = model_kwargs.get("device", "cpu") + self.model = load_model(model_info, device_map=self.device, **model_kwargs) def _preprocess(self, inputs): """ Preprocess the input before inference, including shape validation and value transformation. """ - # TODO: Integrate with the data processing pipeline operators pass def _postprocess(self, output: torch.Tensor): @@ -45,8 +44,8 @@ def _postprocess(self, output: torch.Tensor): class ForecastPipeline(BasicPipeline): - def __init__(self, model_info, **infer_kwargs): - super().__init__(model_info, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): if len(inputs.shape) != 2: @@ -63,8 +62,8 @@ def _postprocess(self, output: torch.Tensor): class ClassificationPipeline(BasicPipeline): - def __init__(self, model_info, **infer_kwargs): - super().__init__(model_info, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): pass @@ -80,8 +79,8 @@ def _postprocess(self, output: torch.Tensor): class ChatPipeline(BasicPipeline): - def __init__(self, model_info, **infer_kwargs): - super().__init__(model_info, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py index f2dd9bdc73715..2225e1a53045b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -28,14 +28,14 @@ logger = Logger() -def load_pipeline(model_info: ModelInfo, device: str, **kwargs): +def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs): if model_info.category == ModelCategory.BUILTIN: - module_name = ( + module_id = ( AINodeDescriptor().get_config().get_ain_models_builtin_dir() + "." + model_info.model_id ) - pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls) + pipeline_cls = import_class_from_path(module_id, model_info.pipeline_cls) else: model_path = os.path.join( os.getcwd(), @@ -49,4 +49,4 @@ def load_pipeline(model_info: ModelInfo, device: str, **kwargs): model_info.model_id, model_info.pipeline_cls ) - return pipeline_cls(model_info, device=device) + return pipeline_cls(model_info, device=device, **model_kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 5af5bb95102a6..8ffa89ffd6752 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -22,7 +22,6 @@ from concurrent.futures import wait from typing import Dict, Optional -import torch import torch.multiprocessing as mp from iotdb.ainode.core.exception import InferenceModelInternalError @@ -73,7 +72,7 @@ def __init__(self, result_queue: mp.Queue): thread_name_prefix=ThreadName.INFERENCE_POOL_CONTROLLER.value ) - # =============== Pool Management =============== + # =============== Automatic Pool Management (Developing) =============== @synchronized(threading.Lock()) def first_req_init(self, model_id: str, device): """ @@ -104,33 +103,35 @@ def _first_pool_init(self, model_id: str, device_str: str): Initialize the first pool for the given model_id. Ensure the pool is ready before returning. """ - device = torch.device(device_str) - device_id = device.index - - first_queue = mp.Queue() - ready_event = mp.Event() - first_pool = InferenceRequestPool( - pool_id=0, - model_id=model_id, - device=device_str, - request_queue=first_queue, - result_queue=self._result_queue, - ready_event=ready_event, - ) - first_pool.start() - self._register_pool(model_id, device_str, 0, first_pool, first_queue) - - if not ready_event.wait(timeout=30): - self._erase_pool(model_id, device_id, 0) - logger.error( - f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" - ) - else: - self.set_state(model_id, device_id, 0, PoolState.RUNNING) - logger.info( - f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" - ) + pass + # device = torch.device(device_str) + # device_id = device.index + # + # first_queue = mp.Queue() + # ready_event = mp.Event() + # first_pool = InferenceRequestPool( + # pool_id=0, + # model_id=model_id, + # device=device_str, + # request_queue=first_queue, + # result_queue=self._result_queue, + # ready_event=ready_event, + # ) + # first_pool.start() + # self._register_pool(model_id, device_str, 0, first_pool, first_queue) + # + # if not ready_event.wait(timeout=30): + # self._erase_pool(model_id, device_id, 0) + # logger.error( + # f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" + # ) + # else: + # self.set_state(model_id, device_id, 0, PoolState.RUNNING) + # logger.info( + # f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" + # ) + # =============== Pool Management =============== def load_model(self, model_id: str, device_id_list: list[str]): """ Load the model to the specified devices asynchronously. diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 9fbc1b0fca4f3..d2e7292ecd8ff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -63,7 +63,7 @@ def _estimate_shared_pool_size_by_total_mem( mem_usages: Dict[str, float] = {} for model_info in all_models: mem_usages[model_info.model_id] = ( - MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO + MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO ) # Evaluate system resources and get TOTAL memory diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index 4aa179558c1bb..8ffb33d91e2d2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -60,6 +60,7 @@ def register_model( ) def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + self._refresh() return self._model_storage.show_models(req) def delete_model(self, req: TDeleteModelReq) -> TSStatus: @@ -73,10 +74,6 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus: logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - def load_model(self, model_id: str, **kwargs) -> Any: - model_info = self.get_model_info(model_id) - return load_model(model_info=model_info, **kwargs) - def get_model_info( self, model_id: str, @@ -91,7 +88,7 @@ def get_model_infos( ) -> List[ModelInfo]: return self._model_storage.get_model_infos(category, model_type) - def refresh(self): + def _refresh(self): """Refresh the model list (re-scan the file system)""" self._model_storage.discover_all_models() diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 87a528f02b572..2b032bb29fbdd 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -82,7 +82,7 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: model_info = ModelManager().get_model_info(model_id) - if model_info is None or model_info.model_id not in MODEL_MEM_USAGE_MAP: + if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." ) @@ -91,7 +91,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] - mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] * INFERENCE_EXTRA_MEMORY_RATIO + mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage) if size <= 0: logger.error( diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index fba47de9be59d..aace7183a7e1a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -42,14 +42,14 @@ logger = Logger() -def load_model(model_info: ModelInfo, **kwargs) -> Any: +def load_model(model_info: ModelInfo, **model_kwargs) -> Any: if model_info.auto_map is not None: - model = load_model_from_transformers(model_info, **kwargs) + model = load_model_from_transformers(model_info, **model_kwargs) else: if model_info.model_type == "sktime": model = create_sktime_model(model_info.model_id) else: - model = load_model_from_pt(model_info, **kwargs) + model = load_model_from_pt(model_info, **model_kwargs) logger.info( f"Model {model_info.model_id} loaded to device {model.device} successfully." @@ -57,10 +57,10 @@ def load_model(model_info: ModelInfo, **kwargs) -> Any: return model -def load_model_from_transformers(model_info: ModelInfo, **kwargs): - device_map = kwargs.get("device_map", "cpu") - trust_remote_code = kwargs.get("trust_remote_code", True) - train_from_scratch = kwargs.get("train_from_scratch", False) +def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): + device_map = model_kwargs.get("device_map", "cpu") + trust_remote_code = model_kwargs.get("trust_remote_code", True) + train_from_scratch = model_kwargs.get("train_from_scratch", False) model_path = os.path.join( os.getcwd(), diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py index b9032aaeb064b..ccac477ca83ae 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -24,8 +24,8 @@ class SktimePipeline(BasicPipeline): - def __init__(self, model_info, **infer_kwargs): - super().__init__(model_info, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): return super()._preprocess(inputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index add2973578f03..e24df5ef842ac 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -22,8 +22,8 @@ class SundialPipeline(ForecastPipeline): - def __init__(self, model_info, **infer_kwargs): - super().__init__(model_info, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): return super()._preprocess(inputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index 38e5effa27b22..c1802186d71dd 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -22,8 +22,8 @@ class TimerPipeline(ForecastPipeline): - def __init__(self, model_info, **infer_kwargs): - super().__init__(model_info, infer_kwargs=infer_kwargs) + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): return super()._preprocess(inputs) From 0e584b735907e7eb6d07101c2500892c8962a5e2 Mon Sep 17 00:00:00 2001 From: RkGrit Date: Wed, 3 Dec 2025 19:28:25 +0800 Subject: [PATCH 13/38] support pipeline for sktime models --- .../core/inference/pipeline/basic_pipeline.py | 20 ++--- .../inference/pipeline/pipeline_loader.py | 9 +- .../iotdb/ainode/core/model/model_loader.py | 2 +- .../core/model/sktime/arima/config.json | 23 ++--- .../core/model/sktime/configuration_sktime.py | 83 +++++++++++-------- .../sktime/exponential_smoothing/config.json | 2 +- .../model/sktime/gaussian_hmm/config.json | 10 ++- .../core/model/sktime/gmm_hmm/config.json | 6 +- .../core/model/sktime/modeling_sktime.py | 43 +++++----- .../model/sktime/naive_forecaster/config.json | 3 +- .../core/model/sktime/pipeline_sktime.py | 6 +- .../model/sktime/stl_forecaster/config.json | 12 ++- .../core/model/sktime/stray/config.json | 2 +- 13 files changed, 121 insertions(+), 100 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index 489caf7863c30..8a5734036e65f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -34,13 +34,13 @@ def _preprocess(self, inputs): """ Preprocess the input before inference, including shape validation and value transformation. """ - pass + return inputs def _postprocess(self, output: torch.Tensor): """ Post-process the outputs after the entire inference task. """ - pass + return output class ForecastPipeline(BasicPipeline): @@ -58,7 +58,7 @@ def forecast(self, inputs, **infer_kwargs): pass def _postprocess(self, output: torch.Tensor): - pass + return output class ClassificationPipeline(BasicPipeline): @@ -66,16 +66,13 @@ def __init__(self, model_info, **model_kwargs): super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): - pass + return inputs def classify(self, inputs, **kwargs): pass - def _post_decode(self): - pass - def _postprocess(self, output: torch.Tensor): - pass + return output class ChatPipeline(BasicPipeline): @@ -83,13 +80,10 @@ def __init__(self, model_info, **model_kwargs): super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): - pass + return inputs def chat(self, inputs, **kwargs): pass - def _post_decode(self): - pass - def _postprocess(self, output: torch.Tensor): - pass + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py index 2225e1a53045b..e221b2e69150c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -29,13 +29,16 @@ def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs): - if model_info.category == ModelCategory.BUILTIN: - module_id = ( + if model_info.model_type == "sktime": + from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline + pipeline_cls = SktimePipeline + elif model_info.category == ModelCategory.BUILTIN: + module_name = ( AINodeDescriptor().get_config().get_ain_models_builtin_dir() + "." + model_info.model_id ) - pipeline_cls = import_class_from_path(module_id, model_info.pipeline_cls) + pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls) else: model_path = os.path.join( os.getcwd(), diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index aace7183a7e1a..738c4a1c70910 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -52,7 +52,7 @@ def load_model(model_info: ModelInfo, **model_kwargs) -> Any: model = load_model_from_pt(model_info, **model_kwargs) logger.info( - f"Model {model_info.model_id} loaded to device {model.device} successfully." + f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully." ) return model diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json index dcdc133529090..87fa8859c4806 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -1,22 +1,13 @@ { "model_type": "sktime", - "name": "ARIMA", + "model_id": "arima", "predict_length": 1, "order": [1, 0, 0], - "seasonal_order": [0, 0, 0, 0], - "method": "lbfgs", - "maxiter": 1, - "suppress_warnings": true, - "out_of_sample_size": 0, - "scoring": "mse", - "with_intercept": true, - "time_varying_regression": false, - "enforce_stationarity": true, - "enforce_invertibility": true, - "simple_differencing": false, - "measurement_error": false, - "mle_regression": true, - "hamilton_representation": false, - "concentrate_scale": false + "season_length": 1, + "seasonal_order": [0, 0, 0], + "include_mean": true, + "include_drift": false, + "biasadj": false, + "method": "CSS-ML" } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py index bd780da3a73fb..6f08c8cd2c626 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -16,10 +16,6 @@ # under the License. # -""" -Sktime model configuration module - simplified version -""" - from dataclasses import dataclass, field from typing import Any, Dict, List, Union @@ -50,12 +46,16 @@ class AttributeConfig: def validate_value(self, value): """Validate if the value meets the requirements""" if self.type == "int": + if value is None: + return True # Allow None for optional int parameters if not isinstance(value, int): raise WrongAttributeTypeError(self.name, "int") if self.low is not None and self.high is not None: if not (self.low <= value <= self.high): raise NumericalRangeException(self.name, value, self.low, self.high) elif self.type == "float": + if value is None: + return True # Allow None for optional float parameters if not isinstance(value, (int, float)): raise WrongAttributeTypeError(self.name, "float") value = float(value) @@ -63,11 +63,15 @@ def validate_value(self, value): if not (self.low <= value <= self.high): raise NumericalRangeException(self.name, value, self.low, self.high) elif self.type == "str": + if value is None: + return True # Allow None for optional str parameters if not isinstance(value, str): raise WrongAttributeTypeError(self.name, "str") if self.choices and value not in self.choices: raise StringRangeException(self.name, value, self.choices) elif self.type == "bool": + if value is None: + return True # Allow None for optional bool parameters if not isinstance(value, bool): raise WrongAttributeTypeError(self.name, "bool") elif self.type == "list": @@ -87,22 +91,30 @@ def validate_value(self, value): def parse(self, string_value: str): """Parse string value to corresponding type""" if self.type == "int": + if string_value.lower() == "none" or string_value.strip() == "": + return None try: return int(string_value) except: raise WrongAttributeTypeError(self.name, "int") elif self.type == "float": + if string_value.lower() == "none" or string_value.strip() == "": + return None try: return float(string_value) except: raise WrongAttributeTypeError(self.name, "float") elif self.type == "str": + if string_value.lower() == "none" or string_value.strip() == "": + return None return string_value elif self.type == "bool": if string_value.lower() == "true": return True elif string_value.lower() == "false": return False + elif string_value.lower() == "none" or string_value.strip() == "": + return None else: raise WrongAttributeTypeError(self.name, "bool") elif self.type == "list": @@ -142,9 +154,10 @@ def parse(self, string_value: str): MODEL_CONFIGS = { "NAIVE_FORECASTER": { "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), - "pipeline": AttributeConfig( - "pipeline", "last", "str", choices=["last", "mean"] + "strategy": AttributeConfig( + "strategy", "last", "str", choices=["last", "mean", "drift"] ), + "window_length": AttributeConfig("window_length", None, "int"), "sp": AttributeConfig("sp", 1, "int", 1, 5000), }, "EXPONENTIAL_SMOOTHING": { @@ -163,48 +176,40 @@ def parse(self, string_value: str): "ARIMA": { "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "season_length": AttributeConfig("season_length", 1, "int", 1, 5000), "seasonal_order": AttributeConfig( - "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + "seasonal_order", (0, 0, 0), "tuple", value_type=int ), + "include_mean": AttributeConfig("include_mean", True, "bool"), + "include_drift": AttributeConfig("include_drift", False, "bool"), + "include_constant": AttributeConfig("include_constant", None, "bool"), + "blambda": AttributeConfig("blambda", None, "float"), + "biasadj": AttributeConfig("biasadj", False, "bool"), "method": AttributeConfig( "method", - "lbfgs", - "str", - choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], - ), - "maxiter": AttributeConfig("maxiter", 1, "int", 1, 5000), - "suppress_warnings": AttributeConfig("suppress_warnings", True, "bool"), - "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), - "scoring": AttributeConfig( - "scoring", - "mse", + "CSS-ML", "str", - choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], - ), - "with_intercept": AttributeConfig("with_intercept", True, "bool"), - "time_varying_regression": AttributeConfig( - "time_varying_regression", False, "bool" + choices=["CSS-ML", "ML", "CSS"], ), - "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), - "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), - "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), - "measurement_error": AttributeConfig("measurement_error", False, "bool"), - "mle_regression": AttributeConfig("mle_regression", True, "bool"), - "hamilton_representation": AttributeConfig( - "hamilton_representation", False, "bool" - ), - "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), }, "STL_FORECASTER": { "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), "sp": AttributeConfig("sp", 2, "int", 1, 5000), "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "trend": AttributeConfig("trend", None, "int"), + "low_pass": AttributeConfig("low_pass", None, "int"), "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "robust": AttributeConfig("robust", False, "bool"), "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + "inner_iter": AttributeConfig("inner_iter", None, "int"), + "outer_iter": AttributeConfig("outer_iter", None, "int"), + "forecaster_trend": AttributeConfig("forecaster_trend", None, "str"), + "forecaster_seasonal": AttributeConfig("forecaster_seasonal", None, "str"), + "forecaster_resid": AttributeConfig("forecaster_resid", None, "str"), }, "GAUSSIAN_HMM": { "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), @@ -219,15 +224,17 @@ def parse(self, string_value: str): "startprob_prior", 1.0, "float", -1e10, 1e10 ), "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), - "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), - "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), - "covars_prior": AttributeConfig("covars_prior", 1e-2, "float", -1e10, 1e10), - "covars_weight": AttributeConfig("covars_weight", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 0.01, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1, "float", -1e10, 1e10), "algorithm": AttributeConfig( "algorithm", "viterbi", "str", choices=["viterbi", "map"] ), + "random_state": AttributeConfig("random_state", None, "float"), "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), "init_params": AttributeConfig( "init_params", "stmc", "str", choices=["stmc", "stm"] @@ -247,6 +254,8 @@ def parse(self, string_value: str): "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", None, "float"), + "covars_weight": AttributeConfig("covars_weight", None, "float"), "algorithm": AttributeConfig( "algorithm", "viterbi", "str", choices=["viterbi", "map"] ), @@ -254,10 +263,12 @@ def parse(self, string_value: str): "covariance_type", "diag", "str", - choices=["sperical", "diag", "full", "tied"], + choices=["spherical", "diag", "full", "tied"], ), + "random_state": AttributeConfig("random_state", None, "int"), "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), "init_params": AttributeConfig( "init_params", "stmcw", diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json index d6002fb26e87a..4126d9de857a6 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json @@ -1,6 +1,6 @@ { "model_type": "sktime", - "name": "ExponentialSmoothing", + "model_id": "exponential_smoothing", "predict_length": 1, "damped_trend": false, "initialization_method": "estimated", diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json index 3392e1c0b57c8..94f7d7ec659fc 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json @@ -1,18 +1,20 @@ { "model_type": "sktime", - "name": "GaussianHMM", + "model_id": "gaussian_hmm", "n_components": 1, "covariance_type": "diag", "min_covar": 0.001, "startprob_prior": 1.0, "transmat_prior": 1.0, - "means_prior": 0.0, - "means_weight": 0.0, + "means_prior": 0, + "means_weight": 0, "covars_prior": 0.01, - "covars_weight": 1.0, + "covars_weight": 1, "algorithm": "viterbi", + "random_state": null, "n_iter": 10, "tol": 0.01, + "verbose": false, "params": "stmc", "init_params": "stmc", "implementation": "log" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json index 235f8ae642da4..fb19d1aaf86d9 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json @@ -1,6 +1,6 @@ { "model_type": "sktime", - "name": "GMMHMM", + "model_id": "gmm_hmm", "n_components": 1, "n_mix": 1, "min_covar": 0.001, @@ -9,10 +9,14 @@ "weights_prior": 1.0, "means_prior": 0.0, "means_weight": 0.0, + "covars_prior": null, + "covars_weight": null, "algorithm": "viterbi", "covariance_type": "diag", + "random_state": null, "n_iter": 10, "tol": 0.01, + "verbose": false, "init_params": "stmcw", "params": "stmcw", "implementation": "log" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py index f272e3dda3579..8efe25ba0a3e7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -16,18 +16,15 @@ # under the License. # -""" -Sktime model implementation module - simplified version -""" - from abc import abstractmethod from typing import Any, Dict import numpy as np +import pandas as pd from sklearn.preprocessing import MinMaxScaler from sktime.detection.hmm_learn import GMMHMM, GaussianHMM from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA +from statsforecast.models import ARIMA from sktime.forecasting.exp_smoothing import ExponentialSmoothing from sktime.forecasting.naive import NaiveForecaster from sktime.forecasting.trend import STLForecaster @@ -51,7 +48,7 @@ def __init__(self, attributes: Dict[str, Any]): self._model = None @abstractmethod - def generate(self, data): + def generate(self, data, **kwargs): """Execute generation/inference""" raise NotImplementedError @@ -59,12 +56,15 @@ def generate(self, data): class ForecastingModel(SktimeModel): """Base class for forecasting models""" - def generate(self, data): + def generate(self, data, **kwargs): """Execute forecasting""" try: - predict_length = self._attributes["predict_length"] + predict_length = kwargs.get("predict_length", self._attributes["predict_length"]) self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) + if isinstance(self._model, ARIMA): + output = self._model.predict(h=predict_length)['mean'] + else: + output = self._model.predict(fh=range(predict_length)) return np.array(output, dtype=np.float64) except Exception as e: raise InferenceModelInternalError(str(e)) @@ -73,12 +73,15 @@ def generate(self, data): class DetectionModel(SktimeModel): """Base class for detection models""" - def generate(self, data): + def generate(self, data, **kwargs): """Execute detection""" try: - self._model.fit(data) - output = self._model.predict(data) - return np.array(output, dtype=np.int32) + predict_length = kwargs.get("predict_length", data.size) + output = self._model.fit_transform(data[:predict_length]) + if isinstance(output, pd.DataFrame): + return np.array(output["labels"], dtype=np.int32) + else: + return np.array(output, dtype=np.int32) except Exception as e: raise InferenceModelInternalError(str(e)) @@ -89,7 +92,7 @@ class ArimaModel(ForecastingModel): def __init__(self, attributes: Dict[str, Any]): super().__init__(attributes) self._model = ARIMA( - **{k: v for k, v in attributes.items() if k != "predict_length"} + **{k: v for k, v in attributes.items() if k != "predict_length" and v is not None} ) @@ -144,14 +147,16 @@ class STRAYModel(DetectionModel): def __init__(self, attributes: Dict[str, Any]): super().__init__(attributes) - self._model = STRAY(**attributes) + self._model = STRAY( + **{k: v for k, v in attributes.items() if v is not None} + ) - def generate(self, data): + def generate(self, data, **kwargs): """STRAY requires special handling: normalize first""" try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - return np.array(output, dtype=np.int32) + scaled_data = MinMaxScaler().fit_transform(data.values.reshape(-1, 1)) + scaled_data = pd.Series(scaled_data.flatten()) + return super().generate(scaled_data, **kwargs) except Exception as e: raise InferenceModelInternalError(str(e)) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json index 20d8c1ed32b5c..3dadd7c3b1e5d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json @@ -1,8 +1,9 @@ { "model_type": "sktime", - "name": "NaiveForecaster", + "model_id": "naive_forecaster", "predict_length": 1, "strategy": "last", + "window_length": null, "sp": 1 } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py index ccac477ca83ae..68698a3e44910 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -26,11 +26,13 @@ class SktimePipeline(BasicPipeline): def __init__(self, model_info, **model_kwargs): super().__init__(model_info, model_kwargs=model_kwargs) + model_kwargs.pop("device", None) def _preprocess(self, inputs): return super()._preprocess(inputs) def infer(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) input_ids = self._preprocess(inputs) # Convert to pandas Series for sktime (sktime expects Series or DataFrame) @@ -44,7 +46,7 @@ def infer(self, inputs, **infer_kwargs): if isinstance(input_ids, torch.Tensor) else input_ids[i] ) - output = self.model.generate(series) + output = self.model.generate(series, predict_length=predict_length) outputs.append(output) output = np.array(outputs) else: @@ -53,7 +55,7 @@ def infer(self, inputs, **infer_kwargs): series = pd.Series(input_ids.squeeze().cpu().numpy()) else: series = pd.Series(input_ids.squeeze()) - output = self.model.generate(series) + output = self.model.generate(series, predict_length=predict_length) # Add batch dimension if needed if len(output.shape) == 1: output = output[np.newaxis, :] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json index 1005f9d944e9d..bfe71dbc48614 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json @@ -1,14 +1,22 @@ { "model_type": "sktime", - "name": "STLForecaster", + "model_id": "stl_forecaster", "predict_length": 1, "sp": 2, "seasonal": 7, + "trend": null, + "low_pass": null, "seasonal_deg": 1, "trend_deg": 1, "low_pass_deg": 1, + "robust": false, "seasonal_jump": 1, "trend_jump": 1, - "low_pass_jump": 1 + "low_pass_jump": 1, + "inner_iter": null, + "outer_iter": null, + "forecaster_trend": null, + "forecaster_seasonal": null, + "forecaster_resid": null } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json index 64c64aa9e0514..e5bcc03cd0714 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json @@ -1,6 +1,6 @@ { "model_type": "sktime", - "name": "STRAY", + "model_id": "stray", "alpha": 0.01, "k": 10, "knn_algorithm": "brute", From a8bfbcc61336a64032a23eafd7a933c441c6800a Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Wed, 3 Dec 2025 20:14:01 +0800 Subject: [PATCH 14/38] Update model_loader.py --- iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index 738c4a1c70910..a6e3b1f7b5e38 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -142,15 +142,15 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs): return model.to(device_map) -def load_model_for_efficient_inference(self): +def load_model_for_efficient_inference(): # TODO: An efficient model loading method for inference based on model_arguments pass -def load_model_for_powerful_finetune(self): +def load_model_for_powerful_finetune(): # TODO: An powerful model loading method for finetune based on model_arguments pass -def unload_model(self): +def unload_model(): pass From 8d00ce77ea6dc245b72bd3cea320893f2b2e7e9a Mon Sep 17 00:00:00 2001 From: Gewu <89496957+RkGrit@users.noreply.github.com> Date: Thu, 4 Dec 2025 10:16:38 +0800 Subject: [PATCH 15/38] support various pipeline Interfaces and support arima with sktime package (#16861) --- .../core/inference/inference_request_pool.py | 24 +++++++++--- .../core/inference/pipeline/basic_pipeline.py | 10 ++--- .../inference/pipeline/pipeline_loader.py | 1 + .../ainode/core/manager/inference_manager.py | 14 +++++-- .../ainode/iotdb/ainode/core/manager/utils.py | 4 +- .../core/model/sktime/arima/config.json | 24 +++++++++--- .../core/model/sktime/configuration_sktime.py | 37 ++++++++++++++----- .../core/model/sktime/modeling_sktime.py | 17 ++++----- .../core/model/sktime/pipeline_sktime.py | 10 ++--- .../core/model/sundial/pipeline_sundial.py | 9 ++++- .../core/model/timer_xl/pipeline_timer.py | 9 ++++- iotdb-core/ainode/pyproject.toml | 2 +- 12 files changed, 111 insertions(+), 50 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 2fb00988bef9b..8164612336239 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -30,6 +30,7 @@ from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline, ClassificationPipeline, ChatPipeline from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, @@ -116,11 +117,24 @@ def _step(self): for requests in grouped_requests: batch_inputs = self._batcher.batch_request(requests).to(self.device) - batch_output = self._inference_pipeline.infer( - batch_inputs, - predict_length=requests[0].max_new_tokens, - revin=True, - ) + if isinstance(self._inference_pipeline, ForecastPipeline): + batch_output = self._inference_pipeline.forecast( + batch_inputs, + predict_length=requests[0].max_new_tokens, + revin=True, + ) + elif isinstance(self._inference_pipeline, ClassificationPipeline): + batch_output = self._inference_pipeline.classify( + batch_inputs, + # more infer kwargs can be added here + ) + elif isinstance(self._inference_pipeline, ChatPipeline): + batch_output = self._inference_pipeline.chat( + batch_inputs, + # more infer kwargs can be added here + ) + else: + self._logger.error("[Inference] Unsupported pipeline type.") offset = 0 for request in requests: request.output_tensor = request.output_tensor.to(self.device) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py index 8a5734036e65f..82601e398059c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -16,11 +16,10 @@ # under the License. # -from abc import ABC +from abc import ABC, abstractmethod import torch -from iotdb.ainode.core.exception import InferenceModelInternalError from iotdb.ainode.core.model.model_loader import load_model @@ -48,12 +47,9 @@ def __init__(self, model_info, **model_kwargs): super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): - if len(inputs.shape) != 2: - raise InferenceModelInternalError( - f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" - ) return inputs + @abstractmethod def forecast(self, inputs, **infer_kwargs): pass @@ -68,6 +64,7 @@ def __init__(self, model_info, **model_kwargs): def _preprocess(self, inputs): return inputs + @abstractmethod def classify(self, inputs, **kwargs): pass @@ -82,6 +79,7 @@ def __init__(self, model_info, **model_kwargs): def _preprocess(self, inputs): return inputs + @abstractmethod def chat(self, inputs, **kwargs): pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py index e221b2e69150c..a30038dd5feff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -31,6 +31,7 @@ def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs): if model_info.model_type == "sktime": from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline + pipeline_cls = SktimePipeline elif model_info.category == ModelCategory.BUILTIN: module_name = ( diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 183c942e3b5aa..24f9fa883218f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -35,6 +35,7 @@ InferenceRequestProxy, ) from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline, ClassificationPipeline, ChatPipeline from iotdb.ainode.core.inference.pool_controller import PoolController from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger @@ -210,9 +211,16 @@ def _run( else: model_info = self._model_manager.get_model_info(model_id) inference_pipeline = load_pipeline(model_info, device="cpu") - outputs = inference_pipeline.infer( - inputs, predict_length=predict_length, **inference_attrs - ) + if isinstance(inference_pipeline, ForecastPipeline): + outputs = inference_pipeline.forecast( + inputs, predict_length=predict_length, **inference_attrs + ) + elif isinstance(inference_pipeline, ClassificationPipeline): + outputs = inference_pipeline.classify(inputs) + elif isinstance(inference_pipeline, ChatPipeline): + outputs = inference_pipeline.chat(inputs) + else: + logger.error("[Inference] Unsupported pipeline type.") outputs = convert_to_binary(pd.DataFrame(outputs[0])) # construct response diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 2b032bb29fbdd..23a98f26bbffa 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -91,7 +91,9 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] - mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO + mem_usage = ( + MODEL_MEM_USAGE_MAP[model_info.model_type] * INFERENCE_EXTRA_MEMORY_RATIO + ) size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage) if size <= 0: logger.error( diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json index 87fa8859c4806..1561124badd12 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -3,11 +3,23 @@ "model_id": "arima", "predict_length": 1, "order": [1, 0, 0], - "season_length": 1, - "seasonal_order": [0, 0, 0], - "include_mean": true, - "include_drift": false, - "biasadj": false, - "method": "CSS-ML" + "seasonal_order": [0, 0, 0, 0], + "start_params": null, + "method": "lbfgs", + "maxiter": 50, + "suppress_warnings": false, + "out_of_sample_size": 0, + "scoring": "mse", + "scoring_args": null, + "trend": null, + "with_intercept": true, + "time_varying_regression": false, + "enforce_stationarity": true, + "enforce_invertibility": true, + "simple_differencing": false, + "measurement_error": false, + "mle_regression": true, + "hamilton_representation": false, + "concentrate_scale": false } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py index 6f08c8cd2c626..261de3c9abe7c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -176,21 +176,40 @@ def parse(self, string_value: str): "ARIMA": { "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), - "season_length": AttributeConfig("season_length", 1, "int", 1, 5000), "seasonal_order": AttributeConfig( - "seasonal_order", (0, 0, 0), "tuple", value_type=int + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int ), - "include_mean": AttributeConfig("include_mean", True, "bool"), - "include_drift": AttributeConfig("include_drift", False, "bool"), - "include_constant": AttributeConfig("include_constant", None, "bool"), - "blambda": AttributeConfig("blambda", None, "float"), - "biasadj": AttributeConfig("biasadj", False, "bool"), + "start_params": AttributeConfig("start_params", None, "str"), "method": AttributeConfig( "method", - "CSS-ML", + "lbfgs", "str", - choices=["CSS-ML", "ML", "CSS"], + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], ), + "maxiter": AttributeConfig("maxiter", 50, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", False, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "scoring_args": AttributeConfig("scoring_args", None, "str"), + "trend": AttributeConfig("trend", None, "str"), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), }, "STL_FORECASTER": { "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py index 8efe25ba0a3e7..eca812d35ec9a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -24,7 +24,7 @@ from sklearn.preprocessing import MinMaxScaler from sktime.detection.hmm_learn import GMMHMM, GaussianHMM from sktime.detection.stray import STRAY -from statsforecast.models import ARIMA +from sktime.forecasting.arima import ARIMA from sktime.forecasting.exp_smoothing import ExponentialSmoothing from sktime.forecasting.naive import NaiveForecaster from sktime.forecasting.trend import STLForecaster @@ -59,12 +59,11 @@ class ForecastingModel(SktimeModel): def generate(self, data, **kwargs): """Execute forecasting""" try: - predict_length = kwargs.get("predict_length", self._attributes["predict_length"]) + predict_length = kwargs.get( + "predict_length", self._attributes["predict_length"] + ) self._model.fit(data) - if isinstance(self._model, ARIMA): - output = self._model.predict(h=predict_length)['mean'] - else: - output = self._model.predict(fh=range(predict_length)) + output = self._model.predict(fh=range(predict_length)) return np.array(output, dtype=np.float64) except Exception as e: raise InferenceModelInternalError(str(e)) @@ -92,7 +91,7 @@ class ArimaModel(ForecastingModel): def __init__(self, attributes: Dict[str, Any]): super().__init__(attributes) self._model = ARIMA( - **{k: v for k, v in attributes.items() if k != "predict_length" and v is not None} + **{k: v for k, v in attributes.items() if k != "predict_length"} ) @@ -147,9 +146,7 @@ class STRAYModel(DetectionModel): def __init__(self, attributes: Dict[str, Any]): super().__init__(attributes) - self._model = STRAY( - **{k: v for k, v in attributes.items() if v is not None} - ) + self._model = STRAY(**{k: v for k, v in attributes.items() if v is not None}) def generate(self, data, **kwargs): """STRAY requires special handling: normalize first""" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py index 68698a3e44910..ced21f29a2b82 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -20,18 +20,18 @@ import pandas as pd import torch -from iotdb.ainode.core.inference.pipeline.basic_pipeline import BasicPipeline +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class SktimePipeline(BasicPipeline): +class SktimePipeline(ForecastPipeline): def __init__(self, model_info, **model_kwargs): + model_kwargs.pop("device", None) # sktime models run on CPU super().__init__(model_info, model_kwargs=model_kwargs) - model_kwargs.pop("device", None) def _preprocess(self, inputs): - return super()._preprocess(inputs) + return inputs - def infer(self, inputs, **infer_kwargs): + def forecast(self, inputs, **infer_kwargs): predict_length = infer_kwargs.get("predict_length", 96) input_ids = self._preprocess(inputs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index e24df5ef842ac..8b33597ab0273 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -19,6 +19,7 @@ import torch from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.exception import InferenceModelInternalError class SundialPipeline(ForecastPipeline): @@ -26,9 +27,13 @@ def __init__(self, model_info, **model_kwargs): super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): - return super()._preprocess(inputs) + if len(inputs.shape) != 2: + raise InferenceModelInternalError( + f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" + ) + return inputs - def infer(self, inputs, **infer_kwargs): + def forecast(self, inputs, **infer_kwargs): predict_length = infer_kwargs.get("predict_length", 96) num_samples = infer_kwargs.get("num_samples", 10) revin = infer_kwargs.get("revin", True) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index c1802186d71dd..bc0620ec9632c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -19,6 +19,7 @@ import torch from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline +from iotdb.ainode.core.exception import InferenceModelInternalError class TimerPipeline(ForecastPipeline): @@ -26,9 +27,13 @@ def __init__(self, model_info, **model_kwargs): super().__init__(model_info, model_kwargs=model_kwargs) def _preprocess(self, inputs): - return super()._preprocess(inputs) + if len(inputs.shape) != 2: + raise InferenceModelInternalError( + f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" + ) + return inputs - def infer(self, inputs, **infer_kwargs): + def forecast(self, inputs, **infer_kwargs): predict_length = infer_kwargs.get("predict_length", 96) revin = infer_kwargs.get("revin", True) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 331cb8ab32a34..788af64c8699f 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -93,7 +93,7 @@ scipy = "^1.12.0" pandas = "^2.3.2" scikit-learn = "^1.7.1" statsmodels = "^0.14.5" -sktime = "0.38.5" +sktime = "0.40.1" # ---- Optimizers / utils ---- optuna = "^4.4.0" From 8e8cd6f28b77471e9c9c4527235428f564d3da2c Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Thu, 4 Dec 2025 10:19:32 +0800 Subject: [PATCH 16/38] spotless ainode codes --- .../iotdb/ainode/core/inference/inference_request_pool.py | 6 +++++- .../ainode/iotdb/ainode/core/manager/inference_manager.py | 6 +++++- .../iotdb/ainode/core/model/sundial/pipeline_sundial.py | 2 +- .../iotdb/ainode/core/model/timer_xl/pipeline_timer.py | 2 +- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 8164612336239..7011cc7f7ba69 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -30,7 +30,11 @@ from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest -from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline, ClassificationPipeline, ChatPipeline +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, +) from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 24f9fa883218f..5fc2c91cbc39a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -34,8 +34,12 @@ InferenceRequest, InferenceRequestProxy, ) +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, +) from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline -from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline, ClassificationPipeline, ChatPipeline from iotdb.ainode.core.inference.pool_controller import PoolController from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 8b33597ab0273..85b6f7db2ffef 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -18,8 +18,8 @@ import torch -from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline class SundialPipeline(ForecastPipeline): diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index bc0620ec9632c..c0f00b1f5caf3 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -18,8 +18,8 @@ import torch -from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline class TimerPipeline(ForecastPipeline): From 71daf3b45f80fa7ebfdd2cc510dcd6c353054c37 Mon Sep 17 00:00:00 2001 From: Gewu <89496957+RkGrit@users.noreply.github.com> Date: Thu, 4 Dec 2025 15:41:04 +0800 Subject: [PATCH 17/38] Add dependencies of python packages for arima, gaussian_hmm and efficient model loading (#16865) --- iotdb-core/ainode/pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 788af64c8699f..cc62089e24436 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -94,6 +94,9 @@ pandas = "^2.3.2" scikit-learn = "^1.7.1" statsmodels = "^0.14.5" sktime = "0.40.1" +pmdarima = "2.1.1" +hmmlearn = "0.3.2" +accelerate = "^1.10.1" # ---- Optimizers / utils ---- optuna = "^4.4.0" From 381aea89a7045a6d7b3fe7541747e161a827f3e4 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Thu, 4 Dec 2025 16:49:35 +0800 Subject: [PATCH 18/38] Fix call inference cannot specify outputLength --- .../ainode/it/AINodeCallInferenceIT.java | 16 +++++++++++++--- iotdb-core/ainode/iotdb/ainode/core/config.py | 16 +++++++--------- .../ainode/iotdb/ainode/core/constant.py | 2 +- .../core/inference/inference_request.py | 16 ++++++++-------- .../core/inference/inference_request_pool.py | 4 ++-- .../ainode/core/manager/inference_manager.py | 19 +++++++++---------- iotdb-core/ainode/pyproject.toml | 2 +- .../process/ai/InferenceOperator.java | 3 ++- .../src/main/thrift/ainode.thrift | 8 +------- 9 files changed, 44 insertions(+), 42 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 6bdb3e25b91b5..5368c584443fa 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -34,10 +34,12 @@ import java.sql.Connection; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; @@ -55,7 +57,8 @@ public class AINodeCallInferenceIT { }; private static final String CALL_INFERENCE_SQL_TEMPLATE = - "CALL INFERENCE(%s, \"select s%d from root.AI\")"; + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final int DEFAULT_OUTPUT_LENGTH = 48; @BeforeClass public static void setUp() throws Exception { @@ -93,14 +96,21 @@ public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo // Invoke call inference for specified models, there should exist result. for (int i = 0; i < 4; i++) { String callInferenceSQL = - String.format(CALL_INFERENCE_SQL_TEMPLATE, modelInfo.getModelId(), i); + String.format( + CALL_INFERENCE_SQL_TEMPLATE, + modelInfo.getModelId(), + i, + DEFAULT_OUTPUT_LENGTH, + DEFAULT_OUTPUT_LENGTH); try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output"); int count = 0; while (resultSet.next()) { count++; } // Ensure the call inference return results - Assert.assertTrue(count > 0); + Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count); } } } diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index 8f9f256dfc16a..e465df7e36d29 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -32,7 +32,7 @@ AINODE_CONF_POM_FILE_NAME, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, - AINODE_INFERENCE_MAX_PREDICT_LENGTH, + AINODE_INFERENCE_MAX_OUTPUT_LENGTH, AINODE_INFERENCE_MEMORY_USAGE_RATIO, AINODE_INFERENCE_MODEL_MEM_USAGE_MAP, AINODE_LOG_DIR, @@ -75,9 +75,7 @@ def __init__(self): self._ain_inference_batch_interval_in_ms: int = ( AINODE_INFERENCE_BATCH_INTERVAL_IN_MS ) - self._ain_inference_max_predict_length: int = ( - AINODE_INFERENCE_MAX_PREDICT_LENGTH - ) + self._ain_inference_max_output_length: int = AINODE_INFERENCE_MAX_OUTPUT_LENGTH self._ain_inference_model_mem_usage_map: dict[str, int] = ( AINODE_INFERENCE_MODEL_MEM_USAGE_MAP ) @@ -160,13 +158,13 @@ def set_ain_inference_batch_interval_in_ms( ) -> None: self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms - def get_ain_inference_max_predict_length(self) -> int: - return self._ain_inference_max_predict_length + def get_ain_inference_max_output_length(self) -> int: + return self._ain_inference_max_output_length - def set_ain_inference_max_predict_length( - self, ain_inference_max_predict_length: int + def set_ain_inference_max_output_length( + self, ain_inference_max_output_length: int ) -> None: - self._ain_inference_max_predict_length = ain_inference_max_predict_length + self._ain_inference_max_output_length = ain_inference_max_output_length def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]: return self._ain_inference_model_mem_usage_map diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 100ef0138eb91..c0b19a570d20a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -48,7 +48,7 @@ # AINode inference configuration AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 -AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 +AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880 # TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index a70445c5efd4c..50634914c2737 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -39,7 +39,7 @@ def __init__( req_id: str, model_id: str, inputs: torch.Tensor, - max_new_tokens: int = 96, + output_length: int = 96, **infer_kwargs, ): if inputs.ndim == 1: @@ -49,8 +49,8 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.max_new_tokens = ( - max_new_tokens # Number of time series data points to generate + self.output_length = ( + output_length # Number of time series data points to generate ) self.batch_size = inputs.size(0) @@ -61,7 +61,7 @@ def __init__( # Preallocate output buffer [batch_size, max_new_tokens] self.output_tensor = torch.zeros( - self.batch_size, max_new_tokens, device="cpu" + self.batch_size, output_length, device="cpu" ) # shape: [self.batch_size, max_new_steps] def mark_running(self): @@ -73,7 +73,7 @@ def mark_finished(self): def is_finished(self) -> bool: return ( self.state == InferenceRequestState.FINISHED - or self.cur_step_idx >= self.max_new_tokens + or self.cur_step_idx >= self.output_length ) def write_step_output(self, step_output: torch.Tensor): @@ -83,11 +83,11 @@ def write_step_output(self, step_output: torch.Tensor): batch_size, step_size = step_output.shape end_idx = self.cur_step_idx + step_size - if end_idx > self.max_new_tokens: + if end_idx > self.output_length: self.output_tensor[:, self.cur_step_idx :] = step_output[ - :, : self.max_new_tokens - self.cur_step_idx + :, : self.output_length - self.cur_step_idx ] - self.cur_step_idx = self.max_new_tokens + self.cur_step_idx = self.output_length else: self.output_tensor[:, self.cur_step_idx : end_idx] = step_output self.cur_step_idx = end_idx diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 7011cc7f7ba69..a6c415a6c848b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -115,7 +115,7 @@ def _step(self): grouped_requests = defaultdict(list) for req in all_requests: - key = (req.inputs.shape[1], req.max_new_tokens) + key = (req.inputs.shape[1], req.output_length) grouped_requests[key].append(req) grouped_requests = list(grouped_requests.values()) @@ -124,7 +124,7 @@ def _step(self): if isinstance(self._inference_pipeline, ForecastPipeline): batch_output = self._inference_pipeline.forecast( batch_inputs, - predict_length=requests[0].max_new_tokens, + predict_length=requests[0].output_length, revin=True, ) elif isinstance(self._inference_pipeline, ClassificationPipeline): diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 5fc2c91cbc39a..1ce2e84e05929 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -189,18 +189,18 @@ def _run( inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") inference_attrs = extract_attrs(req) - predict_length = int(inference_attrs.pop("predict_length", 96)) + output_length = int(inference_attrs.pop("output_length", 96)) if ( - predict_length - > AINodeDescriptor().get_config().get_ain_inference_max_predict_length() + output_length + > AINodeDescriptor().get_config().get_ain_inference_max_output_length() ): raise NumericalRangeException( "output_length", 1, AINodeDescriptor() .get_config() - .get_ain_inference_max_predict_length(), - predict_length, + .get_ain_inference_max_output_length(), + output_length, ) if self._pool_controller.has_request_pools(model_id): @@ -208,7 +208,7 @@ def _run( req_id=generate_req_id(), model_id=model_id, inputs=inputs, - max_new_tokens=predict_length, + output_length=output_length, ) outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) @@ -217,7 +217,7 @@ def _run( inference_pipeline = load_pipeline(model_info, device="cpu") if isinstance(inference_pipeline, ForecastPipeline): outputs = inference_pipeline.forecast( - inputs, predict_length=predict_length, **inference_attrs + inputs, predict_length=output_length, **inference_attrs ) elif isinstance(inference_pipeline, ClassificationPipeline): outputs = inference_pipeline.classify(inputs) @@ -246,7 +246,7 @@ def forecast(self, req: TForecastReq): data_getter=lambda r: r.inputData, deserializer=deserialize, extract_attrs=lambda r: { - "predict_length": r.outputLength, + "output_length": r.outputLength, **(r.options or {}), }, resp_cls=TForecastResp, @@ -259,8 +259,7 @@ def inference(self, req: TInferenceReq): data_getter=lambda r: r.dataset, deserializer=deserialize, extract_attrs=lambda r: { - "window_interval": getattr(r.windowParams, "windowInterval", None), - "window_step": getattr(r.windowParams, "windowStep", None), + "output_length": int(r.inferenceAttributes.pop("outputLength", 96)), **(r.inferenceAttributes or {}), }, resp_cls=TInferenceResp, diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index cc62089e24436..e15a7724bbc4c 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = ">=2.7.0" +torch = "^2.7.1" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index ace55bd0ecf75..29e5580311d0b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -245,7 +245,8 @@ private void submitInferenceTask() { .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { return client.inference( new TInferenceReq( - modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock))); + modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock)) + .setInferenceAttributes(modelInferenceDescriptor.getInferenceAttributes())); } catch (Exception e) { throw new ModelInferenceProcessException(e.getMessage()); } diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index e0680cd29b794..1dc2f025f5c34 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -60,13 +60,7 @@ struct TRegisterModelResp { struct TInferenceReq { 1: required string modelId 2: required binary dataset - 3: optional TWindowParams windowParams - 4: optional map inferenceAttributes -} - -struct TWindowParams { - 1: required i32 windowInterval - 2: required i32 windowStep + 3: optional map inferenceAttributes } struct TInferenceResp { From f8751af22ed65714ea49ca67f985579199e657fc Mon Sep 17 00:00:00 2001 From: Gewu <89496957+RkGrit@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:12:41 +0800 Subject: [PATCH 19/38] If model is already ACTIVATING or ACTIVE, skip duplicate update and download (#16868) --- .../ainode/iotdb/ainode/core/model/model_storage.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index b3f3ee023156d..5194ed4df1bd7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -123,6 +123,14 @@ def _process_builtin_model_directory(self, model_dir: str, model_id: str): """Handling the discovery logic for a builtin model directory.""" ensure_init_file(model_dir) with self._lock_pool.get_lock(model_id).write_lock(): + # Check if model already exists and is in a valid state + existing_model = self._models[ModelCategory.BUILTIN.value].get(model_id) + if existing_model: + # If model is already ACTIVATING or ACTIVE, skip duplicate download + if existing_model.state in (ModelStates.ACTIVATING, ModelStates.ACTIVE): + return + + # If model not exists or is INACTIVE, we'll try to update its info and download its weights self._models[ModelCategory.BUILTIN.value][model_id] = ( BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] ) From 89db67f32df3ff5e67450fb1c761b9bb9ee64c6a Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Fri, 5 Dec 2025 12:53:20 +0800 Subject: [PATCH 20/38] Update builtin model path in CI envs --- .../org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java index e118d6c3a98ff..34fd7e85240cf 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java @@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper { private static final String PROPERTIES_FILE = "iotdb-ainode.properties"; public static final String CONFIG_PATH = "conf"; public static final String SCRIPT_PATH = "sbin"; - public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights"; + public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin"; public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights"; private void replaceAttribute(String[] keys, String[] values, String filePath) { From ef84301661d0d44fdceb25249c5b6197790b8ca9 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Fri, 5 Dec 2025 17:13:33 +0800 Subject: [PATCH 21/38] update torch version, should less than 2.8.0 --- iotdb-core/ainode/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index e15a7724bbc4c..3773f69a847a5 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = "^2.7.1" +torch = "^2.7.1,<2.8.0" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" From 626fbc87f40043b6bc14f807525687570cefcd22 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sun, 7 Dec 2025 13:43:06 +0800 Subject: [PATCH 22/38] Fix IoTDBDatabaseIT --- .../org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index 609a228022f8d..e7bab16ad1ff3 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -639,7 +639,7 @@ public void testInformationSchema() throws SQLException { TestUtils.assertResultSetEqual( statement.executeQuery("count devices from tables where status = 'USING'"), "count(devices),", - Collections.singleton("18,")); + Collections.singleton("19,")); TestUtils.assertResultSetEqual( statement.executeQuery( "select * from columns where table_name = 'queries' or database = 'test'"), From bacd0101377a4379dc9410114dd91be95f30bfdf Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sun, 7 Dec 2025 15:23:25 +0800 Subject: [PATCH 23/38] More essential libs when packaging --- iotdb-core/ainode/ainode.spec | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index a131b2bcff217..30a69e4d73820 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -44,11 +44,17 @@ essential_libraries = { 'torch': True, # Keep collect_all for torch as it has many dynamic imports 'transformers': True, # Keep collect_all for transformers 'safetensors': True, # Keep collect_all for safetensors + 'numpy': True, + 'scipy': True, + 'pandas': True, + 'scikit-learn': True, + 'statsmodels': True, + 'sktime': True, + 'pmdarima': True, + 'hmmlearn': True, + 'accelerate': True } -# For other libraries, use selective collection to speed up startup -other_libraries = ['sktime', 'scipy', 'pandas', 'sklearn', 'statsmodels', 'optuna'] - for lib in essential_libraries: try: lib_datas, lib_binaries, lib_hiddenimports = collect_all(lib) @@ -58,28 +64,6 @@ for lib in essential_libraries: except Exception: pass -# For other libraries, only collect submodules (lighter weight) -# This relies on PyInstaller's dependency analysis to include what's actually used -for lib in other_libraries: - try: - submodules = collect_submodules(lib) - all_hiddenimports.extend(submodules) - # Only collect essential data files and binaries, not all submodules - # This significantly reduces startup time - try: - lib_datas, lib_binaries, _ = collect_all(lib) - all_datas.extend(lib_datas) - all_binaries.extend(lib_binaries) - except Exception: - # If collect_all fails, try collect_data_files for essential data only - try: - lib_datas = collect_data_files(lib) - all_datas.extend(lib_datas) - except Exception: - pass - except Exception: - pass - # Project-specific packages that need their submodules collected # Only list top-level packages - collect_submodules will recursively collect all submodules TOP_LEVEL_PACKAGES = [ From 57a6959fc7cf1972c2a932dc6dd93ad77f889764 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sun, 7 Dec 2025 16:13:31 +0800 Subject: [PATCH 24/38] use sklearn rather than scikit-learn in .spec --- iotdb-core/ainode/ainode.spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index 30a69e4d73820..1e9b66f877f8e 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -47,7 +47,7 @@ essential_libraries = { 'numpy': True, 'scipy': True, 'pandas': True, - 'scikit-learn': True, + 'sklearn': True, 'statsmodels': True, 'sktime': True, 'pmdarima': True, From b6ebe39e5b6df6766f16783ea877ed381ffe1ea6 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Sun, 7 Dec 2025 17:30:11 +0800 Subject: [PATCH 25/38] delete useless dependency --- .../iotdb/ainode/core/inference/utils.py | 45 +++++++++---------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py index cf10b5b2cd4dc..d17f9fbcec536 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py @@ -19,7 +19,6 @@ import string import torch -from transformers.modeling_outputs import MoeCausalLMOutputWithPast def generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str: @@ -56,25 +55,25 @@ def _slice_pkv(pkv, s, e): return out -def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): - """ - split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) - split_sizes[i] = ith request's batch_size。 - """ - outs = [] - start = 0 - for bsz in split_sizes: - end = start + bsz - outs.append( - MoeCausalLMOutputWithPast( - loss=_slice_tensor(batch_out.loss, start, end), - logits=batch_out.logits[start:end], - past_key_values=_slice_pkv(batch_out.past_key_values, start, end), - hidden_states=_slice_tuple_of_tensors( - batch_out.hidden_states, start, end - ), - attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), - ) - ) - start = end - return outs +# def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): +# """ +# split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) +# split_sizes[i] = ith request's batch_size。 +# """ +# outs = [] +# start = 0 +# for bsz in split_sizes: +# end = start + bsz +# outs.append( +# MoeCausalLMOutputWithPast( +# loss=_slice_tensor(batch_out.loss, start, end), +# logits=batch_out.logits[start:end], +# past_key_values=_slice_pkv(batch_out.past_key_values, start, end), +# hidden_states=_slice_tuple_of_tensors( +# batch_out.hidden_states, start, end +# ), +# attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), +# ) +# ) +# start = end +# return outs From 08a0ea60062b7da753757a49f780cdc64abb4fb1 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 10:26:21 +0800 Subject: [PATCH 26/38] append hidden imports --- iotdb-core/ainode/ainode.spec | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index 1e9b66f877f8e..f1fbfccf6d4be 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -64,6 +64,24 @@ for lib in essential_libraries: except Exception: pass +# Some dependencies might still miss in specified operation systems, manually import them in this case +extra_hidden = [ + # torch dynamo polyfills + 'torch._dynamo.polyfills', + 'torch._dynamo.polyfills.functools', + + # torch flex attention + 'torch.nn.attention.flex_attention', + + # transformers + 'transformers.masking_utils', + 'transformers.generation.utils', + 'transformers.models.auto.modeling_auto', + 'transformers.models.auto.auto_factory', +] + +all_hiddenimports.extend(extra_hidden) + # Project-specific packages that need their submodules collected # Only list top-level packages - collect_submodules will recursively collect all submodules TOP_LEVEL_PACKAGES = [ From 305409de439b368f31ccae70d26aefb9601b5683 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 11:17:57 +0800 Subject: [PATCH 27/38] update dependency collection in .spec --- iotdb-core/ainode/ainode.spec | 9 ++++++--- iotdb-core/ainode/pyproject.toml | 4 ++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index f1fbfccf6d4be..77df7dcf70d31 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -41,9 +41,12 @@ all_hiddenimports = [] # Only collect essential data files and binaries for critical libraries # This reduces startup time by avoiding unnecessary module imports essential_libraries = { - 'torch': True, # Keep collect_all for torch as it has many dynamic imports - 'transformers': True, # Keep collect_all for transformers - 'safetensors': True, # Keep collect_all for safetensors + 'torch': True, + 'transformers': True, + 'tokenizers': True, + 'huggingface_hub': True, + 'safetensors': True, + 'hf_xet': True, 'numpy': True, 'scipy': True, 'pandas': True, diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 3773f69a847a5..335476cee9973 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = "^2.7.1,<2.8.0" +torch = "^2.7.2,<2.8.0" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" @@ -115,7 +115,7 @@ black = "25.1.0" isort = "6.0.1" setuptools = ">=75.3.0" joblib = ">=1.4.2" -urllib3 = ">=2.2.3" +urllib3 = "^2.5.0" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main" From 4d26751ee480ea62f7330438458f582d4deaab03 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 11:21:06 +0800 Subject: [PATCH 28/38] accelerate ainode compile --- .github/workflows/cluster-it-1c1d1a.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/cluster-it-1c1d1a.yml b/.github/workflows/cluster-it-1c1d1a.yml index d4c40fa7ad889..b0c6a6e8d7967 100644 --- a/.github/workflows/cluster-it-1c1d1a.yml +++ b/.github/workflows/cluster-it-1c1d1a.yml @@ -43,7 +43,7 @@ jobs: - uses: actions/checkout@v4 - name: Build AINode shell: bash - run: mvn clean package -DskipTests -P with-ainode + run: mvn clean package -pl iotdb-core/ainode -P with-ainode -am -DskipTests - name: IT Test shell: bash run: | From 322e94cf7324c06d6e721040d7d7c982d8f5845d Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 11:23:25 +0800 Subject: [PATCH 29/38] update dependency version --- iotdb-core/ainode/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 335476cee9973..e93bb3dfdaf10 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = "^2.7.2,<2.8.0" +torch = "^2.8.0,<2.9.0" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" @@ -115,7 +115,7 @@ black = "25.1.0" isort = "6.0.1" setuptools = ">=75.3.0" joblib = ">=1.4.2" -urllib3 = "^2.5.0" +urllib3 = "2.6.0" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main" From 45f349e17683c3f1df4f38992b1b5d0b21b4bcf4 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 12:38:11 +0800 Subject: [PATCH 30/38] remove useless pre-build process --- .github/workflows/cluster-it-1c1d1a.yml | 3 --- iotdb-core/ainode/ainode.spec | 3 --- 2 files changed, 6 deletions(-) diff --git a/.github/workflows/cluster-it-1c1d1a.yml b/.github/workflows/cluster-it-1c1d1a.yml index b0c6a6e8d7967..67be8f1a5f6c8 100644 --- a/.github/workflows/cluster-it-1c1d1a.yml +++ b/.github/workflows/cluster-it-1c1d1a.yml @@ -41,9 +41,6 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Build AINode - shell: bash - run: mvn clean package -pl iotdb-core/ainode -P with-ainode -am -DskipTests - name: IT Test shell: bash run: | diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index 77df7dcf70d31..606f21b0ecc5a 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -124,9 +124,6 @@ multiprocessing_modules = [ # Additional dependencies that may need explicit import # These are external libraries that might use dynamic imports external_dependencies = [ - 'huggingface_hub', - 'tokenizers', - 'hf_xet', 'einops', 'dynaconf', 'tzlocal', From 3a45d713ddbfebebe3839e0445bcb1b7fe117de7 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 13:10:14 +0800 Subject: [PATCH 31/38] Update ainode.spec --- iotdb-core/ainode/ainode.spec | 37 +++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index 606f21b0ecc5a..ab507c6128c32 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -68,11 +68,8 @@ for lib in essential_libraries: pass # Some dependencies might still miss in specified operation systems, manually import them in this case +# Use collect_submodules for packages with dynamic imports to ensure all submodules are included extra_hidden = [ - # torch dynamo polyfills - 'torch._dynamo.polyfills', - 'torch._dynamo.polyfills.functools', - # torch flex attention 'torch.nn.attention.flex_attention', @@ -85,6 +82,38 @@ extra_hidden = [ all_hiddenimports.extend(extra_hidden) +# Collect all submodules for torch._dynamo.polyfills recursively +# This is critical because torch._dynamo.polyfills contains multiple submodules that are dynamically imported +# and may not be detected by collect_all or static analysis +try: + torch_polyfills_submodules = collect_submodules('torch._dynamo.polyfills') + all_hiddenimports.extend(torch_polyfills_submodules) + print(f"Collected {len(torch_polyfills_submodules)} submodules from torch._dynamo.polyfills") +except Exception as e: + # If collection fails, add the known modules manually as fallback + print(f"Warning: Failed to collect torch._dynamo.polyfills submodules: {e}") + all_hiddenimports.extend([ + 'torch._dynamo.polyfills', + 'torch._dynamo.polyfills.functools', + 'torch._dynamo.polyfills.operator', + 'torch._dynamo.polyfills.collections', + ]) + +# Collect submodules for transformers packages that use dynamic imports +# These packages may have submodules that are not detected by collect_all +transformers_dynamic_packages = [ + 'transformers.generation', + 'transformers.models.auto', +] +for package in transformers_dynamic_packages: + try: + submodules = collect_submodules(package) + all_hiddenimports.extend(submodules) + print(f"Collected {len(submodules)} submodules from {package}") + except Exception as e: + print(f"Warning: Failed to collect submodules from {package}: {e}") + # Continue - the modules in extra_hidden should still work + # Project-specific packages that need their submodules collected # Only list top-level packages - collect_submodules will recursively collect all submodules TOP_LEVEL_PACKAGES = [ From 1168d21d56be43b850eb9799d17764a4f40c9e9b Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 13:48:48 +0800 Subject: [PATCH 32/38] update hidden import collection --- iotdb-core/ainode/ainode.spec | 107 ++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 49 deletions(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index ab507c6128c32..dd0fb72347664 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -67,69 +67,78 @@ for lib in essential_libraries: except Exception: pass -# Some dependencies might still miss in specified operation systems, manually import them in this case -# Use collect_submodules for packages with dynamic imports to ensure all submodules are included -extra_hidden = [ - # torch flex attention - 'torch.nn.attention.flex_attention', - - # transformers - 'transformers.masking_utils', - 'transformers.generation.utils', - 'transformers.models.auto.modeling_auto', - 'transformers.models.auto.auto_factory', -] -all_hiddenimports.extend(extra_hidden) - -# Collect all submodules for torch._dynamo.polyfills recursively -# This is critical because torch._dynamo.polyfills contains multiple submodules that are dynamically imported -# and may not be detected by collect_all or static analysis -try: - torch_polyfills_submodules = collect_submodules('torch._dynamo.polyfills') - all_hiddenimports.extend(torch_polyfills_submodules) - print(f"Collected {len(torch_polyfills_submodules)} submodules from torch._dynamo.polyfills") -except Exception as e: - # If collection fails, add the known modules manually as fallback - print(f"Warning: Failed to collect torch._dynamo.polyfills submodules: {e}") - all_hiddenimports.extend([ - 'torch._dynamo.polyfills', - 'torch._dynamo.polyfills.functools', - 'torch._dynamo.polyfills.operator', - 'torch._dynamo.polyfills.collections', - ]) - -# Collect submodules for transformers packages that use dynamic imports -# These packages may have submodules that are not detected by collect_all -transformers_dynamic_packages = [ - 'transformers.generation', - 'transformers.models.auto', -] -for package in transformers_dynamic_packages: +# Helper function to collect submodules with fallback +def collect_submodules_with_fallback(package, fallback_modules=None, package_name=None): + """ + Collect all submodules for a package, with fallback to manual module list if collection fails. + + Args: + package: Package name to collect submodules from + fallback_modules: List of module names to add if collection fails (optional) + package_name: Display name for logging (defaults to package) + """ + if package_name is None: + package_name = package try: submodules = collect_submodules(package) all_hiddenimports.extend(submodules) - print(f"Collected {len(submodules)} submodules from {package}") + print(f"Collected {len(submodules)} submodules from {package_name}") except Exception as e: - print(f"Warning: Failed to collect submodules from {package}: {e}") - # Continue - the modules in extra_hidden should still work + print(f"Warning: Failed to collect {package_name} submodules: {e}") + if fallback_modules: + all_hiddenimports.extend(fallback_modules) + print(f"Using fallback modules for {package_name}") + + +# Packages that need submodule collection due to dynamic imports +# Format: (package_name, fallback_modules_list, display_name) +submodule_collection_configs = [ + # torch._dynamo.polyfills - critical for torch dynamo functionality + ( + 'torch._dynamo.polyfills', + [ + 'torch._dynamo.polyfills', + 'torch._dynamo.polyfills.functools', + 'torch._dynamo.polyfills.operator', + 'torch._dynamo.polyfills.collections', + ], + 'torch._dynamo.polyfills' + ), + # transformers packages with dynamic imports + ('transformers.generation', None, 'transformers.generation'), + ('transformers.models.auto', None, 'transformers.models.auto'), + # scipy.stats - contains many private modules (starting with _) + ( + 'scipy.stats', + [ + 'scipy.stats._variation', + 'scipy.stats._morestats', + 'scipy.stats._stats', + 'scipy.stats._distn_infrastructure', + ], + 'scipy.stats' + ), + # sklearn - has many submodules that may be dynamically imported + ('sklearn', None, 'sklearn'), +] + +# Collect submodules for all configured packages +for package, fallback_modules, display_name in submodule_collection_configs: + collect_submodules_with_fallback(package, fallback_modules, display_name) # Project-specific packages that need their submodules collected # Only list top-level packages - collect_submodules will recursively collect all submodules -TOP_LEVEL_PACKAGES = [ +project_packages = [ 'iotdb.ainode.core', # This will include all sub-packages: manager, model, inference, etc. 'iotdb.thrift', # This will include all thrift sub-packages ] # Collect all submodules for project packages automatically # Using top-level packages avoids duplicate collection -for package in TOP_LEVEL_PACKAGES: - try: - submodules = collect_submodules(package) - all_hiddenimports.extend(submodules) - except Exception: - # If package doesn't exist or collection fails, add the package itself - all_hiddenimports.append(package) +# If collection fails, add the package itself as fallback +for package in project_packages: + collect_submodules_with_fallback(package, fallback_modules=[package], package_name=package) # Add parent packages to ensure they are included all_hiddenimports.extend(['iotdb', 'iotdb.ainode']) From 7a6865e330c81fe10f67a9218e9f0f51555d7ed7 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 14:28:21 +0800 Subject: [PATCH 33/38] Update ainode.spec --- iotdb-core/ainode/ainode.spec | 41 ++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index dd0fb72347664..11420c623d50b 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -58,6 +58,7 @@ essential_libraries = { 'accelerate': True } +# Collect all libraries using collect_all (includes data files and binaries) for lib in essential_libraries: try: lib_datas, lib_binaries, lib_hiddenimports = collect_all(lib) @@ -67,6 +68,25 @@ for lib in essential_libraries: except Exception: pass +# Additionally collect ALL submodules for libraries that commonly have dynamic imports +# This is a more aggressive approach but ensures we don't miss any modules +# Libraries that are known to have many dynamic imports and submodules +libraries_with_dynamic_imports = [ + 'scipy', # Has many subpackages: stats, interpolate, optimize, linalg, sparse, signal, etc. + 'sklearn', # Has many submodules that may be dynamically imported + 'transformers', # Has dynamic model loading + 'torch', # Has many submodules, especially _dynamo.polyfills +] + +# Collect all submodules for these libraries to ensure comprehensive coverage +for lib in libraries_with_dynamic_imports: + try: + submodules = collect_submodules(lib) + all_hiddenimports.extend(submodules) + print(f"Collected {len(submodules)} submodules from {lib}") + except Exception as e: + print(f"Warning: Failed to collect submodules from {lib}: {e}") + # Helper function to collect submodules with fallback def collect_submodules_with_fallback(package, fallback_modules=None, package_name=None): @@ -91,10 +111,13 @@ def collect_submodules_with_fallback(package, fallback_modules=None, package_nam print(f"Using fallback modules for {package_name}") -# Packages that need submodule collection due to dynamic imports +# Additional specific packages that need submodule collection +# Note: scipy, sklearn, transformers, torch are already collected above via libraries_with_dynamic_imports +# This section is for more specific sub-packages that need special handling # Format: (package_name, fallback_modules_list, display_name) submodule_collection_configs = [ # torch._dynamo.polyfills - critical for torch dynamo functionality + # (torch is already collected above, but this ensures polyfills are included) ( 'torch._dynamo.polyfills', [ @@ -105,22 +128,10 @@ submodule_collection_configs = [ ], 'torch._dynamo.polyfills' ), - # transformers packages with dynamic imports + # transformers sub-packages with dynamic imports + # (transformers is already collected above, but these specific sub-packages may need extra attention) ('transformers.generation', None, 'transformers.generation'), ('transformers.models.auto', None, 'transformers.models.auto'), - # scipy.stats - contains many private modules (starting with _) - ( - 'scipy.stats', - [ - 'scipy.stats._variation', - 'scipy.stats._morestats', - 'scipy.stats._stats', - 'scipy.stats._distn_infrastructure', - ], - 'scipy.stats' - ), - # sklearn - has many submodules that may be dynamically imported - ('sklearn', None, 'sklearn'), ] # Collect submodules for all configured packages From 344fffae632044ba47ef54afd0543322023597ec Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 15:18:13 +0800 Subject: [PATCH 34/38] Update ainode.spec --- iotdb-core/ainode/ainode.spec | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index 11420c623d50b..bde8b845fb8c5 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -154,6 +154,19 @@ for package in project_packages: # Add parent packages to ensure they are included all_hiddenimports.extend(['iotdb', 'iotdb.ainode']) +# Fix circular import issues in scipy.stats +# scipy.stats has circular imports that can cause issues in PyInstaller +# We need to ensure _stats is imported before scipy.stats tries to import it +# This helps resolve the "partially initialized module" error +scipy_stats_critical_modules = [ + 'scipy.stats._stats', # Core stats module, must be imported first + 'scipy.stats._stats_py', # Python implementation + 'scipy.stats._continuous_distns', # Continuous distributions + 'scipy.stats._discrete_distns', # Discrete distributions + 'scipy.stats.distributions', # Distribution base classes +] +all_hiddenimports.extend(scipy_stats_critical_modules) + # Multiprocessing support for PyInstaller # When using multiprocessing with PyInstaller, we need to ensure proper handling multiprocessing_modules = [ @@ -212,7 +225,9 @@ a = Analysis( win_no_prefer_redirects=False, win_private_assemblies=False, cipher=block_cipher, - noarchive=True, # Set to True to speed up startup - files are not archived into PYZ + noarchive=False, # Set to False to avoid circular import issues with scipy.stats + # When noarchive=True, modules are loaded as separate files which can cause + # circular import issues. Using PYZ archive helps PyInstaller handle module loading order better. ) # Package all PYZ files From 5cc12be4c0e422863fc1408d2548aeb3dbf86a95 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 16:12:52 +0800 Subject: [PATCH 35/38] Fix system dir location --- iotdb-core/ainode/iotdb/ainode/core/constant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index c0b19a570d20a..6cb7d91ed9828 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -63,7 +63,7 @@ AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model" -AINODE_SYSTEM_DIR = "data/ainode/system" +AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") AINODE_LOG_DIR = "logs" # AINode log From fd16677b85fc52566c7731f6d68b7b3006680f76 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 19:27:53 +0800 Subject: [PATCH 36/38] Fix CI bugs --- .../ainode/it/AINodeCallInferenceIT.java | 5 +-- .../ainode/it/AINodeConcurrentForecastIT.java | 3 +- .../iotdb/ainode/it/AINodeForecastIT.java | 6 ++-- .../ainode/it/AINodeInstanceManagementIT.java | 6 ++-- .../iotdb/ainode/it/AINodeModelManageIT.java | 4 +-- .../iotdb/ainode/utils/AINodeTestUtils.java | 31 +++++++------------ .../ainode/iotdb/ainode/core/constant.py | 30 +----------------- 7 files changed, 25 insertions(+), 60 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 5368c584443fa..523a7f9aadb4a 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -58,6 +58,7 @@ public class AINodeCallInferenceIT { private static final String CALL_INFERENCE_SQL_TEMPLATE = "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final int DEFAULT_INPUT_LENGTH = 256; private static final int DEFAULT_OUTPUT_LENGTH = 48; @BeforeClass @@ -100,7 +101,7 @@ public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo CALL_INFERENCE_SQL_TEMPLATE, modelInfo.getModelId(), i, - DEFAULT_OUTPUT_LENGTH, + DEFAULT_INPUT_LENGTH, DEFAULT_OUTPUT_LENGTH); try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); @@ -119,7 +120,7 @@ public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo public void errorCallInferenceTestInTree() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; + String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\")"; errorTest(statement, sql, "1505: model [notFound404] has not been created."); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java index a23eec97497df..64029c1e34b8e 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -97,8 +97,7 @@ public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo) final String devices = "0,1"; statement.execute( String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices)); - checkModelOnSpecifiedDevice( - statement, modelInfo.getModelId(), modelInfo.getModelType(), devices); + checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); long startTime = System.currentTimeMillis(); concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength); long endTime = System.currentTimeMillis(); diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java index 8953bec07a745..025fd50a60f69 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -46,12 +46,12 @@ public class AINodeForecastIT { private static final String[] WRITE_SQL_IN_TABLE = new String[] { - "CREATE DATABASE root", - "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", + "CREATE DATABASE db", + "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", }; private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM root.AI) ORDER BY time)"; + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)"; @BeforeClass public static void setUp() throws Exception { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index 56b9a5bbda732..2ae1b860cd230 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -86,11 +86,11 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter // Load sundial to each device statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); - checkModelOnSpecifiedDevice(statement, "sundial", "sundial", TARGET_DEVICES.toString()); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); // Load timer_xl to each device statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); - checkModelOnSpecifiedDevice(statement, "timer_xl", "timer_xl", TARGET_DEVICES.toString()); + checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); // Clean every device statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); @@ -107,7 +107,7 @@ public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - checkModelOnSpecifiedDevice(statement, "sundial", "sundial", TARGET_DEVICES.toString()); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 25cdf0f8ceef7..037f8f331b2f7 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -131,7 +131,7 @@ private void userDefinedModelManagementTest(Statement statement) public void dropBuiltInModelErrorTestInTree() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Built-in model sundial can't be removed"); } } @@ -139,7 +139,7 @@ public void dropBuiltInModelErrorTestInTree() throws SQLException { public void dropBuiltInModelErrorTestInTable() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Built-in model sundial can't be removed"); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index d9ddb6a4e097d..0de90c42925fe 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -45,9 +45,9 @@ public class AINodeTestUtils { public static final Map BUILTIN_LTSM_MAP = Stream.of( new AbstractMap.SimpleEntry<>( - "sundial", new FakeModelInfo("sundial", "sundial", "BUILT-IN", "ACTIVE")), + "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "timer_xl", new FakeModelInfo("timer_xl", "timer", "BUILT-IN", "ACTIVE"))) + "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); public static final Map BUILTIN_MODEL_MAP; @@ -56,27 +56,25 @@ public class AINodeTestUtils { Map tmp = Stream.of( new AbstractMap.SimpleEntry<>( - "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), + "arima", new FakeModelInfo("arima", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "holtwinters", - new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), + "holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( "exponential_smoothing", - new FakeModelInfo( - "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), + new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( "naive_forecaster", - new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), + new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( "stl_forecaster", - new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), + new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( "gaussian_hmm", - new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), + new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), + "gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")), new AbstractMap.SimpleEntry<>( - "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE"))) + "stray", new FakeModelInfo("stray", "sktime", "builtin", "active"))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); tmp.putAll(BUILTIN_LTSM_MAP); BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp); @@ -134,8 +132,7 @@ public static void concurrentInference( } } - public static void checkModelOnSpecifiedDevice( - Statement statement, String modelId, String modelType, String device) + public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) throws SQLException, InterruptedException { Set targetDevices = ImmutableSet.copyOf(device.split(",")); LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); @@ -146,13 +143,9 @@ public static void checkModelOnSpecifiedDevice( while (resultSet.next()) { String deviceId = resultSet.getString("DeviceId"); String loadedModelId = resultSet.getString("ModelId"); - String loadedModelType = resultSet.getString("ModelType"); int count = resultSet.getInt("Count(instances)"); LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); - if (loadedModelId.equals(modelId) - && loadedModelType.equals(modelType) - && targetDevices.contains(deviceId) - && count > 0) { + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { foundDevices.add(deviceId); LOGGER.info("Model {} is loaded to device {}", modelId, device); } diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index 6cb7d91ed9828..d8f730c829c89 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -64,7 +64,7 @@ AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model" AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") -AINODE_LOG_DIR = "logs" +AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs") # AINode log LOG_FILE_TYPE = ["all", "info", "warn", "error"] @@ -93,27 +93,6 @@ def get_status_code(self) -> int: return self.value -class TaskType(Enum): - FORECAST = "forecast" - - -class OptionsKey(Enum): - # common - TASK_TYPE = "task_type" - MODEL_TYPE = "model_type" - AUTO_TUNING = "auto_tuning" - INPUT_VARS = "input_vars" - - # forecast - INPUT_LENGTH = "input_length" - PREDICT_LENGTH = "predict_length" - PREDICT_INDEX_LIST = "predict_index_list" - INPUT_TYPE_LIST = "input_type_list" - - def name(self) -> str: - return self.value - - class HyperparameterName(Enum): # Training hyperparameter LEARNING_RATE = "learning_rate" @@ -132,10 +111,3 @@ class HyperparameterName(Enum): def name(self): return self.value - - -class ModelInputName(Enum): - DATA_X = "data_x" - TIME_STAMP_X = "time_stamp_x" - TIME_STAMP_Y = "time_stamp_y" - DEC_INP = "dec_inp" From ac6e4df56600e7df78b23c0bb4009c7c84876043 Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 20:44:55 +0800 Subject: [PATCH 37/38] Fix error CIs --- .../iotdb/ainode/it/AINodeCallInferenceIT.java | 9 --------- .../apache/iotdb/ainode/it/AINodeForecastIT.java | 13 ++++--------- .../apache/iotdb/ainode/it/AINodeModelManageIT.java | 4 ++-- 3 files changed, 6 insertions(+), 20 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 523a7f9aadb4a..5bfd8360cf9af 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -115,13 +115,4 @@ public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo } } } - - @Test - public void errorCallInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\")"; - errorTest(statement, sql, "1505: model [notFound404] has not been created."); - } - } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java index 025fd50a60f69..a06656d4adace 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -38,18 +38,11 @@ import java.sql.Statement; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; @RunWith(IoTDBTestRunner.class) @Category({AIClusterIT.class}) public class AINodeForecastIT { - private static final String[] WRITE_SQL_IN_TABLE = - new String[] { - "CREATE DATABASE db", - "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", - }; - private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)"; @@ -57,13 +50,15 @@ public class AINodeForecastIT { public static void setUp() throws Exception { // Init 1C1D1A cluster environment EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TABLE); try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE db"); + statement.execute( + "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); for (int i = 0; i < 2880; i++) { statement.execute( String.format( - "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", i, (float) i, (double) i, i, i)); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 037f8f331b2f7..b92b80aecf321 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -131,7 +131,7 @@ private void userDefinedModelManagementTest(Statement statement) public void dropBuiltInModelErrorTestInTree() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1510: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); } } @@ -139,7 +139,7 @@ public void dropBuiltInModelErrorTestInTree() throws SQLException { public void dropBuiltInModelErrorTestInTable() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1510: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); } } From 474a807c2836d8811688053d5ea0d14ef0a1be0a Mon Sep 17 00:00:00 2001 From: Yongzao <532741407@qq.com> Date: Mon, 8 Dec 2025 21:36:57 +0800 Subject: [PATCH 38/38] spotless --- .../java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java index 5bfd8360cf9af..44e280eca169b 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -40,7 +40,6 @@ import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; @RunWith(IoTDBTestRunner.class)