Skip to content

Commit d169dc8

Browse files
jess-farmerRabiyaF
andauthored
Added number of iterations/function evaluations to fitting report (fitbenchmarking#1307)
* add number of iterations/function evaluations to fitting report * fix failing test * add iteration counts for matlab minimizers * record iter count and func evals separately for minimizers that output both * fix typo Co-authored-by: RabiyaF <47083562+RabiyaF@users.noreply.github.com> * fix typo Co-authored-by: RabiyaF <47083562+RabiyaF@users.noreply.github.com> * use str() instead of f-string Co-authored-by: RabiyaF <47083562+RabiyaF@users.noreply.github.com> * address review comments * ruff formatting fixes * fix failing test * fix failing test --------- Co-authored-by: RabiyaF <47083562+RabiyaF@users.noreply.github.com>
1 parent aa15cf7 commit d169dc8

26 files changed

+747
-496
lines changed

docs/source/users/install_instructions/fitbenchmarking.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ where valid strings ``option-x`` are:
8484
* ``levmar`` -- installs the `levmar <http://users.ics.forth.gr/~lourakis/levmar/>`_ fitting package (suitable for Python up to 3.8, see :ref:`levmar-install`). Note that the interface we use also requires BLAS and LAPLACK to be installed on the system, and calls to this minimizer will fail if these libraries are not present.
8585
* ``mantid`` -- installs the `h5py <https://pypi.org/project/h5py/>`_ and `pyyaml <https://pypi.org/project/PyYAML/>`_ modules.
8686
* ``matlab`` -- installs the `dill <https://pypi.org/project/dill/>`_ module required to run matlab controllers in fitbenchmarking
87-
* ``minuit`` -- installs the `Minuit <http://seal.web.cern.ch/seal/snapshot/work-packages/mathlibs/minuit/>`_ fitting package.
87+
* ``minuit`` -- installs the `Minuit <https://scikit-hep.org/iminuit/>`_ fitting package.
8888
* ``SAS`` -- installs the `Sasmodels <https://github.com/SasView/sasmodels>`_ fitting package and the `tinycc <https://pypi.org/project/tinycc/>`_ module.
8989
* ``numdifftools`` -- installs the `numdifftools <https://numdifftools.readthedocs.io/en/latest/index.html>`_ numerical differentiation package.
9090
* ``nlopt``-- installs the `NLopt <https://github.com/DanielBok/nlopt-python#installation>`_ fitting package.

fitbenchmarking/controllers/base_controller.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ def __init__(self, cost_func):
172172

173173
self.par_names = self.problem.param_names
174174

175+
# save iteration count
176+
self.iteration_count = None
177+
178+
# save number of function evaluations
179+
self.func_evals = None
180+
175181
@property
176182
def flag(self):
177183
"""
@@ -465,7 +471,12 @@ def check_attributes(self):
465471
A helper function which checks all required attributes are set
466472
in software controllers
467473
"""
468-
values = {"_flag": int, "final_params": np.ndarray}
474+
values = {
475+
"_flag": int,
476+
"final_params": np.ndarray,
477+
"iteration_count": (int, type(None)),
478+
"func_evals": (int, type(None)),
479+
}
469480

470481
for attr_name, attr_type in values.items():
471482
attr = getattr(self, attr_name)

fitbenchmarking/controllers/ceres_controller.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,8 @@ def cleanup(self):
204204
self.flag = 2
205205

206206
self.final_params = self.result
207+
208+
self.iteration_count = (
209+
self.ceres_summary.num_successful_steps
210+
+ self.ceres_summary.num_unsuccessful_steps
211+
)

fitbenchmarking/controllers/dfo_controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,4 @@ def cleanup(self):
104104
self.flag = 2
105105

106106
self.final_params = self._popt
107+
self.func_evals = self._soln.nf

fitbenchmarking/controllers/gradient_free_controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def setup(self):
123123
for i in range(len(self.initial_params))
124124
}
125125

126+
self.iteration_count = 1000
126127
self.initialize = {"warm_start": param_dict}
127128

