diff --git a/domainlab/algos/trainers/a_trainer.py b/domainlab/algos/trainers/a_trainer.py index 9bc126705..d9bda1513 100644 --- a/domainlab/algos/trainers/a_trainer.py +++ b/domainlab/algos/trainers/a_trainer.py @@ -127,7 +127,9 @@ def init_business(self, model, task, observer, device, aconf, flag_accept=True): self._decoratee.init_business( model, task, observer, device, aconf, flag_accept ) - self.model = model + self.model = self._decoratee + else: + self.model = model self.task = task self.task.init_business(trainer=self, args=aconf) self.model.list_d_tr = self.task.list_domain_tr