forked from derrian-distro/LoRA_Easy_Training_Scripts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
json_functions.py
62 lines (54 loc) · 2.89 KB
/
json_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import json
import os
import time
def save_json(path, obj: dict) -> None:
# set these to None and False to prevent them from modifying the output when loaded back up
obj['list_of_json_to_run'] = None
obj['save_json_only'] = False
name = f"config-{time.time()}.json" if not obj['save_json_name'] else \
f"config-{time.time()}-{obj['save_json_name']}.json"
fp = open(os.path.join(path, name), "w")
json.dump(obj, fp=fp, indent=4)
fp.close()
def load_json(path, obj: dict) -> dict:
with open(path) as f:
json_obj = json.loads(f.read())
print("loaded json, setting variables...")
ui_name_scheme = {"pretrained_model_name_or_path": "base_model", "logging_dir": "log_dir",
"train_data_dir": "img_folder", "reg_data_dir": "reg_img_folder",
"output_dir": "output_folder", "max_resolution": "train_resolution",
"lr_scheduler": "scheduler", "lr_warmup": "warmup_lr_ratio",
"train_batch_size": "batch_size", "epoch": "num_epochs",
"save_at_n_epochs": "save_every_n_epochs", "num_cpu_threads_per_process": "num_workers",
"enable_bucket": "buckets", "save_model_as": "save_as", "shuffle_caption": "shuffle_captions",
"resume": "load_previous_save_state", "network_dim": "net_dim",
"gradient_accumulation_steps": "gradient_acc_steps", "output_name": "change_output_name",
"network_alpha": "alpha", "lr_scheduler_num_cycles": "cosine_restarts",
"lr_scheduler_power": "scheduler_power"}
for key in list(json_obj):
if key in ui_name_scheme:
json_obj[ui_name_scheme[key]] = json_obj[key]
if ui_name_scheme[key] in {"batch_size", "num_epochs"}:
try:
json_obj[ui_name_scheme[key]] = int(json_obj[ui_name_scheme[key]])
except ValueError:
print(f"attempting to load {key} from json failed as input isn't an integer")
quit(1)
obj['save_json_folder'] = None
for key in list(json_obj):
if obj["json_load_skip_list"] and key in obj["json_load_skip_list"]:
continue
if key == "save_json_folder":
continue
if key in obj:
if key in {"keep_tokens"}:
json_obj[key] = int(json_obj[key]) if json_obj[key] is not None else None
if key in {"learning_rate", "unet_lr", "text_encoder_lr", "warmup_lr_ratio"}:
json_obj[key] = float(json_obj[key]) if json_obj[key] is not None else None
if obj[key] != json_obj[key]:
print_change(key, obj[key], json_obj[key])
obj[key] = json_obj[key]
print("completed changing variables.")
return obj
def print_change(value, old, new):
print(f"{value} changed from {old} to {new}")