Skip to content

Commit

Permalink
Fixed memory usage issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanemenecker committed Nov 12, 2024
1 parent ed4ffe7 commit bf01e09
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 43 deletions.
11 changes: 2 additions & 9 deletions metapredict/backend/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,9 @@ def forward(self, x):
[batch_dim X sequence_length X num_classes]
"""

# Set initial states
# h0 and c0 dimensions: [num_layers*2 X batch_size X hidden_size]
#h0 = torch.zeros(self.num_layers*2, # *2 for bidirection
# x.size(0), self.hidden_size).to(self.device)
#c0 = torch.zeros(self.num_layers*2,
# x.size(0), self.hidden_size).to(self.device)

# Forward propagate LSTM
# out: tensor of shape: [batch_size, seq_length, hidden_size*2]
out, (h_n, c_n) = self.lstm(x)
out, _ = self.lstm(x)

# Decode the hidden state for each time step
fc_out = self.fc(out)
Expand Down Expand Up @@ -242,7 +235,7 @@ def forward(self, x):
"""
# Forward propagate LSTM
# out: tensor of shape: [batch_size, seq_length, lstm_hidden_size*2]
out, (h_n, c_n) = self.lstm(x)
out, _ = self.lstm(x)
out = self.layer_norm(out)
for layer in self.linear_layers:
out = layer(out)
Expand Down
50 changes: 16 additions & 34 deletions metapredict/backend/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,12 @@ def take_care_of_version(version_input):

# function to load model
# A variable to store the loaded model
'''
loaded_models = {}

# gets model. This lets us avoid iteratively loading the model
# because it can check the global dictionary to see if the model
# has already been loaded.
# if you don't do this, you start getting memory issues
def get_model(model_name, params, predictor_path, device):
global loaded_models # Ensure the dictionary is accessible across calls

Expand All @@ -307,7 +310,6 @@ def get_model(model_name, params, predictor_path, device):
# Store the loaded model in the dictionary using the model_name as key
loaded_models[model_name] = model
return model
'''

# ....................................................................................

Expand Down Expand Up @@ -574,24 +576,11 @@ def predict(inputs,
##
## ....................................................................................

# load network. We do this differently depending on if we used
# pytorch or pytorch-lightning to make the network.
if params['used_lightning']==False:
model=architectures.BRNN_MtM(input_size=params['input_size'],
hidden_size=params['hidden_size'], num_layers=params['num_layers'],
num_classes=params['num_classes'], device=device)
network=torch.load(predictor_path, map_location=device)
model.load_state_dict(network)

else:
# if it's a pytorch-lightning, we can just use load_from_checkpoint
model=architectures.BRNN_MtM_lightning
model = model.load_from_checkpoint(predictor_path, map_location=device)

#model = get_model(model_name=version,
# params=params,
# predictor_path=predictor_path,
# device=device)
# load model
model = get_model(model_name=f'disorder_{version}',
params=params,
predictor_path=predictor_path,
device=device)

# set to eval mode
model.eval()
Expand All @@ -612,7 +601,7 @@ def predict(inputs,
start_time = time.time()

# encode the sequence
seq_vector = encode_sequence.one_hot(inputs)
seq_vector = encode_sequence.one_hot(inputs)
seq_vector = seq_vector.to(device)
seq_vector = seq_vector.view(1, len(seq_vector), -1)

Expand Down Expand Up @@ -819,7 +808,7 @@ def predict(inputs,

# lstm forward pass.
with torch.no_grad():
outputs, (ht,ct) = model.lstm(packed_seqs)
outputs, _ = model.lstm(packed_seqs)

# unpack the packed sequence
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
Expand Down Expand Up @@ -1154,17 +1143,10 @@ def predict_pLDDT(inputs,

# load network. We do this differently depending on if we used
# pytorch or pytorch-lightning to make the network.
if params['used_lightning']==False:
model=architectures.BRNN_MtM(input_size=params['input_size'],
hidden_size=params['hidden_size'], num_layers=params['num_layers'],
num_classes=params['num_classes'], device=device)
network=torch.load(predictor_path, map_location=device)
model.load_state_dict(network)

else:
# if it's a pytorch-lightning, we can just use load_from_checkpoint
model=architectures.BRNN_MtM_lightning
model = model.load_from_checkpoint(predictor_path, map_location=device)
model = get_model(model_name=f'pLDDT_{version}',
params=params,
predictor_path=predictor_path,
device=device)

# set to eval mode
model.eval()
Expand Down Expand Up @@ -1408,7 +1390,7 @@ def predict_pLDDT(inputs,

# lstm forward pass.
with torch.no_grad():
outputs, (ht,ct) = model.lstm(packed_seqs)
outputs, _ = model.lstm(packed_seqs)

# unpack the packed sequence
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
Expand Down

0 comments on commit bf01e09

Please sign in to comment.