diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index bc06edf04..6401f4086 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -76,6 +76,8 @@ if MISTRAL: __all__.append(FlashMistral) +LORAX_ENABLED_MODEL_TYPES = os.environ.get("PREDIBASE_LORAX_ENABLED_MODEL_TYPES", "").split(",") + def get_model( model_id: str, @@ -88,10 +90,29 @@ def get_model( source: str, adapter_source: str, ) -> Model: - if len(adapter_id) > 0: - logger.warning( - "adapter_id is only supported for FlashLlama models and will be " - "ignored for other models." + config_dict = None + if source == "s3": + # change the model id to be the local path to the folder so + # we can load the config_dict locally + logger.info(f"Using the local files since we are coming from s3") + model_path = get_s3_model_local_dir(model_id) + logger.info(f"model_path: {model_path}") + config_dict, _ = PretrainedConfig.get_config_dict( + model_path, revision=revision, trust_remote_code=trust_remote_code + ) + logger.info(f"config_dict: {config_dict}") + model_id = model_path + elif source == "hub": + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + else: + raise ValueError(f"Unknown source {source}") + + model_type = config_dict["model_type"] + if len(adapter_id) > 0 and model_type not in LORAX_ENABLED_MODEL_TYPES: + raise ValueError( + f"adapter_id is only supported for models with type {LORAX_ENABLED_MODEL_TYPES}." ) if dtype is None: @@ -133,34 +154,6 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - - config_dict = None - if source == "s3": - # change the model id to be the local path to the folder so - # we can load the config_dict locally - logger.info(f"Using the local files since we are coming from s3") - model_path = get_s3_model_local_dir(model_id) - logger.info(f"model_path: {model_path}") - config_dict, _ = PretrainedConfig.get_config_dict( - model_path, revision=revision, trust_remote_code=trust_remote_code - ) - logger.info(f"config_dict: {config_dict}") - if config_dict["model_type"] != "llama": - raise ValueError(f"Unsupported model type {config_dict['model_type']} for s3 imports") - model_id = model_path - elif source == "hub": - config_dict, _ = PretrainedConfig.get_config_dict( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - else: - raise ValueError(f"Unknown source {source}") - - # ensure that if an adapter source other than hub is used, - # that we only allow a model type of llama - if adapter_source != "hub" and config_dict["model_type"] != "llama": - raise ValueError(f"Unsupported model type {config_dict['model_type']} for adapter source {adapter_source}") - - model_type = config_dict["model_type"] if model_type == "gpt_bigcode": if FLASH_ATTENTION: