In ProteinMPNN, the main sampling function is located in the protein_mpnn_utils.py file. The function name is sample() or tied_sample().
class ProteinMPNN():
...
def sample(self,...):
...
At the beginning of the file, import necessary functions, of note, the private key should be changed to your own key.
######################################
### IMPORT WATERMARK PACKAGE #########
######################################
from proteinwatermark import (
DeltaGumbel_Reweight,
WatermarkLogitsProcessor,
)
delta_wp = WatermarkLogitsProcessor(
b"private key",
DeltaGumbel_Reweight(),
context_code_length=5,
)
######################################
Then, use the logit processor to modify the logits in the sample function.
def sample(self, ...):
...
for t_ in range(N_nodes):
...
if ...:
...
else:
...
logits = self.W_out(h_V_t) / temperature
logits = logits-constant[None,:]*1e8+constant_bias[None,:]/temperature+bias_by_res_gathered/temperature
#################################################################
### MODIFY THE LOGITS TO ADD WATERMARK ##########################
S_record = S.detach().cpu().numpy() # CURRENT GENERATED SEQUENCES
logits = delta_wp("order_agnoistic", # since it is proteinMPNN
S_record, # current sequences
logits.detach().cpu().numpy(),
current_pos=t.long().detach().cpu())
logits = torch.FloatTensor(logits).to(device)
#################################################################
...
...