Skip to content

Commit

Permalink
fix: lora download hang-up; don't hard crash on lora dl fail
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Oct 3, 2023
1 parent b51c4f5 commit 96a6ea1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
13 changes: 11 additions & 2 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,11 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
for ti in payload.get("tis"):
# Determine the actual ti filename
if not SharedModelManager.manager.ti.is_local_model(str(ti["name"])):
adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"]))
try:
adhoc_ti = SharedModelManager.manager.ti.fetch_adhoc_ti(str(ti["name"]))
except Exception as e:
logger.info(f"Error fetching adhoc TI {ti['name']}: ({type(e).__name__}) {e}")
adhoc_ti = None
if not adhoc_ti:
logger.info(f"Adhoc TI requested '{ti['name']}' could not be found in CivitAI. Ignoring!")
continue
Expand Down Expand Up @@ -389,7 +393,12 @@ def _final_pipeline_adjustments(self, payload, pipeline_data):
for lora in payload.get("loras"):
# Determine the actual lora filename
if not SharedModelManager.manager.lora.is_model_available(str(lora["name"])):
adhoc_lora = SharedModelManager.manager.lora.fetch_adhoc_lora(str(lora["name"]))
logger.debug(f"Adhoc lora requested '{lora['name']}' not yet downloaded. Downloading...")
try:
adhoc_lora = SharedModelManager.manager.lora.fetch_adhoc_lora(str(lora["name"]))
except Exception as e:
logger.info(f"Error fetching adhoc lora {lora['name']}: ({type(e).__name__}) {e}")
adhoc_lora = None
if not adhoc_lora:
logger.info(f"Adhoc lora requested '{lora['name']}' could not be found in CivitAI. Ignoring!")
continue
Expand Down
21 changes: 20 additions & 1 deletion hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ def load_model_database(self) -> None:
if self.models_db_path.exists():
try:
self.model_reference = json.loads((self.models_db_path).read_text())

for lora in self.model_reference.values():
self._index_ids[lora["id"]] = lora["name"].lower()
orig_name = lora.get("orig_name", lora["name"]).lower()
self._index_orig_names[orig_name] = lora["name"].lower()

logger.info("Loaded model reference from disk.")
except json.JSONDecodeError:
logger.error(f"Could not load {self.models_db_name} model reference from disk! Bad JSON?")
Expand Down Expand Up @@ -367,8 +373,11 @@ def _download_lora(self, lora):
# Start download threads if they aren't already started
while len(self._download_threads) < self.MAX_DOWNLOAD_THREADS:
thread_iter = len(self._download_threads)
logger.debug(f"Starting download thread {thread_iter}")
thread = threading.Thread(target=self._download_thread, daemon=True, args=(thread_iter,))
self._download_threads[thread_iter] = {"thread": thread, "lora": None}
logger.debug(f"Started download thread {thread_iter}")
logger.debug(f"Download threads: {self._download_threads}")
thread.start()

# Add this lora to the download queue
Expand Down Expand Up @@ -462,6 +471,7 @@ def are_download_threads_idle(self):

def fuzzy_find_lora_key(self, lora_name):
# sname = Sanitizer.remove_version(lora_name).lower()
logger.debug(f"Looking for lora {lora_name}")
if type(lora_name) is int or lora_name.isdigit():
if int(lora_name) in self._index_ids:
return self._index_ids[int(lora_name)]
Expand Down Expand Up @@ -753,6 +763,11 @@ def get_lora_last_use(self, lora_name):
return datetime.strptime(lora["last_used"], "%Y-%m-%d %H:%M:%S")

def fetch_adhoc_lora(self, lora_name, timeout=45):
if isinstance(lora_name, str):
if lora_name in self.model_reference:
self._touch_lora(lora_name)
return lora_name

if type(lora_name) is int or lora_name.isdigit():
url = f"https://civitai.com/api/v1/models/{lora_name}"
else:
Expand All @@ -776,7 +791,8 @@ def fetch_adhoc_lora(self, lora_name, timeout=45):
if fuzzy_find:
logger.debug(f"Found lora with ID: {fuzzy_find}")
return fuzzy_find
self._download_queue.append(lora)
self._download_lora(lora)

# We need to wait a bit to make sure the threads pick up the download
time.sleep(self.THREAD_WAIT_TIME)
self.wait_for_downloads(timeout)
Expand All @@ -799,6 +815,9 @@ def do_baselines_match(self, lora_name, model_details):

@override
def is_model_available(self, model_name):
if model_name in self.model_reference:
return True

found_model_name = self.fuzzy_find_lora_key(model_name)
if found_model_name is None:
return False
Expand Down
6 changes: 6 additions & 0 deletions hordelib/model_manager/ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def load_model_database(self) -> None:
if self.models_db_path.exists():
try:
self.model_reference = json.loads((self.models_db_path).read_text())

for ti in self.model_reference.values():
self._index_ids[ti["id"]] = ti["name"].lower()
orig_name = ti.get("orig_name", ti["name"]).lower()
self._index_orig_names[orig_name] = ti["name"].lower()

logger.info("Loaded model reference from disk.")
except json.JSONDecodeError:
logger.error(f"Could not load {self.models_db_name} model reference from disk! Bad JSON?")
Expand Down

0 comments on commit 96a6ea1

Please sign in to comment.