Skip to content

Commit 00e66b9

Browse files
committed
Support different models for the gradient
1 parent 3946e19 commit 00e66b9

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

hoi/metrics/gradient_oinfo.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,12 @@ class GradientOinfo(HOIEstimator):
4545
_negative = "synergy"
4646
_symmetric = True
4747

48-
def __init__(self, x, y, multiplets=None, verbose=None):
48+
def __init__(self, x, y, multiplets=None, base_model=Oinfo, verbose=None):
4949
kw_oinfo = dict(multiplets=multiplets, verbose=verbose)
5050
HOIEstimator.__init__(self, x=x, y=None, **kw_oinfo)
51-
self._oinf_tr = Oinfo(x, y=y, **kw_oinfo)
52-
self._oinf_tf = Oinfo(x, **kw_oinfo)
51+
self._oinf_tr = base_model(x, y=y, **kw_oinfo)
52+
self._oinf_tf = base_model(x, **kw_oinfo)
53+
self.__name__ = self.__name__ + "(%s)" % base_model.__name__
5354

5455
def fit(self, minsize=2, maxsize=None, method="gcmi", **kwargs):
5556
"""Compute the Gradient O-information.

0 commit comments

Comments
 (0)