diff --git a/src/ae.py b/src/ae.py index 479d1e6..577d765 100644 --- a/src/ae.py +++ b/src/ae.py @@ -40,7 +40,10 @@ def from_pretrained(cls, model_uri: str, **kwargs): from safetensors.torch import load_file from .ae_nn import AutoEncoder - base = pathlib.Path(huggingface_hub.snapshot_download(model_uri)) + try: + base = pathlib.Path(huggingface_hub.snapshot_download(model_uri)) + except Exception: + base = pathlib.Path(model_uri) enc_cfg = OmegaConf.load(base / "encoder_conf.yml").model dec_cfg = OmegaConf.load(base / "decoder_conf.yml").model diff --git a/src/world_engine.py b/src/world_engine.py index 38eea42..77b6969 100644 --- a/src/world_engine.py +++ b/src/world_engine.py @@ -36,6 +36,7 @@ def __init__( """ model_uri: HF URI or local folder containing model.safetensors and config.yaml quant: None | w8a8 | nvfp4 + model_config_overrides: Dict to override model config values """ self.device, self.dtype = device, dtype @@ -49,7 +50,8 @@ def __init__( self.prompt_encoder = None if self.model_cfg.prompt_conditioning is not None: - self.prompt_encoder = PromptEncoder("google/umt5-xl", dtype=dtype).to(device).eval() # TODO: dont hardcode + pe_uri = getattr(self.model_cfg, "prompt_encoder_uri", "google/umt5-xl") + self.prompt_encoder = PromptEncoder(pe_uri, dtype=dtype).to(device).eval() self.model = WorldModel.from_pretrained(model_uri, cfg=self.model_cfg).to(device=device, dtype=dtype).eval() apply_inference_patches(self.model)