Skip to content

Commit b1c18e5

Browse files
committed
refactoring config setting
1 parent d1bceca commit b1c18e5

File tree

1 file changed

+38
-58
lines changed

1 file changed

+38
-58
lines changed

chefboost/commons/functions.py

Lines changed: 38 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -158,65 +158,45 @@ def initializeParams(config: Optional[dict] = None) -> dict:
158158
Returns:
159159
config (dict): final configuration
160160
"""
161-
if config == None:
161+
if config is None:
162162
config = {}
163163

164-
algorithm = "ID3"
165-
enableRandomForest = False
166-
num_of_trees = 5
167-
enableMultitasking = False
168-
enableGBM = False
169-
epochs = 10
170-
learning_rate = 1
171-
max_depth = 5
172-
enableAdaboost = False
173-
num_of_weak_classifier = 4
174-
enableParallelism = False
175-
num_cores = int(multiprocessing.cpu_count() / 2) # allocate half of your total cores
176-
# num_cores = int((3*multiprocessing.cpu_count())/4) #allocate 3/4 of your total cores
177-
# num_cores = multiprocessing.cpu_count()
178-
179-
for key, value in config.items():
180-
if key == "algorithm":
181-
algorithm = value
182-
# ---------------------------------
183-
elif key == "enableRandomForest":
184-
enableRandomForest = value
185-
elif key == "num_of_trees":
186-
num_of_trees = value
187-
elif key == "enableMultitasking":
188-
enableMultitasking = value
189-
# ---------------------------------
190-
elif key == "enableGBM":
191-
enableGBM = value
192-
elif key == "epochs":
193-
epochs = value
194-
elif key == "learning_rate":
195-
learning_rate = value
196-
elif key == "max_depth":
197-
max_depth = value
198-
# ---------------------------------
199-
elif key == "enableAdaboost":
200-
enableAdaboost = value
201-
elif key == "num_of_weak_classifier":
202-
num_of_weak_classifier = value
203-
# ---------------------------------
204-
elif key == "enableParallelism":
205-
enableParallelism = value
206-
elif key == "num_cores":
207-
num_cores = value
208-
209-
config["algorithm"] = algorithm
210-
config["enableRandomForest"] = enableRandomForest
211-
config["num_of_trees"] = num_of_trees
212-
config["enableMultitasking"] = enableMultitasking
213-
config["enableGBM"] = enableGBM
214-
config["epochs"] = epochs
215-
config["learning_rate"] = learning_rate
216-
config["max_depth"] = max_depth
217-
config["enableAdaboost"] = enableAdaboost
218-
config["num_of_weak_classifier"] = num_of_weak_classifier
219-
config["enableParallelism"] = enableParallelism
220-
config["num_cores"] = num_cores
164+
# set these default values if they are not mentioned in config
165+
if config.get("algorithm") is None:
166+
config["algorithm"] = "ID3"
167+
168+
if config.get("enableRandomForest") is None:
169+
config["enableRandomForest"] = False
170+
171+
if config.get("num_of_trees") is None:
172+
config["num_of_trees"] = 5
173+
174+
if config.get("enableMultitasking") is None:
175+
config["enableMultitasking"] = False
176+
177+
if config.get("enableGBM") is None:
178+
config["enableGBM"] = False
179+
180+
if config.get("epochs") is None:
181+
config["epochs"] = 10
182+
183+
if config.get("learning_rate") is None:
184+
config["learning_rate"] = 1
185+
186+
if config.get("max_depth") is None:
187+
config["max_depth"] = 5
188+
189+
if config.get("enableAdaboost") is None:
190+
config["enableAdaboost"] = False
191+
192+
if config.get("num_of_weak_classifier") is None:
193+
config["num_of_weak_classifier"] = 4
194+
195+
if config.get("enableParallelism") is None:
196+
config["enableParallelism"] = False
197+
198+
if config.get("num_cores") is None:
199+
# allocate half of your total cores
200+
config["num_cores"] = int(multiprocessing.cpu_count() / 2)
221201

222202
return config

0 commit comments

Comments
 (0)