Skip to content

Commit

Permalink
Format Python code according to PEP8
Browse files Browse the repository at this point in the history
  • Loading branch information
Fotiligner authored and github-actions[bot] committed Dec 10, 2023
1 parent 0aaaaaf commit 4a3b320
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 72 deletions.
49 changes: 31 additions & 18 deletions recbole/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def construct_transform(config):
"crop_itemseq": CropItemSequence,
"reorder_itemseq": ReorderItemSequence,
"user_defined": UserDefinedTransform,
"random_itemseq": RandomAugmentationSequence
"random_itemseq": RandomAugmentationSequence,
}
if config["transform"] not in str2transform:
raise NotImplementedError(
Expand Down Expand Up @@ -222,15 +222,15 @@ def __call__(self, dataset, interaction):
interaction.update(Interaction(new_dict))
return interaction


class RandomAugmentationSequence:
def __init__(self, config):
self.ITEM_SEQ = config["ITEM_ID_FIELD"] + config["LIST_SUFFIX"]
self.RANDOM_ITEM_SEQ = "Random_" + self.ITEM_SEQ
self.ITEM_SEQ_LEN = config["ITEM_LIST_LENGTH_FIELD"]
self.ITEM_SEQ_LEN = config["ITEM_LIST_LENGTH_FIELD"]
self.ITEM_ID = config["ITEM_ID_FIELD"]
config["RANDOM_ITEM_SEQ"] = self.RANDOM_ITEM_SEQ


def __call__(self, dataset, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
Expand All @@ -255,10 +255,10 @@ def __call__(self, dataset, interaction):
aug_seq, aug_len = self.item_mask(seq, n_items, length)
elif switch[0] == 2:
aug_seq, aug_len = self.item_reorder(seq, length)

aug_seq1.append(aug_seq)
aug_len1.append(aug_len)

if switch[1] == 0:
aug_seq, aug_len = self.item_crop(seq, length)
elif switch[1] == 1:
Expand All @@ -268,42 +268,55 @@ def __call__(self, dataset, interaction):

aug_seq2.append(aug_seq)
aug_len2.append(aug_len)

new_dict = {
"aug1" : torch.stack(aug_seq1),
"aug1_len" : torch.stack(aug_len1),
"aug2" : torch.stack(aug_seq2),
"aug2_len" : torch.stack(aug_len2)
"aug1": torch.stack(aug_seq1),
"aug1_len": torch.stack(aug_len1),
"aug2": torch.stack(aug_seq2),
"aug2_len": torch.stack(aug_len2),
}
interaction.update(Interaction(new_dict))
return interaction

def item_crop(self, item_seq, item_seq_len, eta=0.6):
num_left = math.floor(item_seq_len * eta)
crop_begin = random.randint(0, item_seq_len - num_left)
croped_item_seq = np.zeros(item_seq.shape[0])
if crop_begin + num_left < item_seq.shape[0]:
croped_item_seq[:num_left] = item_seq.cpu().detach().numpy()[crop_begin:crop_begin + num_left]
croped_item_seq[:num_left] = (
item_seq.cpu().detach().numpy()[crop_begin : crop_begin + num_left]
)
else:
croped_item_seq[:num_left] = item_seq.cpu().detach().numpy()[crop_begin:]
return torch.tensor(croped_item_seq, dtype=torch.long, device=item_seq.device),\
torch.tensor(num_left, dtype=torch.long, device=item_seq.device)
return torch.tensor(
croped_item_seq, dtype=torch.long, device=item_seq.device
), torch.tensor(num_left, dtype=torch.long, device=item_seq.device)

def item_mask(self, item_seq, n_items, item_seq_len, gamma=0.3):
num_mask = math.floor(item_seq_len * gamma)
mask_index = random.sample(range(item_seq_len), k=num_mask)
masked_item_seq = item_seq.cpu().detach().numpy().copy()
masked_item_seq[mask_index] = n_items - 1 # token 0 has been used for semantic masking
return torch.tensor(masked_item_seq, dtype=torch.long, device=item_seq.device), item_seq_len
masked_item_seq[mask_index] = (
n_items - 1
) # token 0 has been used for semantic masking
return (
torch.tensor(masked_item_seq, dtype=torch.long, device=item_seq.device),
item_seq_len,
)

def item_reorder(self, item_seq, item_seq_len, beta=0.6):
num_reorder = math.floor(item_seq_len * beta)
reorder_begin = random.randint(0, item_seq_len - num_reorder)
reordered_item_seq = item_seq.cpu().detach().numpy().copy()
shuffle_index = list(range(reorder_begin, reorder_begin + num_reorder))
random.shuffle(shuffle_index)
reordered_item_seq[reorder_begin:reorder_begin + num_reorder] = reordered_item_seq[shuffle_index]
return torch.tensor(reordered_item_seq, dtype=torch.long, device=item_seq.device), item_seq_len
reordered_item_seq[
reorder_begin : reorder_begin + num_reorder
] = reordered_item_seq[shuffle_index]
return (
torch.tensor(reordered_item_seq, dtype=torch.long, device=item_seq.device),
item_seq_len,
)


class CropItemSequence:
Expand Down
132 changes: 78 additions & 54 deletions recbole/model/sequential_recommender/cl4srec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,27 @@ def __init__(self, config, dataset):
super(CL4SRec, self).__init__(config, dataset)

# load parameters info
self.n_layers = config['n_layers']
self.n_heads = config['n_heads']
self.hidden_size = config['hidden_size']
self.inner_size = config['inner_size']
self.hidden_dropout_prob = config['hidden_dropout_prob']
self.attn_dropout_prob = config['attn_dropout_prob']
self.hidden_act = config['hidden_act']
self.layer_norm_eps = config['layer_norm_eps']
self.batch_size = config['train_batch_size']
self.lmd = config['lmd']
self.tau = config['tau']
self.sim = config['sim']

self.initializer_range = config['initializer_range']
self.loss_type = config['loss_type']
self.n_layers = config["n_layers"]
self.n_heads = config["n_heads"]
self.hidden_size = config["hidden_size"]
self.inner_size = config["inner_size"]
self.hidden_dropout_prob = config["hidden_dropout_prob"]
self.attn_dropout_prob = config["attn_dropout_prob"]
self.hidden_act = config["hidden_act"]
self.layer_norm_eps = config["layer_norm_eps"]

self.batch_size = config["train_batch_size"]
self.lmd = config["lmd"]
self.tau = config["tau"]
self.sim = config["sim"]

self.initializer_range = config["initializer_range"]
self.loss_type = config["loss_type"]

# define layers and loss
self.item_embedding = nn.Embedding(self.n_items + 1, self.hidden_size, padding_idx=0)
self.item_embedding = nn.Embedding(
self.n_items + 1, self.hidden_size, padding_idx=0
)
self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size)
self.trm_encoder = TransformerEncoder(
n_layers=self.n_layers,
Expand All @@ -46,27 +48,27 @@ def __init__(self, config, dataset):
hidden_dropout_prob=self.hidden_dropout_prob,
attn_dropout_prob=self.attn_dropout_prob,
hidden_act=self.hidden_act,
layer_norm_eps=self.layer_norm_eps
layer_norm_eps=self.layer_norm_eps,
)

self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self.dropout = nn.Dropout(self.hidden_dropout_prob)

if self.loss_type == 'BPR':
if self.loss_type == "BPR":
self.loss_fct = BPRLoss()
elif self.loss_type == 'CE':
elif self.loss_type == "CE":
self.loss_fct = nn.CrossEntropyLoss()
else:
raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!")

self.mask_default = self.mask_correlated_samples(batch_size=self.batch_size)
self.nce_fct = nn.CrossEntropyLoss()

# parameters initialization
self.apply(self._init_weights)

def _init_weights(self, module):
""" Initialize the weights """
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
Expand All @@ -80,7 +82,9 @@ def _init_weights(self, module):
def get_attention_mask(self, item_seq):
"""Generate left-to-right uni-directional attention mask for multi-head attention."""
attention_mask = (item_seq > 0).long()
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(
2
) # torch.int64
# mask for left-to-right unidirectional
max_len = attention_mask.size(-1)
attn_shape = (1, max_len, max_len)
Expand All @@ -89,12 +93,16 @@ def get_attention_mask(self, item_seq):
subsequent_mask = subsequent_mask.long().to(item_seq.device)

extended_attention_mask = extended_attention_mask * subsequent_mask
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype
) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask

