Skip to content

Commit

Permalink
add torch_compile_backend to config
Browse files Browse the repository at this point in the history
  • Loading branch information
souryadey committed Oct 14, 2023
1 parent 697b651 commit df20ab8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
8 changes: 7 additions & 1 deletion dlkoopman/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ class Config():
- **use_cuda** (*bool, optional*) - If `True`, tensor computations will take place on CuDA GPUs if available.
- **torch_compile_backend** (*str / None, optional*) - The backend to use for `torch.compile()`, which is a feature added in torch major version 2 to potentially speed up computation. For full lists of possible backends, run `torch._dynamo.list_backends()` and `torch._dynamo.list_backends(None)`. See the []`torch.compile()` documentation](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details.
- If you are using a major version of torch less than 2 or you set `torch_compile_backend = None`, the DLKoopman neural nets will not pass through `torch.compile()`.
- **normalize_Xdata** (*bool, optional*) - If `True`, all input states (training, validation, test) are divided by the maximum absolute value in the training data.
- Note that normalizing data is a generally good technique for deep learning, and is normally done for each feature \\(f\\) in the input data as
$$X_f = \\frac{X_f-\\text{offset}_f}{\\text{scale}_f}$$
Expand All @@ -47,12 +50,14 @@ class Config():
def __init__(self,
precision = "float",
use_cuda = True,
torch_compile_backend = "aot_eager",
normalize_Xdata = True,
use_exact_eigenvectors = True,
sigma_threshold = 1e-25
):
self.precision = precision
self.use_cuda = use_cuda
self.torch_compile_backend = torch_compile_backend
self.normalize_Xdata = normalize_Xdata
self.use_exact_eigenvectors = use_exact_eigenvectors
self.sigma_threshold = sigma_threshold
Expand All @@ -61,6 +66,8 @@ def __init__(self,
raise ConfigValidationError(f'`precision` must be either of "half" / "float" / "double", instead found {precision}')
if use_cuda not in [True, False]:
raise ConfigValidationError(f'`use_cuda` must be either True or False, instead found {use_cuda}')
if torch_compile_backend not in torch._dynamo.list_backends() + torch._dynamo.list_backends(None) + [None]:
raise ConfigValidationError(f'`torch_compile_backend` must be either None or one out the options obtained from running `torch._dynamo.list_backends()` or `torch._dynamo.list_backends(None)`, instead found {torch_compile_backend}')
if normalize_Xdata not in [True, False]:
raise ConfigValidationError(f'`normalize_Xdata` must be either True or False, instead found {normalize_Xdata}')
if use_exact_eigenvectors not in [True, False]:
Expand All @@ -71,4 +78,3 @@ def __init__(self,
self.RTYPE = torch.half if self.precision=="half" else torch.float if self.precision=="float" else torch.double
self.CTYPE = torch.chalf if self.precision=="half" else torch.cfloat if self.precision=="float" else torch.cdouble
self.DEVICE = torch.device("cuda" if self.use_cuda and torch.cuda.is_available() else "cpu")
self.BACKEND = "aot_eager"
4 changes: 2 additions & 2 deletions dlkoopman/state_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def __init__(self,
batch_norm = batch_norm
)
self.ae.to(dtype=self.cfg.RTYPE, device=self.cfg.DEVICE)
if int(torch.__version__[0]) > 1:
self.ae = torch.compile(self.ae, backend=self.cfg.BACKEND)
if utils.is_torch_2() and self.cfg.torch_compile_backend is not None:
self.ae = torch.compile(self.ae, backend=self.cfg.torch_compile_backend)

## Get rank and ensure it's a valid value
full_rank = min(self.encoded_size,len(self.dh.ttr)-1) #this is basically min(Y.shape), where Y is defined in _dmd_linearize()
Expand Down
8 changes: 4 additions & 4 deletions dlkoopman/traj_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,16 @@ def __init__(self,
batch_norm = batch_norm
)
self.ae.to(dtype=self.cfg.RTYPE, device=self.cfg.DEVICE)
if int(torch.__version__[0]) > 1:
self.ae = torch.compile(self.ae, backend=self.cfg.BACKEND)
if utils.is_torch_2() and self.cfg.torch_compile_backend is not None:
self.ae = torch.compile(self.ae, backend=self.cfg.torch_compile_backend)

## Define linear layer
self.Knet = nets.Knet(
size = encoded_size
)
self.Knet.to(dtype=self.cfg.RTYPE, device=self.cfg.DEVICE)
if int(torch.__version__[0]) > 1:
self.Knet = torch.compile(self.Knet, backend=self.cfg.BACKEND)
if utils.is_torch_2() and self.cfg.torch_compile_backend is not None:
self.Knet = torch.compile(self.Knet, backend=self.cfg.torch_compile_backend)

## Define params
self.params = list(self.ae.parameters()) + list(self.Knet.parameters())
Expand Down
14 changes: 13 additions & 1 deletion dlkoopman/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
'scale': False,
'shift': False,
'extract_item': False,
'moving_avg': False
'moving_avg': False,
'is_torch_2': False
}


Expand Down Expand Up @@ -185,3 +186,14 @@ def moving_avg(inp, window_size = 3):
if window_size > len(inp):
raise ValueError(f"'window_size' must be <= length of 'inp', but {window_size} > {len(inp)}")
return type(inp)(np.convolve(inp, np.ones(window_size), 'valid') / window_size)


def is_torch_2() -> bool:
"""
Check if the major version of torch being used is 2 or not.
Return False if major version cannot be determined.
"""
try:
return int(torch.__version__[0]) == 2
except ValueError:
return False

0 comments on commit df20ab8

Please sign in to comment.