Skip to content

Commit

Permalink
Implemented Kestrel
Browse files Browse the repository at this point in the history
  • Loading branch information
ardagoreci committed Nov 24, 2023
1 parent 8141cb7 commit 80d3e7d
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 58 deletions.
19 changes: 19 additions & 0 deletions configs/data/protein.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_target_: src.data.protein_datamodule.ProteinDataModule
data_dir: "./data/"
resolution_thr: 3.5 # Resolution threshold for PDB structures
min_seq_id: 0.3 # Minimum sequence identity for MMSeq2 clustering
crop_size: 384 # The number of residues to crop the proteins to.
max_length: 10_000 # Entries with total length of chains larger than max_length will be disregarded.
use_fraction: 1.0 # the fraction of the clusters to use (first N in alphabetic order)
entry_type: "chain" # { "biounit", "chain", "pair" } the type of entries to generate
classes_to_exclude: ['homomers', 'heteromers'] # a list of classes to exclude from the dataset
mask_residues: False # if True, the masked residues will be added to the output
lower_limit: 15 # the lower limit of the number of residues to mask
upper_limit: 100 # the upper limit of the number of residues to mask
mask_frac: None # if given, the number of residues to mask is mask_frac times the length of the chain
mask_sequential: False # if True, the masked residues will be neighbors in the sequence; otherwise geometric mask
mask_whole_chains: False # if True, the whole chain is masked
force_binding_sites_frac: 0.15 #
batch_size: 64 # The batch size. Defaults to `64`.
num_workers: 0 # The number of workers. Defaults to `0`.
pin_memory: False # Whether to pin memory. Defaults to `False`.
36 changes: 36 additions & 0 deletions configs/model/kestrel.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
_target_: src.models.kestrel_module.KestrelLitModule

optimizer:
_target_: torch.optim.Adam
_partial_: true
lr: 0.001
weight_decay: 0.0

scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
_partial_: true
mode: min
factor: 0.1
patience: 10

structure_net:
_target_: src.models.structure_net.StructureNet
c_s: 384
c_z: 128
n_structure_layer: 4
n_structure_block: 1
c_hidden_ipa: 16
n_head_ipa: 12
n_qk_point: 4
n_v_point: 8
ipa_dropout: 0.1
n_structure_transition_layer: 1
structure_transition_dropout: 0.1

pair_feature_net:
_target_: src.models.pair_feature_net.PairFeatureNet
c_z: 128
relpos_k: 32

# compile model for faster training with pytorch 2.0
compile: false
6 changes: 3 additions & 3 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: mnist
- model: mnist
- data: protein # mnist
- model: kestrel # mnist
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
- paths: default
- extras: default
Expand Down
15 changes: 12 additions & 3 deletions src/data/protein_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,21 +271,30 @@ def train_dataloader(self) -> DataLoader[Any]:
"""Create and return the train dataloader.
:return: The train dataloader.
"""
return proteinflow.ProteinLoader(self.data_train, batch_size=self.batch_size_per_device)
return proteinflow.ProteinLoader(self.data_train,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory)

def val_dataloader(self) -> DataLoader[Any]:
"""Create and return the validation dataloader.
:return: The validation dataloader.
"""
return proteinflow.ProteinLoader(self.data_val, batch_size=self.batch_size_per_device)
return proteinflow.ProteinLoader(self.data_val,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory)

def test_dataloader(self) -> DataLoader[Any]:
"""Create and return the test dataloader.
:return: The test dataloader.
"""
return proteinflow.ProteinLoader(self.data_test, batch_size=self.batch_size_per_device)
return proteinflow.ProteinLoader(self.data_test,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory)

def teardown(self, stage: Optional[str] = None) -> None:
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
Expand Down
66 changes: 65 additions & 1 deletion src/diffusion/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def compute_fape_squared(
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
length_scale: float,
length_scale: float = 10.0,
l2_clamp_distance: Optional[float] = None,
eps=1e-8,
) -> torch.Tensor:
Expand Down Expand Up @@ -73,3 +73,67 @@ def compute_fape_squared(
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))

return normed_error


def fape_squared_with_clamp(
pred_frames: Rigids,
target_frames: Rigids,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
use_clamped_fape: float = 0.9,
l2_clamp_distance: float = 100.0, # 10A ^ 2
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
"""Compute squared FAPE loss with clamping.
Args:
pred_frames:
[*, N_frames] Rigid object of predicted frames
target_frames:
[*, N_frames] Rigid object of ground truth frames
frames_mask:
[*, N_frames] binary mask for the frames
pred_positions:
[*, N_pts, 3] predicted atom positions
target_positions:
[*, N_pts, 3] ground truth positions
positions_mask:
[*, N_pts] positions mask
use_clamped_fape:
ratio of clamped to unclamped FAPE in final loss
l2_clamp_distance:
Cutoff above which squared distance errors are disregarded.
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
fape_loss = compute_fape_squared(pred_frames=pred_frames,
target_frames=target_frames,
frames_mask=frames_mask,
pred_positions=pred_positions,
target_positions=target_positions,
positions_mask=positions_mask,
l2_clamp_distance=l2_clamp_distance,
eps=eps)
if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape_squared(pred_frames=pred_frames,
target_frames=target_frames,
frames_mask=frames_mask,
pred_positions=pred_positions,
target_positions=target_positions,
positions_mask=positions_mask,
l2_clamp_distance=l2_clamp_distance,
eps=eps)
use_clamped_fape = torch.Tensor([use_clamped_fape]) # for proper multiplication
# Average the two to provide a useful training signal even early on in training.
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
1 - use_clamped_fape
)

# Average over the batch dimension
fape_loss = torch.mean(fape_loss)

return fape_loss
Loading

0 comments on commit 80d3e7d

Please sign in to comment.