From 0575c2e8ce8bce9bf63e1b58bbc0f4a37e41b1a0 Mon Sep 17 00:00:00 2001 From: TayTroye Date: Thu, 26 Oct 2023 15:32:15 +0000 Subject: [PATCH] Format Python code according to PEP8 --- .../model/sequential_recommender/fearec.py | 434 ++++++++++++------ tests/model/test_model_auto.py | 2 +- 2 files changed, 287 insertions(+), 149 deletions(-) diff --git a/recbole/model/sequential_recommender/fearec.py b/recbole/model/sequential_recommender/fearec.py index cc3297de5..6c0885ccd 100644 --- a/recbole/model/sequential_recommender/fearec.py +++ b/recbole/model/sequential_recommender/fearec.py @@ -11,6 +11,7 @@ from recbole.model.abstract_recommender import SequentialRecommender + # from recbole.model.layers import FEAEncoder #考虑把encoder放进来 from recbole.model.loss import BPRLoss from recbole.data.interaction import Interaction @@ -23,24 +24,28 @@ def __init__(self, config, dataset): # load parameters info self.dataset = dataset self.config = config - self.n_layers = config['n_layers'] - self.n_heads = config['n_heads'] - self.hidden_size = config['hidden_size'] # same as embedding_size - self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer - self.hidden_dropout_prob = config['hidden_dropout_prob'] - self.attn_dropout_prob = config['attn_dropout_prob'] - self.hidden_act = config['hidden_act'] - self.layer_norm_eps = config['layer_norm_eps'] - - self.lmd = config['lmd'] - self.lmd_sem = config['lmd_sem'] - - self.initializer_range = config['initializer_range'] - self.loss_type = config['loss_type'] + self.n_layers = config["n_layers"] + self.n_heads = config["n_heads"] + self.hidden_size = config["hidden_size"] # same as embedding_size + self.inner_size = config[ + "inner_size" + ] # the dimensionality in feed-forward layer + self.hidden_dropout_prob = config["hidden_dropout_prob"] + self.attn_dropout_prob = config["attn_dropout_prob"] + self.hidden_act = config["hidden_act"] + self.layer_norm_eps = config["layer_norm_eps"] + + self.lmd = config["lmd"] + self.lmd_sem = config["lmd_sem"] + + self.initializer_range = config["initializer_range"] + self.loss_type = config["loss_type"] self.same_item_index = self.get_same_item_index(dataset) # define layers and loss - self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0) + self.item_embedding = nn.Embedding( + self.n_items, self.hidden_size, padding_idx=0 + ) self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size) self.item_encoder = FEAEncoder( n_layers=self.n_layers, @@ -57,19 +62,19 @@ def __init__(self, config, dataset): self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.dropout = nn.Dropout(self.hidden_dropout_prob) - if self.loss_type == 'BPR': + if self.loss_type == "BPR": self.loss_fct = BPRLoss() - elif self.loss_type == 'CE': + elif self.loss_type == "CE": self.loss_fct = nn.CrossEntropyLoss() else: raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") - self.ssl = config['contrast'] - self.tau = config['tau'] - self.sim = config['sim'] - self.fredom = config['fredom'] - self.fredom_type = config['fredom_type'] - self.batch_size = config['train_batch_size'] + self.ssl = config["contrast"] + self.tau = config["tau"] + self.sim = config["sim"] + self.fredom = config["fredom"] + self.fredom_type = config["fredom_type"] + self.batch_size = config["train_batch_size"] self.mask_default = self.mask_correlated_samples(batch_size=self.batch_size) self.aug_nce_fct = nn.CrossEntropyLoss() self.sem_aug_nce_fct = nn.CrossEntropyLoss() @@ -77,13 +82,13 @@ def __init__(self, config, dataset): # parameters initialization self.apply(self._init_weights) - - def get_same_item_index(self,dataset): + def get_same_item_index(self, dataset): same_target_index = {} - aug_path = dataset.config['data_path'] + '/semantic_augmentationindex.pkl' + aug_path = dataset.config["data_path"] + "/semantic_augmentationindex.pkl" import os + if os.path.exists(aug_path): - with open(aug_path, 'rb') as file: + with open(aug_path, "rb") as file: same_target_index = pickle.load(file) file.close() else: @@ -100,14 +105,14 @@ def get_same_item_index(self,dataset): count += 1 # print(item_id,"index is",same_target_index[item_id]) - with open(aug_path, 'wb') as file: - pickle.dump(same_target_index,file) + with open(aug_path, "wb") as file: + pickle.dump(same_target_index, file) file.close() return same_target_index def _init_weights(self, module): - """ Initialize the weights """ + """Initialize the weights""" if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 @@ -134,7 +139,9 @@ def get_attention_mask(self, item_seq): """ 最终,这个方法返回一个用于多头注意力的 attention mask,确保了在每个位置上, 只能注意到该位置之前的位置。这有助于模型按照时间顺序逐步生成输出,过滤掉未来的信息 """ attention_mask = (item_seq > 0).long() - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze( + 2 + ) # torch.int64 # mask for left-to-right unidirectional max_len = attention_mask.size(-1) attn_shape = (1, max_len, max_len) @@ -143,25 +150,32 @@ def get_attention_mask(self, item_seq): subsequent_mask = subsequent_mask.long().to(item_seq.device) extended_attention_mask = extended_attention_mask * subsequent_mask - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def get_bi_attention_mask(self, item_seq): """Generate bidirectional attention mask for multi-head attention.""" attention_mask = (item_seq > 0).long() - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # torch.int64 + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze( + 2 + ) # torch.int64 # bidirectional mask - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype + ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask def forward(self, item_seq, item_seq_len): - position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device) + position_ids = torch.arange( + item_seq.size(1), dtype=torch.long, device=item_seq.device + ) position_ids = position_ids.unsqueeze(0).expand_as(item_seq) position_embedding = self.position_embedding(position_ids) - item_emb = self.item_embedding(item_seq) input_emb = item_emb + position_embedding input_emb = self.LayerNorm(input_emb) @@ -170,9 +184,9 @@ def forward(self, item_seq, item_seq_len): extended_attention_mask = self.get_attention_mask(item_seq) # extended_attention_mask = self.get_bi_attention_mask(item_seq) - - - trm_output = self.item_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True) + trm_output = self.item_encoder( + input_emb, extended_attention_mask, output_all_encoded_layers=True + ) output = trm_output[-1] output = self.gather_indexes(output, item_seq_len - 1) @@ -189,10 +203,7 @@ def uniformity(x): x = abs(x) return torch.pdist(x, p=2).pow(2).mul(-2).exp().mean().log() - - def calculate_loss(self, interaction): - same_item_index = self.same_item_index sem_pos_lengths = [] sem_pos_seqs = [] @@ -206,26 +217,27 @@ def calculate_loss(self, interaction): print("error") while True: sample_index = random.choice(targets_index) - cur_item_list = interaction[self.ITEM_SEQ][i].to('cpu') - sample_item_list = dataset.inter_feat['item_id_list'][sample_index] - are_equal = torch.equal(cur_item_list,sample_item_list) - sample_item_length = dataset.inter_feat['item_length'][sample_index] + cur_item_list = interaction[self.ITEM_SEQ][i].to("cpu") + sample_item_list = dataset.inter_feat["item_id_list"][sample_index] + are_equal = torch.equal(cur_item_list, sample_item_list) + sample_item_length = dataset.inter_feat["item_length"][sample_index] if not are_equal or lens == 1: sem_pos_lengths.append(sample_item_length) sem_pos_seqs.append(sample_item_list) - break; + break sem_pos_lengths = torch.stack(sem_pos_lengths).to(self.device) sem_pos_seqs = torch.stack(sem_pos_seqs).to(self.device) - interaction.update(Interaction({'sem_aug': sem_pos_seqs, 'sem_aug_lengths': sem_pos_lengths})) - + interaction.update( + Interaction({"sem_aug": sem_pos_seqs, "sem_aug_lengths": sem_pos_lengths}) + ) item_seq = interaction[self.ITEM_SEQ] item_seq_len = interaction[self.ITEM_SEQ_LEN] seq_output = self.forward(item_seq, item_seq_len) pos_items = interaction[self.POS_ITEM_ID] - if self.loss_type == 'BPR': + if self.loss_type == "BPR": neg_items = interaction[self.NEG_ITEM_ID] pos_items_emb = self.item_embedding(pos_items) neg_items_emb = self.item_embedding(neg_items) @@ -237,12 +249,16 @@ def calculate_loss(self, interaction): logits = torch.matmul(seq_output, test_item_emb.transpose(0, 1)) loss = self.loss_fct(logits, pos_items) - # Unsupervised NCE - if self.ssl in ['us', 'un']: + if self.ssl in ["us", "un"]: aug_seq_output = self.forward(item_seq, item_seq_len) - nce_logits, nce_labels = self.info_nce(seq_output, aug_seq_output, temp=self.tau, - batch_size=item_seq_len.shape[0], sim=self.sim) + nce_logits, nce_labels = self.info_nce( + seq_output, + aug_seq_output, + temp=self.tau, + batch_size=item_seq_len.shape[0], + sim=self.sim, + ) # nce_logits = torch.mm(seq_output, aug_seq_output.T) # nce_labels = torch.tensor(list(range(nce_logits.shape[0])), dtype=torch.long, device=item_seq.device) @@ -255,12 +271,20 @@ def calculate_loss(self, interaction): loss += self.lmd * self.aug_nce_fct(nce_logits, nce_labels) # Supervised NCE - if self.ssl in ['us', 'su']: - sem_aug, sem_aug_lengths = interaction['sem_aug'], interaction['sem_aug_lengths'] + if self.ssl in ["us", "su"]: + sem_aug, sem_aug_lengths = ( + interaction["sem_aug"], + interaction["sem_aug_lengths"], + ) sem_aug_seq_output = self.forward(sem_aug, sem_aug_lengths) - sem_nce_logits, sem_nce_labels = self.info_nce(seq_output, sem_aug_seq_output, temp=self.tau, - batch_size=item_seq_len.shape[0], sim=self.sim) + sem_nce_logits, sem_nce_labels = self.info_nce( + seq_output, + sem_aug_seq_output, + temp=self.tau, + batch_size=item_seq_len.shape[0], + sim=self.sim, + ) # sem_nce_logits = torch.mm(seq_output, sem_aug_seq_output.T) / self.tau # sem_nce_labels = torch.tensor(list(range(sem_nce_logits.shape[0])), dtype=torch.long, device=item_seq.device) @@ -271,13 +295,21 @@ def calculate_loss(self, interaction): loss += self.lmd_sem * self.aug_nce_fct(sem_nce_logits, sem_nce_labels) - if self.ssl == 'us_x': + if self.ssl == "us_x": aug_seq_output = self.forward(item_seq, item_seq_len) - sem_aug, sem_aug_lengths = interaction['sem_aug'], interaction['sem_aug_lengths'] + sem_aug, sem_aug_lengths = ( + interaction["sem_aug"], + interaction["sem_aug_lengths"], + ) sem_aug_seq_output = self.forward(sem_aug, sem_aug_lengths) - sem_nce_logits, sem_nce_labels = self.info_nce(aug_seq_output, sem_aug_seq_output, temp=self.tau, - batch_size=item_seq_len.shape[0], sim=self.sim) + sem_nce_logits, sem_nce_labels = self.info_nce( + aug_seq_output, + sem_aug_seq_output, + temp=self.tau, + batch_size=item_seq_len.shape[0], + sim=self.sim, + ) loss += self.lmd_sem * self.aug_nce_fct(sem_nce_logits, sem_nce_labels) # with torch.no_grad(): @@ -286,15 +318,22 @@ def calculate_loss(self, interaction): # frequency domain loss if self.fredom: - seq_output_f = torch.fft.rfft(seq_output, dim=1, norm='ortho') - aug_seq_output_f = torch.fft.rfft(aug_seq_output, dim=1, norm='ortho') - sem_aug_seq_output_f = torch.fft.rfft(sem_aug_seq_output, dim=1, norm='ortho') - if self.fredom_type in ['us', 'un']: + seq_output_f = torch.fft.rfft(seq_output, dim=1, norm="ortho") + aug_seq_output_f = torch.fft.rfft(aug_seq_output, dim=1, norm="ortho") + sem_aug_seq_output_f = torch.fft.rfft( + sem_aug_seq_output, dim=1, norm="ortho" + ) + if self.fredom_type in ["us", "un"]: loss += 0.1 * abs(seq_output_f - aug_seq_output_f).flatten().mean() - if self.fredom_type in ['us', 'su']: - loss += 0.1 * abs(seq_output_f - sem_aug_seq_output_f).flatten().mean() - if self.fredom_type == 'us_x': - loss += 0.1 * abs(aug_seq_output_f - sem_aug_seq_output_f).flatten().mean() + if self.fredom_type in ["us", "su"]: + loss += ( + 0.1 * abs(seq_output_f - sem_aug_seq_output_f).flatten().mean() + ) + if self.fredom_type == "us_x": + loss += ( + 0.1 + * abs(aug_seq_output_f - sem_aug_seq_output_f).flatten().mean() + ) return loss @@ -307,7 +346,7 @@ def mask_correlated_samples(self, batch_size): mask[batch_size + i, i] = 0 return mask - def info_nce(self, z_i, z_j, temp, batch_size, sim='dot'): + def info_nce(self, z_i, z_j, temp, batch_size, sim="dot"): """ We do not sample negative examples explicitly. Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples. @@ -316,9 +355,12 @@ def info_nce(self, z_i, z_j, temp, batch_size, sim='dot'): z = torch.cat((z_i, z_j), dim=0) - if sim == 'cos': - sim = nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) / temp - elif sim == 'dot': + if sim == "cos": + sim = ( + nn.functional.cosine_similarity(z.unsqueeze(1), z.unsqueeze(0), dim=2) + / temp + ) + elif sim == "dot": sim = torch.mm(z, z.T) / temp sim_i_j = torch.diag(sim, batch_size) @@ -380,7 +422,6 @@ def full_sort_predict(self, interaction): return scores - class HybridAttention(nn.Module): """ Hybrid Attention layer: combine time domain self-attention layer and frequency domain attention layer. @@ -394,7 +435,16 @@ class HybridAttention(nn.Module): """ - def __init__(self, n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, layer_norm_eps, i, config): + def __init__( + self, + n_heads, + hidden_size, + hidden_dropout_prob, + attn_dropout_prob, + layer_norm_eps, + i, + config, + ): super(HybridAttention, self).__init__() if hidden_size % n_heads != 0: raise ValueError( @@ -402,7 +452,7 @@ def __init__(self, n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, "heads (%d)" % (hidden_size, n_heads) ) - self.factor = config['topk_factor'] + self.factor = config["topk_factor"] self.scale = None self.mask_flag = True self.output_attention = False @@ -419,29 +469,45 @@ def __init__(self, n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.out_dropout = nn.Dropout(hidden_dropout_prob) self.filter_mixer = None - self.global_ratio = config['global_ratio'] - self.n_layers = config['n_layers'] + self.global_ratio = config["global_ratio"] + self.n_layers = config["n_layers"] if self.global_ratio > (1 / self.n_layers): - print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers))) - self.filter_mixer = 'G' + print( + "{}>{}:{}".format( + self.global_ratio, + 1 / self.n_layers, + self.global_ratio > (1 / self.n_layers), + ) + ) + self.filter_mixer = "G" else: - print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers))) - self.filter_mixer = 'L' - self.max_item_list_length = config['MAX_ITEM_LIST_LENGTH'] - self.dual_domain = config['dual_domain'] - self.slide_step = ((self.max_item_list_length // 2 + 1) * (1 - self.global_ratio)) // (self.n_layers - 1) + print( + "{}>{}:{}".format( + self.global_ratio, + 1 / self.n_layers, + self.global_ratio > (1 / self.n_layers), + ) + ) + self.filter_mixer = "L" + self.max_item_list_length = config["MAX_ITEM_LIST_LENGTH"] + self.dual_domain = config["dual_domain"] + self.slide_step = ( + (self.max_item_list_length // 2 + 1) * (1 - self.global_ratio) + ) // (self.n_layers - 1) self.local_ratio = 1 / self.n_layers self.filter_size = self.local_ratio * (self.max_item_list_length // 2 + 1) - if self.filter_mixer == 'G': + if self.filter_mixer == "G": self.w = self.global_ratio self.s = self.slide_step - if self.filter_mixer == 'L': + if self.filter_mixer == "L": self.w = self.local_ratio self.s = self.filter_size - self.left = int(((self.max_item_list_length // 2 + 1) * (1 - self.w)) - (i * self.s)) + self.left = int( + ((self.max_item_list_length // 2 + 1) * (1 - self.w)) - (i * self.s) + ) self.right = int((self.max_item_list_length // 2 + 1) - i * self.s) # random: @@ -453,7 +519,7 @@ def __init__(self, n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, self.k_index = list(range(self.left, self.right)) self.v_index = list(range(self.left, self.right)) # if sample in time domain - self.std = config['std'] + self.std = config["std"] if self.std: self.time_q_index = self.q_index self.time_k_index = self.k_index @@ -463,15 +529,18 @@ def __init__(self, n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, self.time_k_index = list(range(self.max_item_list_length // 2 + 1)) self.time_v_index = list(range(self.max_item_list_length // 2 + 1)) - print('modes_q={}, index_q={}'.format(len(self.q_index), self.q_index)) - print('modes_k={}, index_k={}'.format(len(self.k_index), self.k_index)) - print('modes_v={}, index_v={}'.format(len(self.v_index), self.v_index)) + print("modes_q={}, index_q={}".format(len(self.q_index), self.q_index)) + print("modes_k={}, index_k={}".format(len(self.k_index), self.k_index)) + print("modes_v={}, index_v={}".format(len(self.v_index), self.v_index)) - if self.config['dual_domain']: - self.spatial_ratio = self.config['spatial_ratio'] + if self.config["dual_domain"]: + self.spatial_ratio = self.config["spatial_ratio"] def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) x = x.view(*new_x_shape) # [256, 50, 2, 32] # return x.permute(0, 2, 1, 3) # [256, 2, 50, 32] return x @@ -496,8 +565,13 @@ def time_delay_agg_training(self, values, corr): delays_agg = torch.zeros_like(values).float() for i in range(top_k): pattern = torch.roll(tmp_values, -int(index[i]), -1) - delays_agg = delays_agg + pattern * \ - (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + delays_agg = delays_agg + pattern * ( + tmp_corr[:, i] + .unsqueeze(1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, head, channel, length) + ) return delays_agg def time_delay_agg_inference(self, values, corr): @@ -510,8 +584,14 @@ def time_delay_agg_inference(self, values, corr): channel = values.shape[2] length = values.shape[3] # index init - init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0) \ - .repeat(batch, head, channel, 1).to(values.device) + init_index = ( + torch.arange(length) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0) + .repeat(batch, head, channel, 1) + .to(values.device) + ) # find top k top_k = int(self.factor * math.log(length)) mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) @@ -522,10 +602,17 @@ def time_delay_agg_inference(self, values, corr): tmp_values = values.repeat(1, 1, 1, 2) delays_agg = torch.zeros_like(values).float() for i in range(top_k): - tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze( + 1 + ).repeat(1, head, channel, length) pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) - delays_agg = delays_agg + pattern * \ - (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + delays_agg = delays_agg + pattern * ( + tmp_corr[:, i] + .unsqueeze(1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, head, channel, length) + ) return delays_agg def forward(self, input_tensor, attention_mask): @@ -533,20 +620,19 @@ def forward(self, input_tensor, attention_mask): mixed_key_layer = self.key_layer(input_tensor) mixed_value_layer = self.value_layer(input_tensor) - - #trans挺快,query + # trans挺快,query queries = self.transpose_for_scores(mixed_query_layer) keys = self.transpose_for_scores(mixed_key_layer) values = self.transpose_for_scores(mixed_value_layer) - #这段代码是注意力机制中的一部分,涉及到对查询(queries)和键(keys)进行频域变换(FFT),以应用频域注意力 + # 这段代码是注意力机制中的一部分,涉及到对查询(queries)和键(keys)进行频域变换(FFT),以应用频域注意力 # B, H, L, E = query_layer.shape # AutoFormer B, L, H, E = queries.shape # print("qqqq", queries.shape) [256,50,2,32,] _, S, _, D = values.shape if L > S: - zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + zeros = torch.zeros_like(queries[:, : (L - S), :]).float() values = torch.cat([values, zeros], dim=1) keys = torch.cat([keys, zeros], dim=1) else: @@ -557,32 +643,39 @@ def forward(self, input_tensor, attention_mask): q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) - # 装到采样的空盒子里 - q_fft_box = torch.zeros(B, H, E, len(self.q_index), device=q_fft.device, dtype=torch.cfloat) + q_fft_box = torch.zeros( + B, H, E, len(self.q_index), device=q_fft.device, dtype=torch.cfloat + ) q_fft_box = q_fft[:, :, :, self.q_index] - k_fft_box = torch.zeros(B, H, E, len(self.k_index), device=q_fft.device, dtype=torch.cfloat) + k_fft_box = torch.zeros( + B, H, E, len(self.k_index), device=q_fft.device, dtype=torch.cfloat + ) k_fft_box = k_fft[:, :, :, self.q_index] res = q_fft_box * torch.conj(k_fft_box) - if self.config['use_filter']: + if self.config["use_filter"]: # filter weight = torch.view_as_complex(self.complex_weight) res = res * weight - box_res = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat) + box_res = torch.zeros( + B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat + ) box_res[:, :, :, self.q_index] = res corr = torch.fft.irfft(box_res, dim=-1) - # time delay agg if self.training: - V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + V = self.time_delay_agg_training( + values.permute(0, 2, 3, 1).contiguous(), corr + ).permute(0, 3, 1, 2) else: - V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) - + V = self.time_delay_agg_inference( + values.permute(0, 2, 3, 1).contiguous(), corr + ).permute(0, 3, 1, 2) # print(V.shape) new_context_layer_shape = V.size()[:-2] + (self.all_head_size,) @@ -591,28 +684,37 @@ def forward(self, input_tensor, attention_mask): if self.dual_domain: # 装到采样的空盒子里 # q - q_fft_box = torch.zeros(B, H, E, len(self.time_q_index), device=q_fft.device, dtype=torch.cfloat) + q_fft_box = torch.zeros( + B, H, E, len(self.time_q_index), device=q_fft.device, dtype=torch.cfloat + ) q_fft_box = q_fft[:, :, :, self.time_q_index] - spatial_q = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat) + spatial_q = torch.zeros( + B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat + ) spatial_q[:, :, :, self.time_q_index] = q_fft_box # k - k_fft_box = torch.zeros(B, H, E, len(self.time_k_index), device=q_fft.device, dtype=torch.cfloat) + k_fft_box = torch.zeros( + B, H, E, len(self.time_k_index), device=q_fft.device, dtype=torch.cfloat + ) k_fft_box = k_fft[:, :, :, self.time_k_index] - spatial_k = torch.zeros(B, H, E, L // 2 + 1, device=k_fft.device, dtype=torch.cfloat) + spatial_k = torch.zeros( + B, H, E, L // 2 + 1, device=k_fft.device, dtype=torch.cfloat + ) spatial_k[:, :, :, self.time_k_index] = k_fft_box - # v v_fft = torch.fft.rfft(values.permute(0, 2, 3, 1).contiguous(), dim=-1) # 装到采样的空盒子里 - v_fft_box = torch.zeros(B, H, E, len(self.time_v_index), device=v_fft.device, dtype=torch.cfloat) + v_fft_box = torch.zeros( + B, H, E, len(self.time_v_index), device=v_fft.device, dtype=torch.cfloat + ) v_fft_box = v_fft[:, :, :, self.time_v_index] - spatial_v = torch.zeros(B, H, E, L // 2 + 1, device=v_fft.device, dtype=torch.cfloat) + spatial_v = torch.zeros( + B, H, E, L // 2 + 1, device=v_fft.device, dtype=torch.cfloat + ) spatial_v[:, :, :, self.time_v_index] = v_fft_box - - queries = torch.fft.irfft(spatial_q, dim=-1) keys = torch.fft.irfft(spatial_k, dim=-1) values = torch.fft.irfft(spatial_v, dim=-1) @@ -629,16 +731,20 @@ def forward(self, input_tensor, attention_mask): attention_probs = self.attn_dropout(attention_probs) qkv = torch.matmul(attention_probs, values) # [256, 2, index, 32] context_layer_spatial = qkv.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer_spatial.size()[:-2] + (self.all_head_size,) + new_context_layer_shape = context_layer_spatial.size()[:-2] + ( + self.all_head_size, + ) context_layer_spatial = context_layer_spatial.view(*new_context_layer_shape) - context_layer = (1 - self.spatial_ratio) * context_layer + self.spatial_ratio * context_layer_spatial - + context_layer = ( + 1 - self.spatial_ratio + ) * context_layer + self.spatial_ratio * context_layer_spatial hidden_states = self.dense(context_layer) hidden_states = self.out_dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states + class FeedForward(nn.Module): """ Point-wise feed-forward layer is implemented by two dense layers. @@ -651,7 +757,9 @@ class FeedForward(nn.Module): """ - def __init__(self, hidden_size, inner_size, hidden_dropout_prob, hidden_act, layer_norm_eps): + def __init__( + self, hidden_size, inner_size, hidden_dropout_prob, hidden_act, layer_norm_eps + ): super(FeedForward, self).__init__() self.dense_1 = nn.Linear(hidden_size, inner_size) self.intermediate_act_fn = self.get_hidden_act(hidden_act) @@ -694,6 +802,7 @@ def forward(self, input_tensor): return hidden_states + class FEABlock(nn.Module): """ One transformer layer consists of a multi-head self-attention layer and a point-wise feed-forward layer. @@ -709,34 +818,54 @@ class FEABlock(nn.Module): """ def __init__( - self, n_heads, hidden_size, intermediate_size, hidden_dropout_prob, attn_dropout_prob, hidden_act, - layer_norm_eps, n, config + self, + n_heads, + hidden_size, + intermediate_size, + hidden_dropout_prob, + attn_dropout_prob, + hidden_act, + layer_norm_eps, + n, + config, ): super(FEABlock, self).__init__() self.hybrid_attention = HybridAttention( - n_heads, hidden_size, hidden_dropout_prob, attn_dropout_prob, layer_norm_eps, n, config + n_heads, + hidden_size, + hidden_dropout_prob, + attn_dropout_prob, + layer_norm_eps, + n, + config, + ) + self.feed_forward = FeedForward( + hidden_size, + intermediate_size, + hidden_dropout_prob, + hidden_act, + layer_norm_eps, ) - self.feed_forward = FeedForward(hidden_size, intermediate_size, hidden_dropout_prob, hidden_act, layer_norm_eps) def forward(self, hidden_states, attention_mask): - attention_output = self.hybrid_attention(hidden_states, attention_mask) feedforward_output = self.feed_forward(attention_output) return feedforward_output + class FEAEncoder(nn.Module): - r""" One TransformerEncoder consists of several TransformerLayers. - - - n_layers(num): num of transformer layers in transformer encoder. Default: 2 - - n_heads(num): num of attention heads for multi-head attention layer. Default: 2 - - hidden_size(num): the input and output hidden size. Default: 64 - - inner_size(num): the dimensionality in feed-forward layer. Default: 256 - - hidden_dropout_prob(float): probability of an element to be zeroed. Default: 0.5 - - attn_dropout_prob(float): probability of an attention score to be zeroed. Default: 0.5 - - hidden_act(str): activation function in feed-forward layer. Default: 'gelu' - candidates: 'gelu', 'relu', 'swish', 'tanh', 'sigmoid' - - layer_norm_eps(float): a value added to the denominator for numerical stability. Default: 1e-12 + r"""One TransformerEncoder consists of several TransformerLayers. + + - n_layers(num): num of transformer layers in transformer encoder. Default: 2 + - n_heads(num): num of attention heads for multi-head attention layer. Default: 2 + - hidden_size(num): the input and output hidden size. Default: 64 + - inner_size(num): the dimensionality in feed-forward layer. Default: 256 + - hidden_dropout_prob(float): probability of an element to be zeroed. Default: 0.5 + - attn_dropout_prob(float): probability of an attention score to be zeroed. Default: 0.5 + - hidden_act(str): activation function in feed-forward layer. Default: 'gelu' + candidates: 'gelu', 'relu', 'swish', 'tanh', 'sigmoid' + - layer_norm_eps(float): a value added to the denominator for numerical stability. Default: 1e-12 """ @@ -748,16 +877,25 @@ def __init__( inner_size=256, hidden_dropout_prob=0.5, attn_dropout_prob=0.5, - hidden_act='gelu', + hidden_act="gelu", layer_norm_eps=1e-12, config=None, ): - super(FEAEncoder, self).__init__() self.n_layers = n_layers self.layer = nn.ModuleList() for n in range(self.n_layers): - self.layer_ramp = FEABlock(n_heads, hidden_size, inner_size, hidden_dropout_prob, attn_dropout_prob, hidden_act, layer_norm_eps, n, config) + self.layer_ramp = FEABlock( + n_heads, + hidden_size, + inner_size, + hidden_dropout_prob, + attn_dropout_prob, + hidden_act, + layer_norm_eps, + n, + config, + ) self.layer.append(self.layer_ramp) def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 5ea569e15..310f7c158 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -754,7 +754,7 @@ def test_core_ave(self): "dnn_type": "ave", } quick_test(config_dict) - + def test_fea_rec(self): config_dict = { "model": "FEARec",