128129
def _feval(self, p):
@@ -151,7 +152,7 @@ def fit(self):
151152
method_to_call = getattr(gfo, self.minimizer)
152153

153154
opt = method_to_call(self.search_space)
154-
opt.search(self._feval, n_iter=1000, verbosity=False)
155+
opt.search(self._feval, n_iter=self.iteration_count, verbosity=False)
155156
self.results = opt.best_para
156157
self._status = 0 if self.results is not None else 1
157158

fitbenchmarking/controllers/gsl_controller.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def __init__(self, cost_func):
8080
self._abserror = None
8181
self._relerror = None
8282
self._maxits = None
83+
self._nits = None
8384

8485
def _prediction_error(self, p, data=None):
8586
"""
@@ -225,7 +226,7 @@ def fit(self):
225226
"""
226227
Run problem with GSL
227228
"""
228-
for _ in range(self._maxits):
229+
for n in range(self._maxits):
229230
status = self._solver.iterate()
230231
# check if the method has converged
231232
if self.minimizer in self._residual_methods:
@@ -244,6 +245,7 @@ def fit(self):
244245
)
245246
if status == errno.GSL_SUCCESS:
246247
self.flag = 0
248+
self._nits = n + 1
247249
break
248250
if status != errno.GSL_CONTINUE:
249251
self.flag = 2
@@ -256,3 +258,6 @@ def cleanup(self):
256258
will be read from
257259
"""
258260
self.final_params = self._solver.getx()
261+
self.iteration_count = (
262+
self._maxits if self._nits is None else self._nits
263+
)

fitbenchmarking/controllers/levmar_controller.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,6 @@ def cleanup(self):
122122
self.flag = 1
123123
else:
124124
self.flag = 2
125+
126+
self.iteration_count = self._info[2]
127+
self.func_evals = self._info[4]

