Skip to content

Commit

Permalink
Add attention explainer example
Browse files Browse the repository at this point in the history
  • Loading branch information
chaojun-zhang committed Jul 3, 2024
1 parent f0ef2de commit dc24226
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 48 deletions.
4 changes: 2 additions & 2 deletions examples/compile/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def __init__(self, in_channels, hidden_channels, out_channels):

def forward(self, x, edge_index, edge_weight):
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv1(x, edge_index, edge_weight).relu()
# x = self.conv1(x, edge_index, edge_weight).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
# x = self.conv2(x, edge_index)
return x


Expand Down
65 changes: 65 additions & 0 deletions examples/explain/attention_exaplainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os.path as osp

import torch
import torch.nn.functional as F

import torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, AttentionExplainer
from torch_geometric.nn import GATConv

if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora')
data = dataset[0].to(device)


# GAT Node Classification =====================================================

class GAT(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6)
self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False, dropout=0.6)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x


model = GAT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

explainer = Explainer(
model=model,
algorithm=AttentionExplainer(),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)

node_index = torch.tensor([10, 20])
explanation = explainer(data.x, data.edge_index, index=node_index)
print(f'Generated explanations in {explanation.available_explanations}')
55 changes: 9 additions & 46 deletions examples/explain/graphmask_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@
import torch
import torch.nn.functional as F

import torch_geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.explain import Explainer, GraphMaskExplainer
from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.nn import GCNConv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch_geometric.is_xpu_available():
device = torch.device('xpu')
else:
device = torch.device('cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora')
data = dataset[0].to(device)


# GCN Node Classification =====================================================


Expand Down Expand Up @@ -56,47 +63,3 @@ def forward(self, x, edge_index):
node_index = 10
explanation = explainer(data.x, data.edge_index, index=node_index)
print(f'Generated explanations in {explanation.available_explanations}')

# GAT Node Classification =====================================================


class GAT(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GATConv(dataset.num_features, 8, heads=8)
self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False)

def forward(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)


model = GAT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

explainer = Explainer(
model=model,
algorithm=GraphMaskExplainer(2, epochs=5),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='multiclass_classification',
task_level='node',
return_type='log_probs',
),
)

node_index = torch.tensor([10, 20])
explanation = explainer(data.x, data.edge_index, index=node_index)
print(f'Generated explanations in {explanation.available_explanations}')

0 comments on commit dc24226

Please sign in to comment.