diff --git a/model.py b/model.py index c698f8b601..87cf41b39c 100644 --- a/model.py +++ b/model.py @@ -48,6 +48,28 @@ def __init__(self, config): # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def apply_rotary_position_embeddings(self, sinusoidal_pos, q, k): + # Split the sinusoidal_pos into sin and cos parts + sin, cos = sinusoidal_pos.chunk(2, dim=-1) + # Apply the rotary embeddings to the query and key + q_rot = torch.stack((-q[..., 1::2], q[..., ::2]), dim=-1) + k_rot = torch.stack((-k[..., 1::2], k[..., ::2]), dim=-1) + q_rot = torch.reshape(q_rot, q.shape[:-1] + (q.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1) + k_rot = torch.reshape(k_rot, k.shape[:-1] + (k.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1) + q_rot = torch.reshape(q_rot, q.shape) + k_rot = torch.reshape(k_rot, k.shape) + return q_rot, k_rot + + def get_sinusoidal_embeddings(self, n_positions, dim): + """Generate sinusoidal positional embeddings.""" + position = torch.arange(n_positions, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) + sinusoidal_emb = torch.zeros((n_positions, dim)) + sinusoidal_emb[:, 0::2] = torch.sin(position * div_term) + sinusoidal_emb[:, 1::2] = torch.cos(position * div_term) + return sinusoidal_emb.to(self.device) def forward(self, x): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) @@ -58,13 +80,17 @@ def forward(self, x): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + # apply rotary position embeddings + sinusoidal_pos = self.get_sinusoidal_embeddings(T, self.n_embd // self.n_head) + q_rot, k_rot = self.apply_rotary_position_embeddings(sinusoidal_pos, q, k) + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + y = torch.nn.functional.scaled_dot_product_attention(q_rot, k_rot, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) else: # manual implementation of attention - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = (q_rot @ k_rot.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) diff --git a/train.py b/train.py index 951bda9914..8be1e82642 100644 --- a/train.py +++ b/train.py @@ -44,7 +44,7 @@ wandb_project = 'owt' wandb_run_name = 'gpt2' # 'run' + str(time.time()) # data -dataset = 'openwebtext' +dataset = 'shakespeare_char' gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size block_size = 1024