fitbenchmarking/controllers/lmfit_controller.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def cleanup(self):
185185
will be read from
186186
"""
187187

188+
self.func_evals = self.lmfit_out.nfev
189+
188190
if self.lmfit_out.success:
189191
self.flag = 0
190192
else:

fitbenchmarking/controllers/matlab_controller.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(self, cost_func):
4343
super().__init__(cost_func)
4444
self._status = None
4545
self.result = None
46+
self._nits = None
4647

4748
def setup(self):
4849
"""
@@ -59,12 +60,13 @@ def fit(self):
5960
"""
6061
Run problem with Matlab
6162
"""
62-
[self.result, _, exitflag] = self.eng.fminsearch(
63+
[self.result, _, exitflag, output] = self.eng.fminsearch(
6364
self.eng.workspace["eval_cost_mat"],
6465
self.initial_params_mat,
65-
nargout=3,
66+
nargout=4,
6667
)
6768
self._status = int(exitflag)
69+
self._nits = int(output["iterations"])
6870

6971
def cleanup(self):
7072
"""
@@ -77,7 +79,7 @@ def cleanup(self):
7779
self.flag = 1
7880
else:
7981
self.flag = 2
80-
8182
self.final_params = np.array(
8283
self.result[0], dtype=np.float64
8384
).flatten()
85+
self.iteration_count = self._nits

fitbenchmarking/controllers/matlab_curve_controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,4 @@ def cleanup(self):
120120
self.final_params = self.eng.coeffvalues(self.eng.workspace["fitobj"])[
121121
0
122122
]
123+
self.iteration_count = int(self.eng.workspace["output"]["iterations"])

fitbenchmarking/controllers/matlab_opt_controller.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self, cost_func):
5353
self.y_data_mat = None
5454
self._status = None
5555
self.result = None
56+
self._nits = None
5657

5758
def setup(self):
5859
"""
@@ -111,7 +112,7 @@ def fit(self):
111112
"""
112113
Run problem with Matlab Optimization Toolbox
113114
"""
114-
self.result, _, _, exitflag, _ = self.eng.lsqcurvefit(
115+
self.result, _, _, exitflag, output = self.eng.lsqcurvefit(
115116
self.eng.workspace["eval_func"],
116117
self.initial_params_mat,
117118
self.x_data_mat,
@@ -122,6 +123,7 @@ def fit(self):
122123
nargout=5,
123124
)
124125
self._status = int(exitflag)
126+
self._nits = output["iterations"]
125127

126128
def cleanup(self):
127129
"""
@@ -139,3 +141,4 @@ def cleanup(self):
139141
self.final_params = np.array(
140142
self.result[0], dtype=np.float64
141143
).flatten()
144+
self.iteration_count = self._nits

fitbenchmarking/controllers/minuit_controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,6 @@ def cleanup(self):
122122
else:
123123
self.flag = 2
124124

125+
self.func_evals = self._minuit_problem.nfcn
125126
self._popt = np.array(self._minuit_problem.values)
126127
self.final_params = self._popt

fitbenchmarking/controllers/ralfit_controller.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(self, cost_func):
8383
self.param_ranges = None
8484
self._status = None
8585
self._popt = None
86+
self._iter = None
8687
self._options = {}
8788

8889
def setup(self):
@@ -161,25 +162,26 @@ def fit(self):
161162
Run problem with RALFit.
162163
"""
163164
if self.cost_func.hessian:
164-
self._popt = ral_nlls.solve(
165+
(self._popt, inform) = ral_nlls.solve(
165166
self.initial_params,
166167
self.cost_func.eval_r,
167168
self.cost_func.jac_res,
168169
self.hes_eval,
169170
options=self._options,
170171
lower_bounds=self.param_ranges[0],
171172
upper_bounds=self.param_ranges[1],
172-
)[0]
173+
)
173174
else:
174-
self._popt = ral_nlls.solve(
175+
(self._popt, inform) = ral_nlls.solve(
175176
self.initial_params,
176177
self.cost_func.eval_r,
177178
self.cost_func.jac_res,
178179
options=self._options,
179180
lower_bounds=self.param_ranges[0],
180181
upper_bounds=self.param_ranges[1],
181-
)[0]
182+
)
182183
self._status = 0 if self._popt is not None else 1
184+
self._iter = inform["iter"]
183185

184186
def cleanup(self):
185187
"""
@@ -192,3 +194,4 @@ def cleanup(self):
192194
self.flag = 2
193195

194196
self.final_params = self._popt
197+
self.iteration_count = self._iter

fitbenchmarking/controllers/scipy_controller.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,7 @@ def cleanup(self):
153153
else:
154154
self.flag = 2
155155

156+
self.func_evals = self.result.nfev
157+
self.iteration_count = self.result.nit
158+
156159
self.final_params = self._popt

fitbenchmarking/controllers/scipy_go_controller.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def __init__(self, cost_func):
5252
super().__init__(cost_func)
5353

5454
self.support_for_bounds = True
55-
self._popt = None
56-
self._status = None
55+
self._result = None
5756
self._maxiter = None
5857

5958
def setup(self):
@@ -90,25 +89,20 @@ def fit(self):
9089
fun = self.cost_func.eval_cost
9190
bounds = self.value_ranges
9291
algorithm = getattr(optimize, self.minimizer)
93-
result = algorithm(fun, bounds, **kwargs)
94-
self._popt = result.x
95-
if result.success:
96-
self._status = 0
97-
elif "Maximum number of iteration" in result.message:
98-
self._status = 1
99-
else:
100-
self._status = 2
92+
self._result = algorithm(fun, bounds, **kwargs)
10193

10294
def cleanup(self):
10395
"""
10496
Convert the result to a numpy array and populate the variables results
10597
will be read from.
10698
"""
107-
if self._status == 0:
99+
if self._result.success:
108100
self.flag = 0
109-
elif self._status == 1:
101+
elif "Maximum number of iteration reached" in self._result.message:
110102
self.flag = 1
111103
else:
112104
self.flag = 2
113105

114-
self.final_params = self._popt
106+
self.final_params = self._result.x
107+
self.iteration_count = self._result.nit
108+
self.func_evals = self._result.nfev

fitbenchmarking/controllers/scipy_leastsq_controller.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,4 +80,6 @@ def cleanup(self):
8080
else:
8181
self.flag = 2
8282

83+
self.func_evals = self.result[2]["nfev"]
84+
8385
self.final_params = self._popt

fitbenchmarking/controllers/scipy_ls_controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,5 @@ def cleanup(self):
110110
else:
111111
self.flag = 2
112112

113+
self.func_evals = self.result.nfev
113114
self.final_params = self._popt

fitbenchmarking/controllers/tests/test_controllers.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,49 @@ def test_check_invalid_final_params(self):
284284
with self.assertRaises(exceptions.ControllerAttributeError):
285285
controller.check_attributes()
286286

287+
def test_check_valid_iteration_count(self):
288+
"""
289+
BaseSoftwareController: Test iteration_count setting with valid value
290+
"""
291+
controller = DummyController(self.cost_func)
292+
controller.final_params = [1, 2, 3, 4, 5]
293+
controller.flag = 3
294+
controller.iteration_count = 10
295+
controller.check_attributes()
296+
297+
def test_check_invalid_iteration_count(self):
298+
"""
299+
BaseSoftwareController: Test iteration_count setting with invalid value
300+
"""
301+
controller = DummyController(self.cost_func)
302+
controller.final_params = [1, 2, 3, 4, 5]
303+
controller.flag = 3
304+
controller.iteration_count = 10.5
305+
with self.assertRaises(exceptions.ControllerAttributeError):
306+
controller.check_attributes()
307+
308+
def test_check_valid_func_evals(self):
309+
"""
310+
BaseSoftwareController: Test func_evals setting with valid value
311+
"""
312+
controller = DummyController(self.cost_func)
313+
controller.final_params = [1, 2, 3, 4, 5]
314+
controller.flag = 3
315+
controller.iteration_count = 10
316+
controller.func_evals = 10
317+
controller.check_attributes()
318+
319+
def test_check_invalid_func_evals(self):
320+
"""
321+
BaseSoftwareController: Test func_evals setting with invalid value
322+
"""
323+
controller = DummyController(self.cost_func)
324+
controller.final_params = [1, 2, 3, 4, 5]
325+
controller.flag = 3
326+
controller.func_evals = 10.5
327+
with self.assertRaises(exceptions.ControllerAttributeError):
328+
controller.check_attributes()
329+
287330
def test_validate_minimizer_true(self):
288331
"""
289332
BaseSoftwareController: Test validate_minimizer with valid
@@ -1125,11 +1168,12 @@ def test_scipy_go(self):
11251168

11261169
self.shared_tests.controller_run_test(controller)
11271170

1128-
controller._status = 0
11291171
self.shared_tests.check_converged(controller)
1130-
controller._status = 1
1172+
controller._result.success = False
11311173
self.shared_tests.check_max_iterations(controller)
1132-
controller._status = 2
1174+
controller._result.message = [
1175+
"Maximum number of iteration NOT reached"
1176+
]
11331177
self.shared_tests.check_diverged(controller)
11341178

11351179
def test_gradient_free(self):

fitbenchmarking/core/tests/test_fitting_benchmarking.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,8 @@ def test_benchmark_method(self):
910910
"jacobian_tag",
911911
"hessian_tag",
912912
"costfun_tag",
913+
"iteration_count",
914+
"func_evals",
913915
]:
914916
assert getattr(r, attr) == expected["results"][ix][attr]
915917
self.assertAlmostEqual(

fitbenchmarking/results_processing/fitting_report.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def create_prob_group(result, support_pages_dir, options):
8484
n_params = result.get_n_parameters()
8585
list_params = n_params < 100
8686

87+
iteration_count = (
88+
str(result.iteration_count)
89+
if result.iteration_count
90+
else "not available"
91+
)
92+
93+
func_evals = (
94+
str(result.func_evals) if result.func_evals else "not available"
95+
)
96+
8797
if np.isnan(result.emissions):
8898
emission_disp = "N/A"
8999
else:
@@ -115,6 +125,8 @@ def create_prob_group(result, support_pages_dir, options):
115125
n_params=n_params,
116126
list_params=list_params,
117127
n_data_points=result.get_n_data_points(),
128+
iteration_count=iteration_count,
129+
func_evals=func_evals,
118130
)
119131
)
120132

0 commit comments

Comments
 (0)