Skip to content

Commit 093e570

Browse files
zhangyubo0722TingquanGao
authored andcommitted
support input config path
1 parent d8ea10b commit 093e570

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

paddlex/modules/base/evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def __init__(self, config):
5151
self.global_config = config.Global
5252
self.eval_config = config.Evaluate
5353

54-
config_path = self.get_config_path(self.eval_config.weight_path)
54+
config_path = self.eval_config.get("basic_config_path", None)
55+
if not config_path:
56+
config_path = self.get_config_path(self.eval_config.weight_path)
5557

5658
self.pdx_config, self.pdx_model = build_model(
5759
self.global_config.model, config_path=config_path

paddlex/modules/base/exportor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def __init__(self, config):
5151
self.global_config = config.Global
5252
self.export_config = config.Export
5353

54-
config_path = self.get_config_path(self.export_config.weight_path)
54+
config_path = self.export_config.get("basic_config_path", None)
55+
if not config_path:
56+
config_path = self.get_config_path(self.export_config.weight_path)
5557

5658
self.pdx_config, self.pdx_model = build_model(
5759
self.global_config.model, config_path=config_path

paddlex/modules/base/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,11 @@ def __init__(self, config: AttrDict):
5050
self.global_config = config.Global
5151
self.train_config = config.Train
5252
self.benchmark_config = config.get("Benchmark", None)
53+
config_path = self.train_config.get("basic_config_path", None)
5354

54-
self.pdx_config, self.pdx_model = build_model(self.global_config.model)
55+
self.pdx_config, self.pdx_model = build_model(
56+
self.global_config.model, config_path=config_path
57+
)
5558

5659
def train(self, *args, **kwargs):
5760
"""execute model training"""

0 commit comments

Comments
 (0)