Skip to content

Commit

Permalink
initial attempt to handle warnings at bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
kratsg committed Jun 26, 2024
1 parent 7259d60 commit 4412a76
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/pyhf/optimize/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,30 @@ def _internal_minimize(
raise exceptions.FailedMinimization(result)
return result

def _internal_postprocess(self, fitresult, stitch_pars, return_uncertainties=False):
def _internal_postprocess(
self, fitresult, stitch_pars, /, par_bounds, *, return_uncertainties=False
):
"""
Post-process the fit result.
Args:
fitresult (scipy.optimize.OptimizeResult): Fit result from :func:`_internal_minimize`
stitch_pars (:obj:`func`): callable that stitches fixed parameters into the unfixed parameters
par_bounds (:obj:`list` of :obj:`list`/:obj:`tuple`): The extrema of values the model parameters
are allowed to reach in the fit.
The shape should be ``(n, 2)`` for ``n`` model parameters.
return_uncertainties (:obj:`bool`): Return uncertainties on the fitted parameters. Default is off (``False``).
Returns:
fitresult (scipy.optimize.OptimizeResult): A modified version of the fit result.
"""
tensorlib, _ = get_backend()

# TODO: check how to handle this for batching

Check notice on line 89 in src/pyhf/optimize/mixins.py

View check run for this annotation

codefactor.io / CodeFactor

src/pyhf/optimize/mixins.py#L89

Unresolved comment '# TODO: check how to handle this for batching'. (C100)
for par_index, (fitted_par, bound) in enumerate(zip(fitresult.x, par_bounds)):
if fitted_par in bound:
log.warning(f'parameter at index {par_index} is at the bounds')

# stitch in missing parameters (e.g. fixed parameters)
fitted_pars = stitch_pars(tensorlib.astensor(fitresult.x))

Expand Down Expand Up @@ -195,7 +210,10 @@ def minimize(
**minimizer_kwargs, options=kwargs, par_names=par_names
)
result = self._internal_postprocess(
result, stitch_pars, return_uncertainties=return_uncertainties
result,
stitch_pars,
par_bounds=minimizer_kwargs['bounds'],
return_uncertainties=return_uncertainties,
)

_returns = [result.x]
Expand Down

0 comments on commit 4412a76

Please sign in to comment.