Skip to content

Commit

Permalink
Fix continuous learning phase and update forward method
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] committed Aug 20, 2024
1 parent 2d9b279 commit 8d40c78
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 20 deletions.
Binary file added models/neurocoder_model.pth
Binary file not shown.
9 changes: 5 additions & 4 deletions src/models/advanced_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, m
return output

class AdvancedNeuroCoder(nn.Module):
def __init__(self, vocab_size: int, d_model: int = 768, n_layers: int = 12, n_heads: int = 12, num_tasks: int = 3):
def __init__(self, vocab_size: int, d_model: int = 768, n_layers: int = 12, n_heads: int = 12, num_tasks: int = 4):
super(AdvancedNeuroCoder, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.embedding = nn.Embedding(vocab_size, d_model)
Expand All @@ -131,9 +131,9 @@ def __init__(self, vocab_size: int, d_model: int = 768, n_layers: int = 12, n_he
self.criterion = nn.CrossEntropyLoss()
self.to(self.device)

def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, adj_matrix: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
x = x.to(self.device)
x = self.embedding(x)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, adj_matrix: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
input_ids = input_ids.to(self.device)
x = self.embedding(input_ids)

# Ensure x has the correct shape for inception layer
batch_size, seq_len, d_model = x.shape
Expand All @@ -151,6 +151,7 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, adj_matr

# Apply attention mask if provided
if attention_mask is not None:
attention_mask = attention_mask.to(self.device)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0

Expand Down
74 changes: 59 additions & 15 deletions src/models/model_training.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import shutil
import logging
import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -11,8 +14,6 @@
from bayes_opt import BayesianOptimization
from torch.nn.utils.rnn import pad_sequence

from src.models.advanced_architecture import AdvancedNeuroCoder

# AdvancedNeuroCoder is now imported and will be used instead of the previous NeuroCoder class

def load_datasets():
Expand Down Expand Up @@ -195,28 +196,55 @@ def objective(learning_rate, weight_decay, warmup_steps, num_epochs):

def continuous_learning(model: AdvancedNeuroCoder, new_data: List[Dict[str, torch.Tensor]]):
optimizer = AdamW(model.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
token_criterion = nn.CrossEntropyLoss(ignore_index=-100)
task_criterion = nn.CrossEntropyLoss()

model.train()
for batch in new_data:
optimizer.zero_grad()
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
task = batch['task']

outputs = model(input_ids=input_ids, attention_mask=attention_mask)
loss = criterion(outputs.view(-1, outputs.size(-1)), labels.view(-1))
# Ensure all tensors are on the correct device and have the correct shape
input_ids = batch['input_ids'].to(model.device)
attention_mask = batch['attention_mask'].to(model.device)
labels = batch['labels'].to(model.device)
task_labels = batch['task_labels'].to(model.device)

# Ensure input tensors have the correct shape (batch_size, sequence_length)
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
labels = labels.unsqueeze(0)
task_labels = task_labels.unsqueeze(0)

token_outputs, task_outputs = model(input_ids=input_ids, attention_mask=attention_mask)

# Ensure token_outputs and labels have the same shape
if token_outputs.shape[1] != labels.shape[1]:
min_len = min(token_outputs.shape[1], labels.shape[1])
token_outputs = token_outputs[:, :min_len, :]
labels = labels[:, :min_len]

# Mask out padding tokens
mask = (labels != -100).float()
token_loss = token_criterion(token_outputs.contiguous().view(-1, token_outputs.size(-1)), labels.contiguous().view(-1))
token_loss = (token_loss * mask.view(-1)).sum() / mask.sum()

task_loss = task_criterion(task_outputs, task_labels.squeeze())
loss = token_loss + task_loss

# Add regularization to preserve existing knowledge
for param, old_param in zip(model.parameters(), model.old_params):
loss += 0.001 * torch.sum((param - old_param) ** 2)
if hasattr(model, 'old_params'):
for param, old_param in zip(model.parameters(), model.old_params):
loss += 0.001 * torch.sum((param - old_param) ** 2)

loss.backward()
optimizer.step()

# Update old parameters
model.old_params = [param.clone().detach() for param in model.parameters()]

print(f"Continuous learning completed. Final loss: {loss.item():.4f}")

if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AdvancedNeuroCoder(vocab_size=10000).to(device) # Adjust vocab_size as needed
Expand Down Expand Up @@ -244,11 +272,27 @@ def continuous_learning(model: AdvancedNeuroCoder, new_data: List[Dict[str, torc
config.update(optimized_config)

# Train the model
train_model(model, train_loader, val_loader, config)
trained_model = train_model(model, train_loader, val_loader, config)

# Continuous learning
new_data = load_datasets() # Load new data periodically
continuous_learning(model, new_data)
new_train_data, new_val_data = load_datasets() # Load new data periodically
new_train_loader = DataLoader(new_train_data, batch_size=32, shuffle=True)
continuous_learning(trained_model, new_train_loader)

# Save the trained model
torch.save(model.state_dict(), 'neurocoder_model.pth')
try:
save_dir = os.path.join(os.getcwd(), 'models')
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'neurocoder_model.pth')
torch.save(trained_model.state_dict(), save_path)
print(f"Model saved successfully at {save_path}")
except Exception as e:
print(f"Error saving model: {str(e)}")
print(f"Current working directory: {os.getcwd()}")
print(f"Attempted save path: {save_path}")
print(f"Is directory writable? {os.access(save_dir, os.W_OK)}")
print(f"Free disk space: {shutil.disk_usage(save_dir).free / (1024 * 1024 * 1024):.2f} GB")

# Log model architecture and hyperparameters
logging.info(f"Model Architecture:\n{trained_model}")
logging.info(f"Final Hyperparameters: {config}")
2 changes: 1 addition & 1 deletion tests/test_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_forward_pass(model):
token_output, task_output = model(input_ids, attention_mask)
assert token_output is not None and task_output is not None
assert token_output.shape == (1, 4, 10000) # (batch_size, sequence_length, vocab_size)
assert task_output.shape == (1, 3) # (batch_size, num_tasks)
assert task_output.shape == (1, 4) # (batch_size, num_tasks)

def test_model_output_range(model):
input_ids = torch.tensor([[1, 2, 3, 4]])
Expand Down

0 comments on commit 8d40c78

Please sign in to comment.