def forward(self, item_seq, item_seq_len):
position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device)
position_ids = torch.arange(
item_seq.size(1), dtype=torch.long, device=item_seq.device
)
position_ids = position_ids.unsqueeze(0).expand_as(item_seq)
position_embedding = self.position_embedding(position_ids)

Expand All @@ -105,7 +113,9 @@ def forward(self, item_seq, item_seq_len):

extended_attention_mask = self.get_attention_mask(item_seq)

trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True)
trm_output = self.trm_encoder(
input_emb, extended_attention_mask, output_all_encoded_layers=True
)
output = trm_output[-1]
output = self.gather_indexes(output, item_seq_len - 1)
return output # [B H]
Expand All @@ -115,18 +125,20 @@ def calculate_loss(self, interaction):
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
pos_items = interaction[self.POS_ITEM_ID]
if self.loss_type == 'BPR':
if self.loss_type == "BPR":
neg_items = interaction[self.NEG_ITEM_ID]
pos_items_emb = self.item_embedding(pos_items)
neg_items_emb = self.item_embedding(neg_items)
pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B]
neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B]
loss = self.loss_fct(pos_score, neg_score)
else: # self.loss_type = 'CE'
test_item_emb = self.item_embedding.weight[:self.n_items] # unpad the augmentation mask
test_item_emb = self.item_embedding.weight[
: self.n_items
] # unpad the augmentation mask
logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1))
loss = self.loss_fct(logits, pos_items)

