Skip to content

Commit 4a03485

Browse files
committed
update
1 parent c8cd4de commit 4a03485

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ jobs:
6464
if: steps.changed-files-specific.outputs.only_changed != 'true'
6565
timeout-minutes: 10
6666
run: |
67-
pytest --cov --cov-report=xml --durations 10
67+
FULL_TEST=1 pytest --cov --cov-report=xml --durations 10
6868
6969
- name: Upload coverage
7070
if: ${{ steps.changed-files-specific.outputs.only_changed != 'true' && runner.os == 'Linux' }}

test/nn/conv/test_message_passing.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,24 @@ def test_pickle(tmp_path):
740740

741741
model = torch.load(path)
742742
torch.jit.script(model)
743+
744+
745+
class MyOptionalEdgeAttrConv(MessagePassing):
746+
def __init__(self):
747+
super().__init__()
748+
749+
def forward(self, x, edge_index, edge_attr=None):
750+
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
751+
752+
def message(self, x_j, edge_attr=None):
753+
return x_j if edge_attr is None else x_j * edge_attr.view(-1, 1)
754+
755+
756+
def test_my_optional_edge_attr_conv():
757+
conv = MyOptionalEdgeAttrConv()
758+
759+
x = torch.randn(4, 8)
760+
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
761+
762+
out = conv(x, edge_index)
763+
assert out.size() == (4, 8)

torch_geometric/nn/conv/collect.jinja

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,6 @@ def {{collect_name}}(
9696
else:
9797
raise NotImplementedError
9898

99-
{%- if 'edge_weight' in collect_param_dict and
100-
collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
101-
assert edge_weight is not None
102-
{%- elif 'edge_attr' in collect_param_dict and
103-
collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}
104-
assert edge_attr is not None
105-
{%- elif 'edge_type' in collect_param_dict and
106-
collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}
107-
assert edge_type is not None
108-
{%- endif %}
109-
11099
# Collect user-defined arguments:
111100
{%- for name in collect_param_dict %}
112101
{%- if (name.endswith('_i') or name.endswith('_j')) and

0 commit comments

Comments
 (0)