From c42271d4bd53932d139b3d6d16d7d11088828327 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 8 Jul 2024 11:53:47 +0000 Subject: [PATCH] update --- CHANGELOG.md | 1 + torch_geometric/nn/conv/collect.jinja | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d1b64a1ab1dc..4d9a45259a8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494)) - Added support for modifying `filename` of the stored partitioned file in `ClusterLoader` ([#9448](https://github.com/pyg-team/pytorch_geometric/pull/9448)) - Support other than two-dimensional inputs in `AttentionalAggregation` ([#9433](https://github.com/pyg-team/pytorch_geometric/pull/9433)) - Improved model performance of the `examples/ogbn_papers_100m.py` script ([#9386](https://github.com/pyg-team/pytorch_geometric/pull/9386), [#9445](https://github.com/pyg-team/pytorch_geometric/pull/9445)) diff --git a/torch_geometric/nn/conv/collect.jinja b/torch_geometric/nn/conv/collect.jinja index 58a42007a819..480b10ec109d 100644 --- a/torch_geometric/nn/conv/collect.jinja +++ b/torch_geometric/nn/conv/collect.jinja @@ -96,6 +96,20 @@ def {{collect_name}}( else: raise NotImplementedError +{%- if 'edge_weight' in collect_param_dict and + collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %} + if torch.jit.is_scripting(): + assert edge_weight is not None +{%- elif 'edge_attr' in collect_param_dict and + collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %} + if torch.jit.is_scripting(): + assert edge_attr is not None +{%- elif 'edge_type' in collect_param_dict and + collect_param_dict['edge_type'].type_repr.endswith('Tensor') %} + if torch.jit.is_scripting(): + assert edge_type is not None +{%- endif %} + # Collect user-defined arguments: {%- for name in collect_param_dict %} {%- if (name.endswith('_i') or name.endswith('_j')) and