Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit 0a3ec8f
Author: GreatRSingh <rakshitsingh421@gmail.com>
Date:   Fri Jan 19 17:33:57 2024 +0530

    sibling test

commit 2b83c5f
Author: GreatRSingh <rakshitsingh421@gmail.com>
Date:   Fri Jan 19 16:55:51 2024 +0530

    transform tests
  • Loading branch information
GreatRSingh committed Jan 19, 2024
1 parent d5837de commit b8813a7
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 6 deletions.
18 changes: 17 additions & 1 deletion deepchem/utils/differentiation_utils/pure_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def _check_identical_objs(objs1: List, objs2: List) -> bool:
return False
return True


class SingleSiblingPureFunction(PureFunction):
"""Implementation of PureFunction for a sibling method
Expand Down Expand Up @@ -465,7 +466,6 @@ def _set_all_obj_params(self, allobjparams: List):
allobjparams[self.cumsum_idx[i]:self.cumsum_idx[i + 1]])



def get_pure_function(fcn) -> PureFunction:
"""Get the pure function form of the function or method ``fcn``.
Expand Down Expand Up @@ -530,6 +530,22 @@ def make_sibling(*pfuncs) -> Callable[[Callable], PureFunction]:
Changing the state of the decorated function will also change the state of
``pfunc`` and its other siblings.
Examples
--------
>>> import torch
>>> from deepchem.utils.differentiation_utils import make_sibling
>>> def fcn1(x, y):
... return x + y
>>> def fcn2(x, y):
... return x - y
>>> pfunc1 = get_pure_function(fcn1)
>>> pfunc2 = get_pure_function(fcn2)
>>> @make_sibling(pfunc1)
... def fcn3(x, y):
... return x * y
>>> pfunc3(1, 2)
2
Parameters
----------
pfuncs: List[Callable]
Expand Down
2 changes: 1 addition & 1 deletion deepchem/utils/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def to_fortran_order(V):
"Only the last two dimensions can be made Fortran order.")


def get_np_dtype(dtype: torch.dtype) -> np.dtype:
def get_np_dtype(dtype: torch.dtype) -> Any:
"""corresponding numpy dtype from the input pytorch's tensor dtype
Examples
Expand Down
39 changes: 35 additions & 4 deletions deepchem/utils/test/test_dft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def test_radial_grid():
def test_get_xw_integration():
from deepchem.utils.dft_utils import get_xw_integration
x, w = get_xw_integration(4, "chebyshev")
assert x.shape == (4, )
assert x.shape == (4,)
assert w.shape == torch.Size([4])


Expand All @@ -472,6 +472,37 @@ def test_sliced_radial_grid():


@pytest.mark.torch
def test_base_grid_transform():
from deepchem.utils.dft_utils import BaseGridTransform

def test_de2_transform():
from deepchem.utils.dft_utils import DE2Transformation
x = torch.linspace(-1, 1, 100)
r = DE2Transformation().x2r(x)
assert r.shape == torch.Size([100])
drdx = DE2Transformation().get_drdx(x)
assert drdx.shape == torch.Size([100])


@pytest.mark.torch
def test_logm3_transform():
from deepchem.utils.dft_utils import LogM3Transformation
x = torch.linspace(-1, 1, 100)
r = LogM3Transformation().x2r(x)
assert r.shape == torch.Size([100])
drdx = LogM3Transformation().get_drdx(x)
assert drdx.shape == torch.Size([100])


@pytest.mark.torch
def test_treutlerm4_transform():
from deepchem.utils.dft_utils import TreutlerM4Transformation
x = torch.linspace(-1, 1, 100)
r = TreutlerM4Transformation().x2r(x)
assert r.shape == torch.Size([100])
drdx = TreutlerM4Transformation().get_drdx(x)
assert drdx.shape == torch.Size([100])


@pytest.mark.torch
def test_get_grid_transform():
from deepchem.utils.dft_utils import get_grid_transform
transform = get_grid_transform("logm3")
transform.x2r(torch.tensor([0.5])) == torch.tensor([2.])
14 changes: 14 additions & 0 deletions deepchem/utils/test/test_differentiation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,20 @@ def fcn(x, y):
assert pfunc(1, 2) == 3


@pytest.mark.torch
def test_make_siblings():
from deepchem.utils.differentiation_utils import make_sibling

def fcn1(x, y):
return x + y

@make_sibling(fcn1)
def fcn3(x, y):
return x * y

assert fcn3(1, 2) == 2


@pytest.mark.torch
def test_wrap_gmres():
from deepchem.utils.differentiation_utils.solve import wrap_gmres
Expand Down

0 comments on commit b8813a7

Please sign in to comment.