# # NCE
# aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = self.augment(item_seq, item_seq_len)
# # aug_item_seq1, aug_len1, aug_item_seq2, aug_len2 = \
Expand All @@ -136,15 +148,22 @@ def calculate_loss(self, interaction):

seq_output1 = self.forward(interaction["aug1"], interaction["aug1_len"])
seq_output2 = self.forward(interaction["aug2"], interaction["aug2_len"])

nce_logits, nce_labels = self.info_nce(seq_output1, seq_output2, temp=self.tau, batch_size=item_seq_len.shape[0], sim='dot')


nce_logits, nce_labels = self.info_nce(
seq_output1,
seq_output2,
temp=self.tau,
batch_size=item_seq_len.shape[0],
sim="dot",
)

nce_loss = self.nce_fct(nce_logits, nce_labels)

with torch.no_grad():
alignment, uniformity = self.decompose(seq_output1, seq_output2, seq_output,
batch_size=item_seq_len.shape[0])

alignment, uniformity = self.decompose(
seq_output1, seq_output2, seq_output, batch_size=item_seq_len.shape[0]
)

return loss + self.lmd * nce_loss, alignment, uniformity

def decompose(self, z_i, z_j, origin_z, batch_size):
Expand All @@ -153,27 +172,27 @@ def decompose(self, z_i, z_j, origin_z, batch_size):
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
N = 2 * batch_size

z = torch.cat((z_i, z_j), dim=0)

# pairwise l2 distace
sim = torch.cdist(z, z, p=2)

sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)

positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
alignment = positive_samples.mean()

# pairwise l2 distace
sim = torch.cdist(origin_z, origin_z, p=2)
mask = torch.ones((batch_size, batch_size), dtype=bool)
mask = mask.fill_diagonal_(0)
negative_samples = sim[mask].reshape(batch_size, -1)
uniformity = torch.log(torch.exp(-2 * negative_samples).mean())

return alignment, uniformity

def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N), dtype=bool)
Expand All @@ -182,35 +201,38 @@ def mask_correlated_samples(self, batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
return mask
def info_nce(self, z_i, z_j, temp, batch_size, sim='dot'):

def info_nce(self, z_i, z_j, temp, batch_size, sim="dot"):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
N = 2 * batch_size

z = torch.cat((z_i, z_j), dim=0)

if sim == 'cos':
sim = nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temp
elif sim == 'dot':

if sim == "cos":
sim = (
nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2)
/ temp
)
elif sim == "dot":
sim = torch.mm(z, z.T) / temp

sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)

positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
if batch_size != self.batch_size:
mask = self.mask_correlated_samples(batch_size)
else:
mask = self.mask_default
negative_samples = sim[mask].reshape(N, -1)

labels = torch.zeros(N).to(positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
return logits, labels

def predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
Expand All @@ -224,6 +246,8 @@ def full_sort_predict(self, interaction):
item_seq = interaction[self.ITEM_SEQ]
item_seq_len = interaction[self.ITEM_SEQ_LEN]
seq_output = self.forward(item_seq, item_seq_len)
test_items_emb = self.item_embedding.weight[:self.n_items] # unpad the augmentation mask
test_items_emb = self.item_embedding.weight[
: self.n_items
] # unpad the augmentation mask
scores = torch.matmul(seq_output, test_items_emb.transpose(0, 1)) # [B n_items]
return scores

0 comments on commit 4a3b320

Please sign in to comment.