Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: ensure any (source-agnostic) Lorax-enabled adapters are usable #3

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 25 additions & 32 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
if MISTRAL:
__all__.append(FlashMistral)

LORAX_ENABLED_MODEL_TYPES = os.environ.get("PREDIBASE_LORAX_ENABLED_MODEL_TYPES", "").split(",")


def get_model(
model_id: str,
Expand All @@ -88,10 +90,29 @@ def get_model(
source: str,
adapter_source: str,
) -> Model:
if len(adapter_id) > 0:
logger.warning(
"adapter_id is only supported for FlashLlama models and will be "
"ignored for other models."
config_dict = None
if source == "s3":
# change the model id to be the local path to the folder so
# we can load the config_dict locally
logger.info(f"Using the local files since we are coming from s3")
model_path = get_s3_model_local_dir(model_id)
logger.info(f"model_path: {model_path}")
config_dict, _ = PretrainedConfig.get_config_dict(
model_path, revision=revision, trust_remote_code=trust_remote_code
)
logger.info(f"config_dict: {config_dict}")
model_id = model_path
elif source == "hub":
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
else:
raise ValueError(f"Unknown source {source}")

model_type = config_dict["model_type"]
if len(adapter_id) > 0 and model_type not in LORAX_ENABLED_MODEL_TYPES:
raise ValueError(
f"adapter_id is only supported for models with type {LORAX_ENABLED_MODEL_TYPES}."
)

if dtype is None:
Expand Down Expand Up @@ -133,34 +154,6 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)

config_dict = None
if source == "s3":
# change the model id to be the local path to the folder so
# we can load the config_dict locally
logger.info(f"Using the local files since we are coming from s3")
model_path = get_s3_model_local_dir(model_id)
logger.info(f"model_path: {model_path}")
config_dict, _ = PretrainedConfig.get_config_dict(
model_path, revision=revision, trust_remote_code=trust_remote_code
)
logger.info(f"config_dict: {config_dict}")
if config_dict["model_type"] != "llama":
raise ValueError(f"Unsupported model type {config_dict['model_type']} for s3 imports")
model_id = model_path
elif source == "hub":
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
else:
raise ValueError(f"Unknown source {source}")

# ensure that if an adapter source other than hub is used,
# that we only allow a model type of llama
if adapter_source != "hub" and config_dict["model_type"] != "llama":
raise ValueError(f"Unsupported model type {config_dict['model_type']} for adapter source {adapter_source}")

model_type = config_dict["model_type"]

if model_type == "gpt_bigcode":
if FLASH_ATTENTION:
Expand Down
Loading