From 4f6396fe0b2357455e35a3756f8dce15dee4c07d Mon Sep 17 00:00:00 2001 From: taorann Date: Sat, 16 Nov 2024 00:33:33 +0800 Subject: [PATCH] about FatraGNN, including model,datasets and example (#218) * about FatraGNN, including model,datasets and example * update * Revert "update" This reverts commit fa9101b527dada348008f78353cc85e428f73f85. * update * Modifications as required * update github action --------- Co-authored-by: Guangyu Zhou <77875480+gyzhou2000@users.noreply.github.com> Co-authored-by: gyzhou2000 --- .github/workflows/test_push.yml | 4 + .github/workflows/test_pypi_package.yml | 4 + examples/fatragnn/config.yaml | 24 ++ examples/fatragnn/fatragnn_trainer.py | 293 ++++++++++++++++++++++++ examples/fatragnn/readme.md | 63 +++++ gammagl/datasets/__init__.py | 4 + gammagl/datasets/bail.py | 215 +++++++++++++++++ gammagl/datasets/credit.py | 215 +++++++++++++++++ gammagl/models/__init__.py | 4 +- gammagl/models/fatragnn.py | 189 +++++++++++++++ 10 files changed, 1014 insertions(+), 1 deletion(-) create mode 100644 examples/fatragnn/config.yaml create mode 100644 examples/fatragnn/fatragnn_trainer.py create mode 100644 examples/fatragnn/readme.md create mode 100644 gammagl/datasets/bail.py create mode 100644 gammagl/datasets/credit.py create mode 100644 gammagl/models/fatragnn.py diff --git a/.github/workflows/test_push.yml b/.github/workflows/test_push.yml index 9f9e2afa..d3884cda 100644 --- a/.github/workflows/test_push.yml +++ b/.github/workflows/test_push.yml @@ -38,6 +38,10 @@ jobs: run: | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 + - name: Install Tensorflow + run: | + pip install tensorflow==2.11.0 + - name: Install llvmlite run: | pip install llvmlite diff --git a/.github/workflows/test_pypi_package.yml b/.github/workflows/test_pypi_package.yml index 962df92e..b2549453 100644 --- a/.github/workflows/test_pypi_package.yml +++ b/.github/workflows/test_pypi_package.yml @@ -30,6 +30,10 @@ jobs: run: | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 + - name: Install Tensorflow + run: | + pip install tensorflow==2.11.0 + - name: Install llvmlite run: | pip install llvmlite diff --git a/examples/fatragnn/config.yaml b/examples/fatragnn/config.yaml new file mode 100644 index 00000000..3ad95fad --- /dev/null +++ b/examples/fatragnn/config.yaml @@ -0,0 +1,24 @@ +bail: + epochs: 400 + g_epochs: 5 + a_epochs: 4 + cla_epochs: 10 + dic_epochs: 8 + dtb_epochs: 5 + d_lr: 0.001 + c_lr: 0.005 + e_lr: 0.005 + g_lr: 0.05 + drope_rate: 0.1 +credit: + epochs: 600 + g_epochs: 5 + a_epochs: 2 + cla_epochs: 12 + dic_epochs: 5 + dtb_epochs: 5 + d_lr: 0.001 + c_lr: 0.01 + e_lr: 0.01 + g_lr: 0.05 + drope_rate: 0.1 \ No newline at end of file diff --git a/examples/fatragnn/fatragnn_trainer.py b/examples/fatragnn/fatragnn_trainer.py new file mode 100644 index 00000000..71f6bbe5 --- /dev/null +++ b/examples/fatragnn/fatragnn_trainer.py @@ -0,0 +1,293 @@ +import os +# os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ['TL_BACKEND'] = 'torch' +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR +import tensorlayerx as tlx +from gammagl.models import FatraGNNModel +import argparse +import numpy as np +from tensorlayerx.model import TrainOneStep, WithLoss +from sklearn.metrics import roc_auc_score +import scipy.sparse as sp +import yaml +from gammagl.datasets import Bail +from gammagl.datasets import Credit + + +def fair_metric(pred, labels, sens): + idx_s0 = sens == 0 + idx_s1 = sens == 1 + idx_s0_y1 = np.bitwise_and(idx_s0, labels == 1) + idx_s1_y1 = np.bitwise_and(idx_s1, labels == 1) + parity = abs(sum(pred[idx_s0]) / sum(idx_s0) - + sum(pred[idx_s1]) / sum(idx_s1)) + equality = abs(sum(pred[idx_s0_y1]) / sum(idx_s0_y1) - + sum(pred[idx_s1_y1]) / sum(idx_s1_y1)) + return parity.item(), equality.item() + + +def evaluate_ged3(net, x, edge_index, y, test_mask, sens): + net.set_eval() + flag = 0 + output = net(x, edge_index, flag) + pred_test = tlx.cast(tlx.squeeze(output[test_mask], axis=-1) > 0, y.dtype) + + acc_nums_test = (pred_test == y[test_mask]) + accs = np.sum(tlx.convert_to_numpy(acc_nums_test))/np.sum(tlx.convert_to_numpy(test_mask)) + + auc_rocs = roc_auc_score(tlx.convert_to_numpy(y[test_mask]), tlx.convert_to_numpy(output[test_mask])) + paritys, equalitys = fair_metric(tlx.convert_to_numpy(pred_test), tlx.convert_to_numpy(y[test_mask]), tlx.convert_to_numpy(sens[test_mask])) + + return accs, auc_rocs, paritys, equalitys + + +class DicLoss(WithLoss): + def __init__(self, net, loss_fn): + super(DicLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + output = self.backbone_network(data['x'], data['edge_index'], data['flag']) + loss = tlx.losses.binary_cross_entropy(tlx.squeeze(output, axis=-1), tlx.cast(data['sens'], dtype=tlx.float32)) + return loss + + +class EncClaLoss(WithLoss): + def __init__(self, net, loss_fn): + super(EncClaLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + output = self.backbone_network(data['x'], data['edge_index'], data['flag']) + y_train = tlx.cast(tlx.expand_dims(label[data['train_mask']], axis=1), dtype=tlx.float32) + loss = tlx.losses.binary_cross_entropy(output[data['train_mask']], y_train) + return loss + + +class EncLoss(WithLoss): + def __init__(self, net, loss_fn): + super(EncLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + output = self.backbone_network(data['x'], data['edge_index'], data['flag']) + loss = tlx.losses.mean_squared_error(output, 0.5 * tlx.ones_like(output)) + return loss + + +class EdtLoss(WithLoss): + def __init__(self, net, loss_fn): + super(EdtLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + output = self.backbone_network(data['x'], data['edge_index'], data['flag']) + loss = -tlx.abs(tlx.reduce_sum(output[data['train_mask']][data['t_idx_s0_y1']])) / tlx.reduce_sum(tlx.cast(data['t_idx_s0_y1'], dtype=tlx.float32)) - tlx.reduce_sum(output[data['train_mask']][data['t_idx_s1_y1']]) / tlx.reduce_sum(tlx.cast(data['t_idx_s1_y1'], dtype=tlx.float32)) + + return loss + + +class AliLoss(WithLoss): + def __init__(self, net, loss_fn): + super(AliLoss, self).__init__(backbone=net, loss_fn=loss_fn) + + def forward(self, data, label): + output = self.backbone_network(data['x'], data['edge_index'], data['flag']) + h1 = output['h1'] + h2 = output['h2'] + idx_s0_y0 = data['idx_s0_y0'] + idx_s1_y0 = data['idx_s1_y0'] + idx_s0_y1 = data['idx_s0_y1'] + idx_s1_y1 = data['idx_s1_y1'] + node_num = data['x'].shape[0] + loss_align = - node_num / (tlx.reduce_sum(tlx.cast(idx_s0_y0, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s0_y0], tlx.transpose(h2[idx_s0_y0]))) \ + - node_num / (tlx.reduce_sum(tlx.cast(idx_s0_y1, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s0_y1], tlx.transpose(h2[idx_s0_y1]))) \ + - node_num / (tlx.reduce_sum(tlx.cast(idx_s1_y0, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s1_y0], tlx.transpose(h2[idx_s1_y0]))) \ + - node_num / (tlx.reduce_sum(tlx.cast(idx_s1_y1, dtype=tlx.float32))) * tlx.reduce_mean(tlx.matmul(h1[idx_s1_y1], tlx.transpose(h2[idx_s1_y1]))) + + loss = loss_align * 0.01 + return loss + + +def main(args): + + # load datasets + if str.lower(args.dataset) not in ['bail', 'credit', 'pokec']: + raise ValueError('Unknown dataset: {}'.format(args.dataset)) + + if args.dataset == 'bail': + dataset = Bail(args.dataset_path, args.dataset) + + elif args.dataset == 'credit': + dataset = Credit(args.dataset_path, args.dataset) + + graphs = dataset.data + data = { + 'x':graphs[0].x, + 'y': graphs[0].y, + 'edge_index': {'edge_index': graphs[0].edge_index}, + 'sens': graphs[0].sens, + 'train_mask': graphs[0].train_mask, + } + data_test = [] + for i in range(1, len(graphs)): + data_tem = { + 'x':graphs[i].x, + 'y': graphs[i].y, + 'edge_index': graphs[i].edge_index, + 'sens': graphs[i].sens, + 'test_mask': graphs[i].train_mask | graphs[i].val_mask | graphs[i].test_mask, + } + data_test.append(data_tem) + dataset = None + graphs = None + args.num_features, args.num_classes = data['x'].shape[1], len(np.unique(tlx.convert_to_numpy(data['y']))) - 1 + args.test_set_num = len(data_test) + + t_idx_s0 = data['sens'][data['train_mask']] == 0 + t_idx_s1 = data['sens'][data['train_mask']] == 1 + t_idx_s0_y1 = tlx.logical_and(t_idx_s0, data['y'][data['train_mask']] == 1) + t_idx_s1_y1 = tlx.logical_and(t_idx_s1, data['y'][data['train_mask']] == 1) + + idx_s0 = data['sens'] == 0 + idx_s1 = data['sens'] == 1 + idx_s0_y1 = tlx.logical_and(idx_s0, data['y'] == 1) + idx_s1_y1 = tlx.logical_and(idx_s1, data['y'] == 1) + idx_s0_y0 = tlx.logical_and(idx_s0, data['y'] == 0) + idx_s1_y0 = tlx.logical_and(idx_s1, data['y'] == 0) + + data['idx_s0_y0'] = idx_s0_y0 + data['idx_s1_y0'] = idx_s1_y0 + data['idx_s0_y1'] = idx_s0_y1 + data['idx_s1_y1'] = idx_s1_y1 + data['t_idx_s0_y1'] = t_idx_s0_y1 + data['t_idx_s1_y1'] = t_idx_s1_y1 + + edge_index_np = tlx.convert_to_numpy(data['edge_index']['edge_index']) + adj = sp.coo_matrix((np.ones(data['edge_index']['edge_index'].shape[1]), (edge_index_np[0, :], edge_index_np[1, :])), + shape=(data['x'].shape[0], data['x'].shape[0]), + dtype=np.float32) + A2 = adj.dot(adj) + A2 = A2.toarray() + A2_edge = tlx.convert_to_tensor(np.vstack((A2.nonzero()[0], A2.nonzero()[1]))) + + net = FatraGNNModel(args) + + dic_loss_func = DicLoss(net, tlx.losses.binary_cross_entropy) + enc_cla_loss_func = EncClaLoss(net, tlx.losses.binary_cross_entropy) + enc_loss_func = EncLoss(net, tlx.losses.binary_cross_entropy) + edt_loss_func = EdtLoss(net, tlx.losses.binary_cross_entropy) + ali_loss_func = AliLoss(net, tlx.losses.binary_cross_entropy) + + dic_opt = tlx.optimizers.Adam(lr=args.d_lr, weight_decay=args.d_wd) + dic_train_one_step = TrainOneStep(dic_loss_func, dic_opt, net.discriminator.trainable_weights) + + enc_cla_opt = tlx.optimizers.Adam(lr=args.c_lr, weight_decay=args.c_wd) + enc_cla_train_one_step = TrainOneStep(enc_cla_loss_func, enc_cla_opt, net.encoder.trainable_weights+net.classifier.trainable_weights) + + enc_opt = tlx.optimizers.Adam(lr=args.e_lr, weight_decay=args.e_wd) + enc_train_one_step = TrainOneStep(enc_loss_func, enc_opt, net.encoder.trainable_weights) + + edt_opt = tlx.optimizers.Adam(lr=args.g_lr, weight_decay=args.g_wd) + edt_train_one_step = TrainOneStep(edt_loss_func, edt_opt, net.graphEdit.trainable_weights) + + ali_opt = tlx.optimizers.Adam(lr=args.e_lr, weight_decay=args.e_wd) + ali_train_one_step = TrainOneStep(ali_loss_func, ali_opt, net.encoder.trainable_weights) + + tlx.set_seed(args.seed) + net.set_train() + for epoch in range(0, args.epochs): + print(f"======={epoch}=======") + # train discriminator to recognize the sensitive group + data['flag'] = 1 + for epoch_d in range(0, args.dic_epochs): + dic_loss = dic_train_one_step(data=data, label=data['y']) + + # train classifier and encoder + data['flag'] = 2 + for epoch_c in range(0, args.cla_epochs): + enc_cla_loss = enc_cla_train_one_step(data=data, label=data['y']) + + # train encoder to fool discriminator + data['flag'] = 3 + for epoch_g in range(0, args.g_epochs): + enc_loss = enc_train_one_step(data=data, label=data['y']) + + # train generator + data['flag'] = 4 + if epoch > args.start: + if epoch % 10 == 0: + if epoch % 20 == 0: + data['edge_index']['edge_index2'] = net.graphEdit.modify_structure1(data['edge_index']['edge_index'], A2_edge, data['sens'], data['x'].shape[0], args.drope_rate) + else: + data['edge_index']['edge_index2'] = net.graphEdit.modify_structure2(data['edge_index']['edge_index'], A2_edge, data['sens'], data['x'].shape[0], args.drope_rate) + else: + data['edge_index']['edge_index2'] = data['edge_index']['edge_index'] + + for epoch_g in range(0, args.dtb_epochs): + edt_loss = edt_train_one_step(data=data, label=data['y']) + + # shift align + data['flag'] = 5 + if epoch > args.start: + for epoch_a in range(0, args.a_epochs): + aliloss = ali_train_one_step(data=data, label=data['y']) + + acc = np.zeros([args.test_set_num]) + auc_roc = np.zeros([args.test_set_num]) + parity = np.zeros([args.test_set_num]) + equality = np.zeros([args.test_set_num]) + net.set_eval() + for i in range(args.test_set_num): + data_tem = data_test[i] + acc[i],auc_roc[i], parity[i], equality[i] = evaluate_ged3(net, data_tem['x'], data_tem['edge_index'], data_tem['y'], data_tem['test_mask'], data_tem['sens']) + return acc, auc_roc, parity, equality + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='bail') + parser.add_argument('--start', type=int, default=50) + parser.add_argument('--epochs', type=int, default=400) + parser.add_argument('--dic_epochs', type=int, default=5) + parser.add_argument('--dtb_epochs', type=int, default=5) + parser.add_argument('--cla_epochs', type=int, default=12) + parser.add_argument('--a_epochs', type=int, default=2) + parser.add_argument('--g_epochs', type=int, default=5) + parser.add_argument('--g_lr', type=float, default=0.05) + parser.add_argument('--g_wd', type=float, default=0.01) + parser.add_argument('--d_lr', type=float, default=0.001) + parser.add_argument('--d_wd', type=float, default=0) + parser.add_argument('--c_lr', type=float, default=0.001) + parser.add_argument('--c_wd', type=float, default=0.01) + parser.add_argument('--e_lr', type=float, default=0.005) + parser.add_argument('--e_wd', type=float, default=0) + parser.add_argument('--hidden', type=int, default=128) + parser.add_argument('--seed', type=int, default=3) + parser.add_argument('--top_k', type=int, default=10) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--drope_rate', type=float, default=0.1) + parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset") + + args = parser.parse_args() + + if args.gpu >= 0: + tlx.set_device("GPU", args.gpu) + else: + tlx.set_device("CPU") + args.device = f'cuda:{args.gpu}' + + + fileNamePath = os.path.split(os.path.realpath(__file__))[0] + yamlPath = os.path.join(fileNamePath, 'config.yaml') + with open(yamlPath, 'r', encoding='utf-8') as f: + cont = f.read() + config_dict = yaml.safe_load(cont)[args.dataset] + for key, value in config_dict.items(): + args.__setattr__(key, value) + + print(args) + acc, auc_roc, parity, equality = main(args) + + for i in range(args.test_set_num): + print("===========test{}============".format(i+1)) + print('Acc: ', acc.T[i]) + print('auc_roc: ', auc_roc.T[i]) + print('parity: ', parity.T[i]) + print('equality: ', equality.T[i]) diff --git a/examples/fatragnn/readme.md b/examples/fatragnn/readme.md new file mode 100644 index 00000000..0fa5ee69 --- /dev/null +++ b/examples/fatragnn/readme.md @@ -0,0 +1,63 @@ +# Graph Fairness Learning under Distribution Shifts + +- Paper link: [https://arxiv.org/abs/2401.16784](https://arxiv.org/abs/2401.16784) +- Author's code repo: [https://github.com/BUPT-GAMMA/FatraGNN](https://github.com/BUPT-GAMMA/FatraGNN). Note that the original code is implemented with Torch for the paper. + +# Dataset Statics + + +| Dataset | # Nodes | # Edges | # Classes | +|----------|---------|---------|-----------| +| Bail_B0 | 4,686 | 153,942 | 2 | +| Bail_B1 | 2,214 | 49,124 | 2 | +| Bail_B2 | 2,395 | 88,091 | 2 | +| Bail_B3 | 1,536 | 57,838 | 2 | +| Bail_B4 | 1,193 | 30,319 | 2 | +| Credit_C0| 4,184 | 45,718 | 2 | +| Credit_C1| 2,541 | 18,949 | 2 | +| Credit_C2| 3,796 | 28,936 | 2 | +| Credit_C3| 2,068 | 15,314 | 2 | +| Credit_C4| 3,420 | 26,048 | 2 | + +Refer to [Credit](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Credit) and [Bail](https://gammagl.readthedocs.io/en/latest/api/gammagl.datasets.html#gammagl.datasets.Bail). + +Results +------- + + + + +```bash +TL_BACKEND="torch" python fatragnn_trainer.py --dataset credit --epochs 600 --g_epochs 5 --a_epochs 2 --cla_epochs 12 --dic_epochs 5 --dtb_epochs 5 --c_lr 0.01 --e_lr 0.01 +TL_BACKEND="torch" python fatragnn_trainer.py --dataset bail --epochs 400 --g_epochs 5 --a_epochs 4 --cla_epochs 10 --dic_epochs 8 --dtb_epochs 5 --c_lr 0.005 --e_lr 0.005 + + +TL_BACKEND="tensorflow" python fatragnn_trainer.py --dataset credit --epochs 600 --g_epochs 5 --a_epochs 2 --cla_epochs 12 --dic_epochs 5 --dtb_epochs 5 --c_lr 0.01 --e_lr 0.01 +TL_BACKEND="tensorflow" python fatragnn_trainer.py --dataset bail --epochs 400 --g_epochs 5 --a_epochs 4 --cla_epochs 10 --dic_epochs 8 --dtb_epochs 5 --c_lr 0.005 --e_lr 0.005 +``` +ACC: +| Dataset | Paper | Our(torch) | Our(tensorflow) | +| ---------- | ----------- | ---------------- | --------------- | +| Credit_C1 | 77.31±0.10 | 77.08(±0.08) | 77.06(±0.10) | +| Credit_C2 | 77.12±0.28 | 77.26(±0.13) | 77.22(±0.11) | +| Credit_C3 | 71.81±0.39 | 70.86(±0.15) | 71.02(±0.12) | +| Credit_C4 | 72.15±0.42 | 70.91(±0.10) | 71.08(±0.09) | +| Bail_B1 | 74.59±0.93 | 72.13(±0.97) | 72.08(±0.98) | +| Bail_B2 | 70.46±0.44 | 78.55(±0.94) | 79.02(±0.31) | +| Bail_B3 | 71.65±4.65 | 79.77(±0.70) | 78.96(±0.76) | +| Bail_B4 | 72.59±3.39 | 80.35(±1.73) | 79.91(±0.64) | + + + + +equality: +| Dataset | Paper | Our(torch) | Our(tensorflow) | +| ---------- | ---------- | ---------------- | --------------- | +| Credit_C1 | 0.71±0.03 | 0.53(±0.05) | 0.41(±0.02) | +| Credit_C2 | 0.95±0.7 | 0.13(±0.10) | 0.30(±0.39) | +| Credit_C3 | 0.81±0.56 | 1.81(±1.68) | 2.51(±1.92) | +| Credit_C4 | 1.16±0.13 | 0.14(±0.07) | 0.18(±0.13) | +| Bail_B1 | 2.38±3.19 | 4.38(±2.87) | 1.28(±1.04) | +| Bail_B2 | 0.43±1.14 | 4.48(±2.52) | 3.51(±1.92) | +| Bail_B3 | 2.43±4.94 | 2.62(±2.55) | 2.13(±0.43) | +| Bail_B4 | 2.45±6.67 | 1.16(±1.40) | 3.03(±1.22) | diff --git a/gammagl/datasets/__init__.py b/gammagl/datasets/__init__.py index f1a64eab..e49f39f1 100644 --- a/gammagl/datasets/__init__.py +++ b/gammagl/datasets/__init__.py @@ -23,6 +23,8 @@ from .facebook import FacebookPagePage from .acm4heco import ACM4HeCo from .yelp import Yelp +from .bail import Bail +from .credit import Credit from .acm4dhn import ACM4DHN __all__ = [ @@ -50,6 +52,8 @@ 'FacebookPagePage', 'NGSIM_US_101', 'Yelp', + 'Bail', + 'Credit', 'ACM4DHN' ] diff --git a/gammagl/datasets/bail.py b/gammagl/datasets/bail.py new file mode 100644 index 00000000..75ea20d5 --- /dev/null +++ b/gammagl/datasets/bail.py @@ -0,0 +1,215 @@ +import pandas as pd +import numpy as np +import scipy.sparse as sp +from scipy.spatial import distance_matrix +import os.path as osp +from typing import Union, List, Tuple +from gammagl.data import download_url, InMemoryDataset, Graph +import tensorlayerx as tlx +from gammagl.utils.mask import index_to_mask + + +def sys_normalized_adjacency(adj): + adj = sp.coo_matrix(adj) + adj = adj + sp.eye(adj.shape[0]) + row_sum = np.array(adj.sum(1)) + row_sum = (row_sum == 0) * 1 + row_sum + d_inv_sqrt = np.power(row_sum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + + return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo() + + +def build_relationship(x, thresh=0.25): + df_euclid = pd.DataFrame( + 1 / (1 + distance_matrix(x.T.T, x.T.T)), columns=x.T.columns, index=x.T.columns) + df_euclid = df_euclid.to_numpy() + idx_map = [] + for ind in range(df_euclid.shape[0]): + max_sim = np.sort(df_euclid[ind, :])[-2] + neig_id = np.where(df_euclid[ind, :] > thresh * max_sim)[0] + import random + random.seed(912) + random.shuffle(neig_id) + for neig in neig_id: + if neig != ind: + idx_map.append([ind, neig]) + idx_map = np.array(idx_map) + return idx_map + + +class Bail(InMemoryDataset): + r""" + The datasets "Bail-Bs" from the + `"Graph Fairness Learning under Distribution Shifts" + `_ paper. + Nodes represent defendants released on bail. + Training, validation and test splits are given by binary masks. + + Parameters + ---------- + root: str, optional + Root directory where the dataset should be saved. + transform: callable, optional + A function/transform that takes in an + :obj:`gammagl.data.Graph` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform: callable, optional + A function/transform that takes in + an :obj:`gammagl.data.Graph` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + + Tip + --- + .. list-table:: + :widths: 10 10 10 10 10 + :header-rows: 1 + + * - Name + - #nodes + - #edges + - #features + - #classes + * - Bail_B0 + - 4686 + - 153942 + - 18 + - 2 + * - Bail_B1 + - 2214 + - 49124 + - 18 + - 2 + * - Bail_B2 + - 2395 + - 88091 + - 18 + - 2 + * - Bail_B3 + - 1536 + - 57838 + - 18 + - 2 + * - Bail_B4 + - 1193 + - 30319 + - 18 + - 2 + """ + + url = 'https://raw.githubusercontent.com/liushiliushi/FatraGNN/main/dataset' + + def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None, force_reload: bool = False): + self.name = 'bail' + self.top_k = 10 + self.strlist = ['_B0', '_B1','_B2', '_B3', '_B4'] + super(Bail, self).__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) + self.data = self.load_data(self.processed_paths[0]) + + @property + def raw_dir(self) -> str: + return osp.join(self.root, self.name, 'raw') + + @property + def processed_dir(self) -> str: + return osp.join(self.root, self.name, 'processed') + + @property + def raw_file_names(self) -> List[str]: + self.strlist = ['_B0', '_B1','_B2', '_B3', '_B4'] + feature_names = [f'{self.name}{name}.csv' for name in self.strlist] + edge_names = [f'{self.name}{name}_edges.txt' for name in self.strlist] + return feature_names + edge_names + + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + return tlx.BACKEND + '_data.pt' + + @property + def num_classes(self) -> int: + return super().num_classes() + + def download(self): + for name in self.raw_file_names: + download_url(f'{self.url}/{self.name}/{name}', self.raw_dir) + + def process(self): + sens_attr = "WHITE" + predict_attr='RECID' + label_number=100 + + data_list = [] + for i in self.strlist: + + idx_features_labels = pd.read_csv(osp.join(self.raw_dir, "{}.csv".format(self.name + i))) + if 'Unnamed: 0' in idx_features_labels.columns: + idx_features_labels = idx_features_labels.drop(['Unnamed: 0'], axis=1) + + header = list(idx_features_labels.columns) + header.remove(predict_attr) + + edges_unordered = np.genfromtxt(osp.join(self.raw_dir, "{}_edges.txt".format(self.name + i))).astype('int') + features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32) + labels = idx_features_labels[predict_attr].values + + idx = np.arange(features.shape[0]) + idx_map = {j: i for i, j in enumerate(idx)} + edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), + dtype=int).reshape(edges_unordered.shape) + adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), + shape=(labels.shape[0], labels.shape[0]), + dtype=np.float32) + + adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) + adj = adj + sp.eye(adj.shape[0]) + adj_norm = sys_normalized_adjacency(adj) + + edge_index = tlx.convert_to_tensor(np.vstack((adj_norm.row, adj_norm.col)).astype(np.int64)) + features = tlx.convert_to_tensor(np.array(features.todense()).astype(np.float32)) + + import random + random.seed(20) + label_idx_0 = np.where(labels == 0)[0] + label_idx_1 = np.where(labels == 1)[0] + + labels = tlx.convert_to_tensor(labels.astype(np.float32)) + random.shuffle(label_idx_0) + random.shuffle(label_idx_1) + idx_train = np.append(label_idx_0[:min(int(0.5 * len(label_idx_0)), label_number // 2)], + label_idx_1[:min(int(0.5 * len(label_idx_1)), label_number // 2)]) + idx_val = np.append(label_idx_0[int(0.5 * len(label_idx_0)):int(0.75 * len( + label_idx_0))], label_idx_1[int(0.5 * len(label_idx_1)):int(0.75 * len(label_idx_1))]) + idx_test = np.append(label_idx_0[int( + 0.75 * len(label_idx_0)):], label_idx_1[int(0.75 * len(label_idx_1)):]) + + sens = idx_features_labels[sens_attr].values.astype(int) + sens = tlx.convert_to_tensor(sens) + train_mask = index_to_mask(tlx.convert_to_tensor(idx_train), features.shape[0]) + val_mask = index_to_mask(tlx.convert_to_tensor(idx_val), features.shape[0]) + test_mask = index_to_mask(tlx.convert_to_tensor(idx_test), features.shape[0]) + + sens_idx = 0 + x_max, x_min = tlx.reduce_max(features, axis=0), tlx.reduce_min(features, axis=0) + + norm_features = 2 * (features - x_min)/(x_max - x_min) - 1 + norm_features = tlx.convert_to_numpy(norm_features) + features = tlx.convert_to_numpy(features) + norm_features[:, sens_idx] = features[:, sens_idx] + features = norm_features + features = tlx.convert_to_tensor(features) + corr = pd.DataFrame(np.array(tlx.to_device(features, 'cpu'))).corr() + corr_matrix = corr[sens_idx].to_numpy() + corr_idx = np.argsort(-np.abs(corr_matrix)) + corr_idx = tlx.convert_to_tensor(corr_idx[:self.top_k]) + + data = Graph(x=features, edge_index=edge_index, adj=adj, y=labels, + train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, sens=sens) + data_list.append(data) + self.save_data(data_list, self.processed_paths[0]) + diff --git a/gammagl/datasets/credit.py b/gammagl/datasets/credit.py new file mode 100644 index 00000000..2cf0c363 --- /dev/null +++ b/gammagl/datasets/credit.py @@ -0,0 +1,215 @@ +import pandas as pd +import numpy as np +import scipy.sparse as sp +from scipy.spatial import distance_matrix +import os.path as osp +from typing import Union, List, Tuple +from gammagl.data import download_url, InMemoryDataset, Graph +import tensorlayerx as tlx +from gammagl.utils.mask import index_to_mask + + +def sys_normalized_adjacency(adj): + adj = sp.coo_matrix(adj) + adj = adj + sp.eye(adj.shape[0]) + row_sum = np.array(adj.sum(1)) + row_sum = (row_sum == 0) * 1 + row_sum + d_inv_sqrt = np.power(row_sum, -0.5).flatten() + d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. + d_mat_inv_sqrt = sp.diags(d_inv_sqrt) + + return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo() + + +def build_relationship(x, thresh=0.25): + df_euclid = pd.DataFrame( + 1 / (1 + distance_matrix(x.T.T, x.T.T)), columns=x.T.columns, index=x.T.columns) + df_euclid = df_euclid.to_numpy() + idx_map = [] + for ind in range(df_euclid.shape[0]): + max_sim = np.sort(df_euclid[ind, :])[-2] + neig_id = np.where(df_euclid[ind, :] > thresh * max_sim)[0] + import random + random.seed(912) + random.shuffle(neig_id) + for neig in neig_id: + if neig != ind: + idx_map.append([ind, neig]) + idx_map = np.array(idx_map) + return idx_map + + +class Credit(InMemoryDataset): + r""" + The datasets "Bail-Bs" from the + `"Graph Fairness Learning under Distribution Shifts" + `_ paper. + Nodes represent credit card users. + Training, validation and test splits are given by binary masks. + + Parameters + ---------- + root: str, optional + Root directory where the dataset should be saved. + transform: callable, optional + A function/transform that takes in an + :obj:`gammagl.data.Graph` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform: callable, optional + A function/transform that takes in + an :obj:`gammagl.data.Graph` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + + Tip + --- + .. list-table:: + :widths: 10 10 10 10 10 + :header-rows: 1 + + * - Name + - #nodes + - #edges + - #classes + * - Credit_C0 + - 4184 + - 45718 + - 13 + - 2 + * - Credit_C1 + - 2541 + - 18949 + - 13 + - 2 + * - Credit_C2 + - 3796 + - 28936 + - 13 + - 2 + * - Credit_C3 + - 2068 + - 15314 + - 13 + - 2 + * - Credit_C4 + - 3420 + - 26048 + - 13 + - 2 + """ + + url = 'https://raw.githubusercontent.com/liushiliushi/FatraGNN/main/dataset' + + + def __init__(self, root=None, transform=None, pre_transform=None, pre_filter=None, force_reload: bool = False): + self.name = 'credit' + self.top_k = 10 + self.strlist = ['_C0', '_C1', '_C2', '_C3', '_C4'] + super(Credit, self).__init__(root, transform, pre_transform, pre_filter, force_reload = force_reload) + self.data = self.load_data(self.processed_paths[0]) + + @property + def raw_dir(self) -> str: + return osp.join(self.root, self.name, 'raw') + + @property + def processed_dir(self) -> str: + return osp.join(self.root, self.name, 'processed') + + @property + def raw_file_names(self) -> List[str]: + self.strlist = ['_C0', '_C1', '_C2', '_C3', '_C4'] + feature_names = [f'{self.name}{name}.csv' for name in self.strlist] + edge_names = [f'{self.name}{name}_edges.txt' for name in self.strlist] + return feature_names + edge_names + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple]: + return tlx.BACKEND + '_data.pt' + + @property + def num_classes(self) -> int: + return super().num_classes() + + def download(self): + for name in self.raw_file_names: + download_url(f'{self.url}/{self.name}/{name}', self.raw_dir) + + def process(self): + sens_attr="Age" + predict_attr="NoDefaultNextMonth" + label_number=6000 + + data_list = [] + for i in self.strlist: + + idx_features_labels = pd.read_csv(osp.join(self.raw_dir, "{}.csv".format(self.name + i))) + if 'Unnamed: 0' in idx_features_labels.columns: + idx_features_labels = idx_features_labels.drop(['Unnamed: 0'], axis=1) + + header = list(idx_features_labels.columns) + header.remove(predict_attr) + header.remove('Single') + + edges_unordered = np.genfromtxt(osp.join(self.raw_dir, "{}_edges.txt".format(self.name + i))).astype('int') + features = sp.csr_matrix(idx_features_labels[header], dtype=np.float32) + labels = idx_features_labels[predict_attr].values + + idx = np.arange(features.shape[0]) + idx_map = {j: i for i, j in enumerate(idx)} + edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), + dtype=int).reshape(edges_unordered.shape) + adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), + shape=(labels.shape[0], labels.shape[0]), + dtype=np.float32) + + adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) + adj = adj + sp.eye(adj.shape[0]) + adj_norm = sys_normalized_adjacency(adj) + + edge_index = tlx.convert_to_tensor(np.vstack((adj_norm.row, adj_norm.col)).astype(np.int64)) + features = tlx.convert_to_tensor(np.array(features.todense()).astype(np.float32)) + + import random + random.seed(20) + label_idx_0 = np.where(labels == 0)[0] + label_idx_1 = np.where(labels == 1)[0] + + labels = tlx.convert_to_tensor(labels.astype(np.float32)) + random.shuffle(label_idx_0) + random.shuffle(label_idx_1) + idx_train = np.append(label_idx_0[:min(int(0.5 * len(label_idx_0)), label_number // 2)], + label_idx_1[:min(int(0.5 * len(label_idx_1)), label_number // 2)]) + idx_val = np.append(label_idx_0[int(0.5 * len(label_idx_0)):int(0.75 * len( + label_idx_0))], label_idx_1[int(0.5 * len(label_idx_1)):int(0.75 * len(label_idx_1))]) + idx_test = np.append(label_idx_0[int( + 0.75 * len(label_idx_0)):], label_idx_1[int(0.75 * len(label_idx_1)):]) + + sens = idx_features_labels[sens_attr].values.astype(int) + sens = tlx.convert_to_tensor(sens) + train_mask = index_to_mask(tlx.convert_to_tensor(idx_train), features.shape[0]) + val_mask = index_to_mask(tlx.convert_to_tensor(idx_val), features.shape[0]) + test_mask = index_to_mask(tlx.convert_to_tensor(idx_test), features.shape[0]) + + sens_idx = 1 + x_max, x_min = tlx.reduce_max(features, axis=0), tlx.reduce_min(features, axis=0) + + norm_features = 2 * (features - x_min)/(x_max - x_min) - 1 + norm_features = tlx.convert_to_numpy(norm_features) + features = tlx.convert_to_numpy(features) + norm_features[:, sens_idx] = features[:, sens_idx] + features = norm_features + features = tlx.convert_to_tensor(features) + corr = pd.DataFrame(np.array(tlx.to_device(features, 'cpu'))).corr() + corr_matrix = corr[sens_idx].to_numpy() + corr_idx = np.argsort(-np.abs(corr_matrix)) + corr_idx = tlx.convert_to_tensor(corr_idx[:self.top_k]) + + data = Graph(x=features, edge_index=edge_index, adj=adj, y=labels, + train_mask=train_mask, val_mask=val_mask, test_mask=test_mask, sens=sens) + data_list.append(data) + self.save_data(data_list, self.processed_paths[0]) + diff --git a/gammagl/models/__init__.py b/gammagl/models/__init__.py index 61b3e112..24ed6176 100644 --- a/gammagl/models/__init__.py +++ b/gammagl/models/__init__.py @@ -61,6 +61,7 @@ from .dhn import DHNModel from .dna import DNAModel from .dfad import DFADModel, DFADGenerator +from .fatragnn import FatraGNNModel, Graph_Editer __all__ = [ 'HeCo', @@ -128,7 +129,8 @@ 'DHNModel', 'DNAModel', 'DFADModel', - 'DFADGenerator' + 'DFADGenerator', + 'FatraGNNModel' ] classes = __all__ diff --git a/gammagl/models/fatragnn.py b/gammagl/models/fatragnn.py new file mode 100644 index 00000000..5e8d0ade --- /dev/null +++ b/gammagl/models/fatragnn.py @@ -0,0 +1,189 @@ +from gammagl.layers.conv import GCNConv +import tensorlayerx.nn as nn +import tensorlayerx as tlx +from gammagl.utils import mask_to_index +import numpy as np +import random + + +class GCN_encoder(tlx.nn.Module): + def __init__(self, args): + super(GCN_encoder, self).__init__() + self.args = args + self.conv = GCNConv(args.num_features, args.hidden) + + def forward(self, x, edge_index): + h = self.conv(x, edge_index, edge_weight=tlx.ones((edge_index.shape[1],), dtype=tlx.float32), num_nodes=x.shape[0]) + + return h + + +class MLP_discriminator(tlx.nn.Module): + def __init__(self, args): + super(MLP_discriminator, self).__init__() + self.args = args + self.lin = nn.Linear(in_features=args.hidden, out_features=1) + + def forward(self, h): + h = self.lin(h) + + return tlx.sigmoid(h) + + +class MLP_classifier(tlx.nn.Module): + def __init__(self, args): + super(MLP_classifier, self).__init__() + self.args = args + self.lin = nn.Linear(in_features=args.hidden, out_features=1) + + def forward(self, h): + h = self.lin(h) + + return h + + +class FatraGNNModel(tlx.nn.Module): + r"""FatraGNN from `"Graph Fairness Learning under Distribution Shifts" + `_ paper. + + Parameters + ---------- + in_features: int + input feature dimension. + hidden: int + hidden dimension. + out_features: int + number of output feature dimension. + drop_rate: float + dropout rate. + """ + def __init__(self, args): + super(FatraGNNModel, self).__init__() + self.classifier = MLP_classifier(args) + self.graphEdit = Graph_Editer(1, args.num_features, args.device) + self.encoder = GCN_encoder(args) + self.discriminator = MLP_discriminator(args) + def forward(self, x, edge_index, flag): + if flag==0: + h = self.encoder(x, edge_index) + output = self.classifier(h) + return output + + if flag==1: + edge_index = edge_index['edge_index'] + h = self.encoder(x, edge_index) + output = self.discriminator(h) + + elif flag==2: + h = self.encoder(x, edge_index['edge_index']) + h = self.classifier(h) + output = tlx.sigmoid(h) + + elif flag==3: + h = self.encoder(x, edge_index['edge_index']) + output = self.discriminator(h) + + elif flag==4: + x2 = self.graphEdit(x) + h2 = self.encoder(x2, edge_index['edge_index2']) + h2 = tlx.l2_normalize(h2, axis=1) + output = self.classifier(h2) + + elif flag==5: + x2 = self.graphEdit(x) + h2 = self.encoder(x2, edge_index['edge_index2']) + h1 = self.encoder(x, edge_index['edge_index']) + h2 = tlx.l2_normalize(h2, axis=1) + h1 = tlx.l2_normalize(h1, axis=1) + output = { + 'h1': h1, + 'h2': h2, + } + + return output + +class Graph_Editer(tlx.nn.Module): + def __init__(self, n, a, device): + super(Graph_Editer, self).__init__() + self.transFeature = nn.Linear(in_features=a, out_features=a) + self.device = device + self.seed = 13 + + + def modify_structure1(self, edge_index, A2_edge, sens, nodes_num, drop=0.8, add=0.3): + random.seed(self.seed) + src_node, targ_node = edge_index[0], edge_index[1] + matching = tlx.gather(sens, src_node) == tlx.gather(sens, targ_node) + + yipei = mask_to_index(matching == False) + drop_index = tlx.convert_to_tensor(random.sample(range(yipei.shape[0]), int(yipei.shape[0] * drop))) + yipei_drop = tlx.gather(yipei, drop_index) + keep_indices0 = tlx.ones(src_node.shape, dtype=tlx.bool) + keep_indices = tlx.scatter_update(keep_indices0, yipei_drop, tlx.zeros((yipei_drop.shape), dtype=tlx.bool)) + n_src_node = src_node[keep_indices] + n_targ_node = targ_node[keep_indices] + + src_node2, targ_node2 = A2_edge[0], A2_edge[1] + matching2 = tlx.gather(sens, src_node2) == tlx.gather(sens, targ_node2) + matching3 = src_node2 == targ_node2 + tongpei = mask_to_index(tlx.logical_and(matching2 == True, matching3 == False) == True) + add_index = tlx.convert_to_tensor(random.sample(range(tongpei.shape[0]), int(yipei.shape[0] * drop))) + tongpei_add = tlx.gather(tongpei, add_index) + keep_indices0 = tlx.zeros(src_node2.shape, dtype=tlx.bool) + keep_indices = tlx.scatter_update(keep_indices0, tongpei_add, tlx.ones((tongpei_add.shape), dtype=tlx.bool)) + + a_src_node = src_node2[keep_indices] + a_targ_node = targ_node2[keep_indices] + + m_src_node = tlx.concat((a_src_node, n_src_node), axis=0) + m_targ_node = tlx.concat((a_targ_node, n_targ_node), axis=0) + n_edge_index = tlx.concat((tlx.expand_dims(m_src_node, axis=1), tlx.expand_dims(m_targ_node, axis=1)), axis=1) + return n_edge_index + + + def modify_structure2(self, edge_index, A2_edge, sens, nodes_num, drop=0.6, add=0.3): + random.seed(self.seed) + src_node, targ_node = edge_index[0], edge_index[1] + matching = tlx.gather(sens, src_node) == tlx.gather(sens, targ_node) + + + yipei = mask_to_index(matching == False) + yipei_np = tlx.convert_to_numpy(yipei) + np.random.shuffle(yipei_np) + yipei_shuffled = tlx.convert_to_tensor(yipei_np) + + drop_index = tlx.convert_to_tensor(random.sample(range(yipei_shuffled.shape[0]), int(yipei_shuffled.shape[0] * drop))) + yipei_drop = tlx.gather(yipei_shuffled, drop_index) + keep_indices0 = tlx.ones(src_node.shape, dtype=tlx.bool) + keep_indices = tlx.scatter_update(keep_indices0, yipei_drop, tlx.zeros((yipei_drop.shape), dtype=tlx.bool)) + n_src_node = src_node[keep_indices] + n_targ_node = targ_node[keep_indices] + + src_node2, targ_node2 = A2_edge[0], A2_edge[1] + matching2 = tlx.gather(sens, src_node2) != tlx.gather(sens, targ_node2) + matching3 = src_node2 == targ_node2 + tongpei = mask_to_index(tlx.logical_and(matching2 == True, matching3 == False) == True) + tongpei_np = tlx.convert_to_numpy(tongpei) + np.random.shuffle(tongpei_np) + tongpei_shuffled = tlx.convert_to_tensor(tongpei_np) + add_index = tlx.convert_to_tensor(random.sample(range(tongpei_shuffled.shape[0]), int(yipei_shuffled.shape[0] * drop))) + tongpei_add = tlx.gather(tongpei_shuffled, add_index) + keep_indices0 = tlx.zeros(src_node2.shape, dtype=tlx.bool) + keep_indices = tlx.scatter_update(keep_indices0, tongpei_add, tlx.ones((tongpei_add.shape), dtype=tlx.bool)) + + a_src_node = src_node2[keep_indices] + a_targ_node = targ_node2[keep_indices] + + m_src_node = tlx.concat((a_src_node, n_src_node), axis=0) + m_targ_node = tlx.concat((a_targ_node, n_targ_node), axis=0) + n_edge_index = tlx.concat((tlx.expand_dims(m_src_node, axis=1), tlx.expand_dims(m_targ_node, axis=1)), axis=1) + + return n_edge_index + + def forward(self, x): + x1 = x + 0.1 * self.transFeature(x) + + return x1 + + +