From dc2422636553f803fd9212dd5850357abc4402e8 Mon Sep 17 00:00:00 2001 From: "Zhang, Chaojun" Date: Wed, 3 Jul 2024 10:29:52 +0000 Subject: [PATCH] Add attention explainer example --- examples/compile/gcn.py | 4 +- examples/explain/attention_exaplainer.py | 65 ++++++++++++++++++++++++ examples/explain/graphmask_explainer.py | 55 ++++---------------- 3 files changed, 76 insertions(+), 48 deletions(-) create mode 100644 examples/explain/attention_exaplainer.py diff --git a/examples/compile/gcn.py b/examples/compile/gcn.py index cc8e7ba2b462b..b93c69d3e23eb 100644 --- a/examples/compile/gcn.py +++ b/examples/compile/gcn.py @@ -34,9 +34,9 @@ def __init__(self, in_channels, hidden_channels, out_channels): def forward(self, x, edge_index, edge_weight): x = F.dropout(x, p=0.5, training=self.training) - x = self.conv1(x, edge_index, edge_weight).relu() + # x = self.conv1(x, edge_index, edge_weight).relu() x = F.dropout(x, p=0.5, training=self.training) - x = self.conv2(x, edge_index) + # x = self.conv2(x, edge_index) return x diff --git a/examples/explain/attention_exaplainer.py b/examples/explain/attention_exaplainer.py new file mode 100644 index 0000000000000..65e7c0f3d15db --- /dev/null +++ b/examples/explain/attention_exaplainer.py @@ -0,0 +1,65 @@ +import os.path as osp + +import torch +import torch.nn.functional as F + +import torch_geometric +from torch_geometric.datasets import Planetoid +from torch_geometric.explain import Explainer, AttentionExplainer +from torch_geometric.nn import GATConv + +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch_geometric.is_xpu_available(): + device = torch.device('xpu') +else: + device = torch.device('cpu') + +path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') +dataset = Planetoid(path, name='Cora') +data = dataset[0].to(device) + + +# GAT Node Classification ===================================================== + +class GAT(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6) + self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False, dropout=0.6) + + def forward(self, x, edge_index): + x = F.dropout(x, p=0.6, training=self.training) + x = F.elu(self.conv1(x, edge_index)) + x = F.dropout(x, p=0.6, training=self.training) + x = self.conv2(x, edge_index) + return x + + +model = GAT().to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + +for epoch in range(1, 201): + model.train() + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + +explainer = Explainer( + model=model, + algorithm=AttentionExplainer(), + explanation_type='model', + node_mask_type='attributes', + edge_mask_type='object', + model_config=dict( + mode='multiclass_classification', + task_level='node', + return_type='log_probs', + ), +) + +node_index = torch.tensor([10, 20]) +explanation = explainer(data.x, data.edge_index, index=node_index) +print(f'Generated explanations in {explanation.available_explanations}') diff --git a/examples/explain/graphmask_explainer.py b/examples/explain/graphmask_explainer.py index 88eb8dbe9c4ca..c7de53cf75428 100644 --- a/examples/explain/graphmask_explainer.py +++ b/examples/explain/graphmask_explainer.py @@ -3,16 +3,23 @@ import torch import torch.nn.functional as F +import torch_geometric from torch_geometric.datasets import Planetoid from torch_geometric.explain import Explainer, GraphMaskExplainer -from torch_geometric.nn import GATConv, GCNConv +from torch_geometric.nn import GCNConv -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +if torch.cuda.is_available(): + device = torch.device('cuda') +elif torch_geometric.is_xpu_available(): + device = torch.device('xpu') +else: + device = torch.device('cpu') path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid') dataset = Planetoid(path, name='Cora') data = dataset[0].to(device) + # GCN Node Classification ===================================================== @@ -56,47 +63,3 @@ def forward(self, x, edge_index): node_index = 10 explanation = explainer(data.x, data.edge_index, index=node_index) print(f'Generated explanations in {explanation.available_explanations}') - -# GAT Node Classification ===================================================== - - -class GAT(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1 = GATConv(dataset.num_features, 8, heads=8) - self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False) - - def forward(self, x, edge_index): - x = self.conv1(x, edge_index).relu() - x = F.dropout(x, training=self.training) - x = self.conv2(x, edge_index) - return F.log_softmax(x, dim=1) - - -model = GAT().to(device) -optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) - -for epoch in range(1, 201): - model.train() - optimizer.zero_grad() - out = model(data.x, data.edge_index) - loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) - loss.backward() - optimizer.step() - -explainer = Explainer( - model=model, - algorithm=GraphMaskExplainer(2, epochs=5), - explanation_type='model', - node_mask_type='attributes', - edge_mask_type='object', - model_config=dict( - mode='multiclass_classification', - task_level='node', - return_type='log_probs', - ), -) - -node_index = torch.tensor([10, 20]) -explanation = explainer(data.x, data.edge_index, index=node_index) -print(f'Generated explanations in {explanation.available_explanations}')