Skip to content

Commit 5abae22

Browse files
authored
feat: parse env var strings to expected config value types (#3107)
* fix: add try_parse_bool for env var strings to enable config overrides of boolean values * fix: fallback to given value if not parseable * feat: extend eval to all valid types * fix: remove return type * fix: prevent strange type conversions by providing expected type * feat: add tests
1 parent 04d7648 commit 5abae22

File tree

3 files changed

+168
-39
lines changed

3 files changed

+168
-39
lines changed

modules/config.py

Lines changed: 79 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import json
33
import math
44
import numbers
5+
56
import args_manager
67
import tempfile
78
import modules.flags
89
import modules.sdxl_styles
910

1011
from modules.model_loader import load_file_from_url
11-
from modules.extra_utils import makedirs_with_log, get_files_from_folder
12+
from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var
1213
from modules.flags import OutputFormat, Performance, MetadataScheme
1314

1415

@@ -200,14 +201,15 @@ def get_dir_or_set_default(key, default_value, as_array=False, make_directory=Fa
200201
path_outputs = get_path_output()
201202

202203

203-
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False):
204+
def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, expected_type=None):
204205
global config_dict, visited_keys
205206

206207
if key not in visited_keys:
207208
visited_keys.append(key)
208209

209210
v = os.getenv(key)
210211
if v is not None:
212+
v = try_eval_env_var(v, expected_type)
211213
print(f"Environment: {key} = {v}")
212214
config_dict[key] = v
213215

