Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[operator] add bspmm_sum operator #210

Merged
merged 4 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions gammagl/layers/conv/gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import tensorlayerx as tlx
from gammagl.layers.conv import MessagePassing
from gammagl.utils import segment_softmax


from gammagl.mpops import bspmm


class GATConv(MessagePassing):
Expand Down Expand Up @@ -79,10 +78,14 @@ def __init__(self,
self.linear = tlx.layers.Linear(out_features=self.out_channels * self.heads,
in_features=self.in_channels,
b_init=None)

init_weight = tlx.initializers.TruncatedNormal()
self.w = tlx.nn.Parameter(
init_weight((in_channels, self.out_channels * self.heads)))

initor = tlx.initializers.TruncatedNormal()
self.att_src = self._get_weights("att_src", shape=(1, self.heads, self.out_channels), init=initor, order=True)
self.att_dst = self._get_weights("att_dst", shape=(1, self.heads, self.out_channels), init=initor, order=True)
self.att = tlx.nn.Parameter(
initor((1, self.heads, self.out_channels * 2)))

self.leaky_relu = tlx.layers.LeakyReLU(negative_slope)
self.dropout = tlx.layers.Dropout(self.dropout_rate)
Expand All @@ -91,22 +94,23 @@ def __init__(self,
self.bias = self._get_weights("bias", shape=(self.heads * self.out_channels,), init=initor)
elif self.add_bias and not concat:
self.bias = self._get_weights("bias", shape=(self.out_channels,), init=initor)

def message(self, x, edge_index, edge_weight=None, num_nodes=None):

def forward(self, x, edge_index, num_nodes=None):
x = tlx.matmul(x, self.w)
x = tlx.reshape(x, shape=(-1, self.heads, self.out_channels))
node_src = edge_index[0, :]
node_dst = edge_index[1, :]
weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src)
weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst)
weight = self.leaky_relu(weight_src + weight_dst)
feat_src = tlx.gather(x, node_src)
feat_dst = tlx.gather(x, node_dst)
feat = tlx.concat((feat_src, feat_dst), axis=-1)
feat = tlx.reshape(feat, shape=(-1, self.heads, self.out_channels * 2))
e = tlx.reduce_sum(feat * self.att, axis = -1)

alpha = self.dropout(segment_softmax(weight, node_dst, num_nodes))
x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1)
return x * edge_weight if edge_weight else x
e = self.leaky_relu(e)
alpha = self.dropout(segment_softmax(e, node_dst, num_nodes))


def forward(self, x, edge_index, num_nodes=None):
x = tlx.reshape(self.linear(x), shape=(-1, self.heads, self.out_channels))
x = self.propagate(x, edge_index, num_nodes=num_nodes)
x = self.propagate(x, edge_index, num_nodes=num_nodes, edge_weight=alpha)
# x = bspmm(edge_index, weight=alpha, x=x, reduce='sum')

if self.concat:
x = tlx.reshape(x, (-1, self.heads * self.out_channels))
Expand Down
3 changes: 3 additions & 0 deletions gammagl/mpops/mindspore.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def segment_max(x, segment_ids, num_segments=None):

def gspmm(index, weight=None, x=None, reduce='sum'):
pass

def bspmm(index, weight=None, x=None, reduce='sum'):
pass
3 changes: 3 additions & 0 deletions gammagl/mpops/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,6 @@ def _scatter(x, index, updates, overwrite=True):

def gspmm(index, weight=None, x=None, reduce='sum'):
pass

def bspmm(index, weight=None, x=None, reduce='sum'):
pass
3 changes: 3 additions & 0 deletions gammagl/mpops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def segment_min(x, segment_ids, num_segments=None):

def gspmm(index, weight=None, x=None, reduce='sum'):
pass

def bspmm(index, weight=None, x=None, reduce='sum'):
pass
16 changes: 15 additions & 1 deletion gammagl/mpops/torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
use_ext = False
try:
from .torch_ext._torch_ext import c_segment_sum, c_segment_mean, c_segment_max, c_spmm_sum, c_spmm_mean, c_spmm_max
from .torch_ext._torch_ext import c_segment_sum, c_segment_mean, c_segment_max, c_spmm_sum, c_spmm_mean, c_spmm_max, c_bspmm_sum
use_ext = True
except:
pass
Expand Down Expand Up @@ -297,3 +297,17 @@ def gspmm(index, weight=None, x=None, reduce='sum'):
return c_spmm_max(index, weight, x)
else:
raise Exception("Unsupported reduce type, please choose from ['sum', 'mean', 'max'].")


def bspmm(index, weight=None, x=None, reduce='sum'):
if weight == None:
weight = torch.ones(size=(index.shape[1], ), dtype=torch.float32)
if reduce == 'sum':
return c_bspmm_sum(index, weight, x)
# elif reduce == 'mean':
# return c_spmm_mean(index, weight, x)
# elif reduce == 'max':
# return c_spmm_max(index, weight, x)
else:
# raise Exception("Unsupported reduce type, please choose from ['sum', 'mean', 'max'].")
raise Exception("Unsupported reduce type, please choose from ['sum'].")
102 changes: 102 additions & 0 deletions gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "./bspmm_sum_cpu.h"
#include <torch/torch.h>
#include "ATen/core/TensorBody.h"

torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x){
if (!x.is_contiguous()) {
x = x.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}

int num_nodes = x.size(0);
int heads = x.size(1);
int out_channels = x.size(2);

torch::Tensor out = torch::zeros_like(x, x.options());
auto E = index.size(1);
// auto K = x.numel() / x.size(0);

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];

for (auto h = 0; h < heads; ++h){
for (auto k = 0; k < out_channels; ++k){
#ifdef COMPILE_WITH_OMP
#pragma omp atomic
#endif
out_data[dst * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] * x_data[src * out_channels * heads + h * out_channels + k];
}
}
}
return out;
}

std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, torch::Tensor &grad) {
if (!grad.is_contiguous()) {
grad = grad.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}

int num_nodes = grad.size(0);
int heads = grad.size(1);
int out_channels = grad.size(2);

torch::Tensor grad_x = torch::zeros_like(grad, grad.options());
torch::Tensor grad_weight = torch::zeros_like(weight, weight.options());
auto E = index.size(1);
// auto K = grad.numel() / grad.size(0);

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto grad_data = grad.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();

// 计算反向传播的梯度
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];

for (auto h = 0; h < heads; ++h){
for (auto k = 0; k < out_channels; ++k){
#ifdef COMPILE_WITH_OMP
#pragma omp atomic
#endif
grad_x_data[src * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] * grad_data[dst * out_channels * heads + h * out_channels + k];

grad_weight_data[e * heads + h] += x_data[src * out_channels * heads + h * out_channels + k] *
grad_data[dst * out_channels * heads + h * out_channels + k];

}
}
}
// return {grad_x, grad_weight};
return std::make_tuple(grad_x, grad_weight);
}
6 changes: 6 additions & 0 deletions gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <torch/torch.h>

torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight,
torch::Tensor &x);
std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x,
torch::Tensor &grad);
8 changes: 8 additions & 0 deletions gammagl/mpops/torch_ext/include/gspmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,11 @@ class SpMMMax : public torch::autograd::Function<SpMMMax> {
static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext *ctx,
std::vector<torch::Tensor> grad_outs);
};

class BSpMMSum : public torch::autograd::Function<BSpMMSum> {
public:
static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index,
torch::Tensor weight, torch::Tensor x);
static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext *ctx,
std::vector<torch::Tensor> grad_outs);
};
52 changes: 52 additions & 0 deletions gammagl/mpops/torch_ext/src/gspmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "../cpu/spmm_sum_cpu.h"
#include "../cpu/spmm_mean_cpu.h"
#include "../cpu/spmm_max_cpu.h"
#include "../cpu/bspmm_sum_cpu.h"

#ifdef COMPILE_WITH_CUDA
#include "../cuda/spmm_sum_cuda.h"
#endif
Expand Down Expand Up @@ -171,3 +173,53 @@ std::vector<torch::Tensor> SpMMMax::backward(torch::autograd::AutogradContext *c

return {torch::Tensor(), torch::Tensor(), grad_x};
}


torch::Tensor BSpMMSum::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index,
torch::Tensor weight, torch::Tensor x) {
ctx->save_for_backward({index, weight, x});
ctx->mark_non_differentiable({index, weight});
torch::Tensor out;
// CUDA
if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) {
// #ifdef COMPILE_WITH_CUDA
// out = bspmm_sum_cuda_forward(index, weight, x);
// #else
AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU.");
// #endif
}
// CPU
else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) {
out = bspmm_sum_cpu_forward(index, weight, x);
} else {
AT_ERROR("Tensor device inconsistent error.");
}

return out;
}

std::vector<torch::Tensor> BSpMMSum::backward(torch::autograd::AutogradContext *ctx, std::vector<torch::Tensor> grad_outs) {
auto saved = ctx->get_saved_variables();
auto index = saved[0], weight = saved[1], x = saved[2];
auto grad = grad_outs[0];
torch::Tensor grad_x, grad_weight;

// CUDA
if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) {
// #ifdef COMPILE_WITH_CUDA
// grad_x = bspmm_sum_cuda_backward(index, weight, grad);
// #else
AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU.");
// #endif
}
// CPU
else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) {
auto result = bspmm_sum_cpu_backward(index, weight, x, grad);
grad_x = std::get<0>(result);
grad_weight = std::get<1>(result);
} else {
AT_ERROR("Tensor device inconsistent error.");
}

return {torch::Tensor(), grad_weight, grad_x};
}
7 changes: 7 additions & 0 deletions gammagl/mpops/torch_ext/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ torch::Tensor spmm_max(torch::Tensor index, torch::Tensor weight, torch::Tensor
return SpMMMax::apply(index, weight, x);
}

torch::Tensor bspmm_sum(torch::Tensor index, torch::Tensor weight,
torch::Tensor x) {
auto result = BSpMMSum::apply(index, weight, x);
return result;
}

PYBIND11_MODULE(_torch_ext, m) {
m.def("c_segment_max", segment_max);
m.def("c_segment_sum", segment_sum);
m.def("c_segment_mean", segment_mean);
m.def("c_spmm_sum", spmm_sum);
m.def("c_spmm_mean", spmm_mean);
m.def("c_spmm_max", spmm_max);
m.def("c_bspmm_sum", bspmm_sum);
}
Loading