From 3887abdfb08519de9cd08870d0a81ed697a01f30 Mon Sep 17 00:00:00 2001 From: zhkai <522351448@qq.com> Date: Sun, 30 Jul 2023 08:50:25 +0200 Subject: [PATCH] Update data_handler.py Kai --- data_handler.py | 78 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 67 insertions(+), 11 deletions(-) diff --git a/data_handler.py b/data_handler.py index 8f0b03f..289daaf 100644 --- a/data_handler.py +++ b/data_handler.py @@ -2,6 +2,8 @@ import pandas as pd import tensorflow as tf from load_data import load_data, smooth, transform +from sklearn.covariance import GraphicalLasso, EmpiricalCovariance +from sklearn import preprocessing class DataHandler: def __init__( @@ -12,7 +14,8 @@ def __init__( window_batch_size=10, window_shift=1, splits=[0.70, 0.15, 0.15], - predict_timestamp=3 + predict_timestamp=3, + num_per_group=5 ): """Create a DataHandler which is able to load and manipulate data in different ways.""" #self.is_normalized = False @@ -25,6 +28,7 @@ def __init__( self.window_shift = window_shift self.window_batch_size = window_batch_size # can find a best from 8 10 16 self.max_num_features = num_features + self.num_per_group = num_per_group self.predict_timestamp = predict_timestamp self.clusters = None self.clusters_func = None @@ -33,6 +37,8 @@ def __init__( self._load_data(config) self.use_splits(splits) self.clusters_abund, self.clusters_abund_size = self._make_abundance_clusters() + self.clusters_graph, self.clusters_graph_size = self._make_graph_clusters() + assert self.clusters_abund_size == self.clusters_graph_size @property def train(self): @@ -44,7 +50,7 @@ def train(self): @property def val(self): if self._train_val_index == self._val_test_index: - return self.test # if no val data, it should be test + return self.test elif self.clusters is None: return self._all.iloc[self._train_val_index:self._val_test_index, :self.max_num_features] else: @@ -74,17 +80,18 @@ def all_nontrans(self): @property def train_batched(self): """Batches of training data.""" - return self._make_batched_dataset(self.train, endindex=True) + return self._make_batched_dataset(self._all.iloc[:self._train_val_index+self.predict_timestamp, self.clusters], endindex=True) @property def val_batched(self): """Batches of validation data.""" - return self._make_batched_dataset(self.val, endindex=True) + return self._make_batched_dataset(self._all.iloc[self._train_val_index-self.predict_timestamp:self._val_test_index+self.predict_timestamp, + self.clusters], endindex=True) @property def test_batched(self): """Batches of test data.""" - return self._make_batched_dataset(self.test, endindex=True) + return self._make_batched_dataset(self._all.iloc[self._val_test_index-self.predict_timestamp:, self.clusters], endindex=True) @property def all_batched(self): @@ -130,13 +137,13 @@ def _make_batched_dataset(self, dataset, endindex): ) def _make_abundance_clusters(self): - clust = np.zeros(self.clusters_func.shape, dtype=int) - if self.max_num_features is None: + clust = np.zeros(self.max_num_features, dtype=int) + if self.num_per_group is None: return clust i = 0 c = 0 - while i < (clust.size - self.max_num_features): - for _ in range(self.max_num_features): + while i < (clust.size - self.num_per_group): + for _ in range(self.num_per_group): clust[i] = c i += 1 c += 1 @@ -146,6 +153,49 @@ def _make_abundance_clusters(self): c += 1 return clust, c + def _make_graph_clusters(self): + clust = np.zeros(self.max_num_features, dtype=int) + if self.num_per_group is None: + return clust + clust = clust - 1 + standsacle = preprocessing.StandardScaler() + x = self._all.iloc[:, :] + standsacle.fit(x[:]) + graph_train_data = standsacle.transform(x[:], copy=True) + try: + cov_init = GraphicalLasso(alpha=0.0001, mode='cd', max_iter=500, assume_centered=True).fit(graph_train_data) + except Exception as e: + print('EmpiricalCovariance precision_') + cov_init = EmpiricalCovariance(store_precision=True, assume_centered=True).fit(graph_train_data) + adj_mx = np.abs(cov_init.precision_) + d = np.array(adj_mx.sum(1)) + d_add = np.diag(d) + d = d * 2 + d_inv = np.power(d, -0.5).flatten() + d_inv[np.isinf(d_inv)] = 0. + d_mat_inv = np.diag(d_inv) + self.graph_matrix = d_mat_inv.dot(adj_mx+d_add).dot(d_mat_inv) + graph_matrix = self.graph_matrix.copy() + + i = 0 + c = 0 + while i < self.max_num_features: + if clust[i] != -1: + i += 1 + continue + clust[i] = c + temp = graph_matrix[i] + top_number = temp.argsort() + graph_matrix[:, i] = -1 + assert i in top_number[-self.num_per_group:] + for j in range(self.num_per_group): + if clust[top_number[-j-1]] == -1: + clust[top_number[-j-1]] = c + graph_matrix[:, top_number[-j-1]] = -1 + c += 1 + i += 1 + return clust, c + def use_cluster(self, number, cluster_type='abund'): """Cluster type is which type of cluster to use: abund: means that the x most abundant are in the first cluster, @@ -165,19 +215,21 @@ def use_cluster(self, number, cluster_type='abund'): elif cluster_type == 'abund': self.clusters = self.clusters_abund == number self._only_mark_first_max_num_features() + elif cluster_type == 'graph': + self.clusters = self.clusters_graph == number else: self.clusters = None raise Exception('Unknown cluster type.') def _only_mark_first_max_num_features(self): """Only use the x most abundant taxa in a given cluster.""" - if self.max_num_features is None: + if self.num_per_group is None: return count = 0 for i in range(self.clusters.size): if self.clusters[i]: count += 1 - if count > self.max_num_features: + if count > self.num_per_group: self.clusters[i] = False def use_splits(self, splits): @@ -197,6 +249,9 @@ def get_metadata(self, dataframe, attribute): def _load_data(self, config): data_raw, meta, func_tax, clusters_func, functions = load_data(config) + data_raw = data_raw[:self.max_num_features] + func_tax = func_tax[:self.max_num_features] + clusters_func = clusters_func[:self.max_num_features] data_smooth = smooth(data_raw, factor = config['smoothing_factor']) data_transformed, mean, std, min, max, transform_type = transform(data_smooth, transform = config['transform']) @@ -209,6 +264,7 @@ def _load_data(self, config): columns=func_tax[:,0]) self.data_raw = data_raw # N * T + self.data_transformed = data_transformed self.func_tax = func_tax self.meta = meta self.transform_mean = mean