@@ -252,41 +254,49 @@ def init_temp_path(path: str | None, default_path: str) -> str:
252254
key='temp_path',
253255
default_value=default_temp_path,
254256
validator=lambda x: isinstance(x, str),
257+
expected_type=str
255258
), default_temp_path)
256259
temp_path_cleanup_on_launch = get_config_item_or_set_default(
257260
key='temp_path_cleanup_on_launch',
258261
default_value=True,
259-
validator=lambda x: isinstance(x, bool)
262+
validator=lambda x: isinstance(x, bool),
263+
expected_type=bool
260264
)
261265
default_base_model_name = default_model = get_config_item_or_set_default(
262266
key='default_model',
263267
default_value='model.safetensors',
264-
validator=lambda x: isinstance(x, str)
268+
validator=lambda x: isinstance(x, str),
269+
expected_type=str
265270
)
266271
previous_default_models = get_config_item_or_set_default(
267272
key='previous_default_models',
268273
default_value=[],
269-
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x)
274+
validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x),
275+
expected_type=list
270276
)
271277
default_refiner_model_name = default_refiner = get_config_item_or_set_default(
272278
key='default_refiner',
273279
default_value='None',
274-
validator=lambda x: isinstance(x, str)
280+
validator=lambda x: isinstance(x, str),
281+
expected_type=str
275282
)
276283
default_refiner_switch = get_config_item_or_set_default(
277284
key='default_refiner_switch',
278285
default_value=0.8,
279-
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1
286+
validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1,
287+
expected_type=numbers.Number
280288
)
281289
default_loras_min_weight = get_config_item_or_set_default(
282290
key='default_loras_min_weight',
283291
default_value=-2,
284-
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
292+
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10,
293+
expected_type=numbers.Number
285294
)
286295
default_loras_max_weight = get_config_item_or_set_default(
287296
key='default_loras_max_weight',
288297
default_value=2,
289-
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10
298+
validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10,
299+
expected_type=numbers.Number
290300
)
291301
default_loras = get_config_item_or_set_default(
292302
key='default_loras',
@@ -320,38 +330,45 @@ def init_temp_path(path: str | None, default_path: str) -> str:
320330
validator=lambda x: isinstance(x, list) and all(
321331
len(y) == 3 and isinstance(y[0], bool) and isinstance(y[1], str) and isinstance(y[2], numbers.Number)
322332
or len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number)
323-
for y in x)
333+
for y in x),
334+
expected_type=list
324335
)
325336
default_loras = [(y[0], y[1], y[2]) if len(y) == 3 else (True, y[0], y[1]) for y in default_loras]
326337
default_max_lora_number = get_config_item_or_set_default(
327338
key='default_max_lora_number',
328339
default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5,
329-
validator=lambda x: isinstance(x, int) and x >= 1
340+
validator=lambda x: isinstance(x, int) and x >= 1,
341+
expected_type=int
330342
)
331343
default_cfg_scale = get_config_item_or_set_default(
332344
key='default_cfg_scale',
333345
default_value=7.0,
334-
validator=lambda x: isinstance(x, numbers.Number)
346+
validator=lambda x: isinstance(x, numbers.Number),
347+
expected_type=numbers.Number
335348
)
336349
default_sample_sharpness = get_config_item_or_set_default(
337350
key='default_sample_sharpness',
338351
default_value=2.0,
339-
validator=lambda x: isinstance(x, numbers.Number)
352+
validator=lambda x: isinstance(x, numbers.Number),
353+
expected_type=numbers.Number
340354
)
341355
default_sampler = get_config_item_or_set_default(
342356
key='default_sampler',
343357
default_value='dpmpp_2m_sde_gpu',
344-
validator=lambda x: x in modules.flags.sampler_list
358+
validator=lambda x: x in modules.flags.sampler_list,
359+
expected_type=str
345360
)
346361
default_scheduler = get_config_item_or_set_default(
347362
key='default_scheduler',
348363
default_value='karras',
349-
validator=lambda x: x in modules.flags.scheduler_list
364+
validator=lambda x: x in modules.flags.scheduler_list,
365+
expected_type=str
350366
)
351367
default_vae = get_config_item_or_set_default(
352368
key='default_vae',
353369
default_value=modules.flags.default_vae,
354-
validator=lambda x: isinstance(x, str)
370+
validator=lambda x: isinstance(x, str),
371+
expected_type=str
355372
)
356373
default_styles = get_config_item_or_set_default(
357374
key='default_styles',
@@ -360,121 +377,144 @@ def init_temp_path(path: str | None, default_path: str) -> str:
360377
"Fooocus Enhance",
361378
"Fooocus Sharp"
362379
],
363-
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x)
380+
validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x),
381+
expected_type=list
364382
)
365383
default_prompt_negative = get_config_item_or_set_default(
366384
key='default_prompt_negative',
367385
default_value='',
368386
validator=lambda x: isinstance(x, str),
369-
disable_empty_as_none=True
387+
disable_empty_as_none=True,
388+
expected_type=str
370389
)
371390
default_prompt = get_config_item_or_set_default(
372391
key='default_prompt',
373392
default_value='',
374393
validator=lambda x: isinstance(x, str),
375-
disable_empty_as_none=True
394+
disable_empty_as_none=True,
395+
expected_type=str
376396
)
377397
default_performance = get_config_item_or_set_default(
378398
key='default_performance',
379399
default_value=Performance.SPEED.value,
380-
validator=lambda x: x in Performance.list()
400+
validator=lambda x: x in Performance.list(),
401+
expected_type=str
381402
)
382403
default_advanced_checkbox = get_config_item_or_set_default(
383404
key='default_advanced_checkbox',
384405
default_value=False,
385-
validator=lambda x: isinstance(x, bool)
406+
validator=lambda x: isinstance(x, bool),
407+
expected_type=bool
386408
)
387409
default_max_image_number = get_config_item_or_set_default(
388410
key='default_max_image_number',
389411
default_value=32,
390-
validator=lambda x: isinstance(x, int) and x >= 1
412+
validator=lambda x: isinstance(x, int) and x >= 1,
413+
expected_type=int
391414
)
392415
default_output_format = get_config_item_or_set_default(
393416
key='default_output_format',
394417
default_value='png',
395-
validator=lambda x: x in OutputFormat.list()
418+
validator=lambda x: x in OutputFormat.list(),
419+
expected_type=str
396420
)
397421
default_image_number = get_config_item_or_set_default(
398422
key='default_image_number',
399423
default_value=2,
400-
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number
424+
validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number,
425+
expected_type=int
401426
)
402427
checkpoint_downloads = get_config_item_or_set_default(
403428
key='checkpoint_downloads',
404429
default_value={},
405-
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
430+
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
431+
expected_type=dict
406432
)
407433
lora_downloads = get_config_item_or_set_default(
408434
key='lora_downloads',
409435
default_value={},
410-
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
436+
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
437+
expected_type=dict
411438
)
412439
embeddings_downloads = get_config_item_or_set_default(
413440
key='embeddings_downloads',
414441
default_value={},
415-
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items())
442+
validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()),
443+
expected_type=dict
416444
)
417445
available_aspect_ratios = get_config_item_or_set_default(
418446
key='available_aspect_ratios',
419447
default_value=modules.flags.sdxl_aspect_ratios,
420-
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1
448+
validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1,
449+
expected_type=list
421450
)
422451
default_aspect_ratio = get_config_item_or_set_default(
423452
key='default_aspect_ratio',
424453
default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0],
425-
validator=lambda x: x in available_aspect_ratios
454+
validator=lambda x: x in available_aspect_ratios,
455+
expected_type=str
426456
)
427457
default_inpaint_engine_version = get_config_item_or_set_default(
428458
key='default_inpaint_engine_version',
429459
default_value='v2.6',
430-
validator=lambda x: x in modules.flags.inpaint_engine_versions
460+
validator=lambda x: x in modules.flags.inpaint_engine_versions,
461+
expected_type=str
431462
)
432463
default_cfg_tsnr = get_config_item_or_set_default(
433464
key='default_cfg_tsnr',
434465
default_value=7.0,
435-
validator=lambda x: isinstance(x, numbers.Number)
466+
validator=lambda x: isinstance(x, numbers.Number),
467+
expected_type=numbers.Number
436468
)
437469
default_clip_skip = get_config_item_or_set_default(
438470
key='default_clip_skip',
439471
default_value=2,
440-
validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max
472+
validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max,
473+
expected_type=int
441474
)
442475
default_overwrite_step = get_config_item_or_set_default(
443476
key='default_overwrite_step',
444477
default_value=-1,
445-
validator=lambda x: isinstance(x, int)
478+
validator=lambda x: isinstance(x, int),
479+
expected_type=int
446480
)
447481
default_overwrite_switch = get_config_item_or_set_default(
448482
key='default_overwrite_switch',
449483
default_value=-1,
450-
validator=lambda x: isinstance(x, int)
484+
validator=lambda x: isinstance(x, int),
485+
expected_type=int
451486
)
452487
example_inpaint_prompts = get_config_item_or_set_default(
453488
key='example_inpaint_prompts',
454489
default_value=[
455490
'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes'
456491
],
457-
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x)
492+
validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x),
493+
expected_type=list
458494
)
459495
default_black_out_nsfw = get_config_item_or_set_default(
460496
key='default_black_out_nsfw',
461497
default_value=False,
462-
validator=lambda x: isinstance(x, bool)
498+
validator=lambda x: isinstance(x, bool),
499+
expected_type=bool
463500
)
464501
default_save_metadata_to_images = get_config_item_or_set_default(
465502
key='default_save_metadata_to_images',
466503
default_value=False,
467-
validator=lambda x: isinstance(x, bool)
504+
validator=lambda x: isinstance(x, bool),
505+
expected_type=bool
468506
)
469507
default_metadata_scheme = get_config_item_or_set_default(
470508
key='default_metadata_scheme',
471509
default_value=MetadataScheme.FOOOCUS.value,
472-
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x]
510+
validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x],
511+
expected_type=str
473512
)
474513
metadata_created_by = get_config_item_or_set_default(
475514
key='metadata_created_by',
476515
default_value='',
477-
validator=lambda x: isinstance(x, str)
516+
validator=lambda x: isinstance(x, str),
517+
expected_type=str
478518
)
479519

480520
example_inpaint_prompts = [[x] for x in example_inpaint_prompts]

modules/extra_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
from ast import literal_eval
3+
24

35
def makedirs_with_log(path):
46
try:
@@ -24,3 +26,16 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None):
2426
filenames.append(path)
2527

2628
return filenames
29+
30+
31+
def try_eval_env_var(value: str, expected_type=None):
32+
try:
33+
value_eval = value
34+
if expected_type is bool:
35+
value_eval = value.title()
36+
value_eval = literal_eval(value_eval)
37+
if expected_type is not None and not isinstance(value_eval, expected_type):
38+
return value
39+
return value_eval
40+
except:
41+
return value

0 commit comments

Comments
 (0)