Skip to content

Add graph transformer #10313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 114 commits into
base: master
Choose a base branch
from

Conversation

omarkhater
Copy link

This PR introduces a modular and extensible GraphTransformer model for PyTorch Geometric, following the architecture described in "Transformer for Graphs: An Overview from Architecture Perspective".

Key Features

  • Flexible architecture: Supports bias providers, positional encoders, GNN integration (pre/post/parallel), and an optional super node (CLS token).
  • Robust batching: Handles variable-sized graphs and key padding.
  • Extensible API: Easily supports new bias providers, positional encoders, and GNN blocks.
  • Comprehensive validation: Constructor checks for argument consistency and type safety.
  • Thorough testing: Unit tests cover forward/backward passes, hooks, encoders, bias, and edge cases.

Benchmarking

I have thoroughly tested this model on a variety of node and graph classification benchmarks, including CORA, CITSEER, PUBMED, MUTAG, and PROTEINS. Results and reproducible scripts are available in my benchmarking repository:
👉 omarkhater/benchmark_graph_transformer

The model achieves strong performance across all tested datasets and supports a wide range of configurations (see the benchmark repo for details).

Future Work

If desirable, I am happy to continue working on this model to add regression task support and address any additional requirements from the PyG team.


Let me know if you have feedback or requests for further improvements!

…tests

- Introduce GraphTransformer module:
  - _readout() → global_mean_pool aggregation
  - classifier → nn.Linear(hidden_dim → num_classes)
- Add parameterized pytest to verify output shape across:
    * various num_graphs, num_nodes, feature_dims, num_classes
- Establishes TDD baseline for subsequent transformer layers
feat: add super-node ([CLS] token) readout support

- Introduce `use_super_node` flag and `cls_token` parameter
- Implement `_add_cls_token` and branch in `_readout`
- Add `test_super_node_readout` to verify zero-bias behavior
- Introduce `node_feature_encoder` argument (default Identity)
- Wire encoder into forward via `_encode_nodes`
- Add `test_node_feature_encoder_identity` to verify scaling logic
Introduce the core Transformer‐backbone scaffold without changing model behavior:

- Add `IdentityLayer(nn.Module)` that accepts `(x, batch)` and simply returns `x`.
- Create `EncoderLayers` container wrapping a `ModuleList` of `IdentityLayer` instances.
- Wire `self.encoder = EncoderLayers(num_encoder_layers)` into `GraphTransformer.__init__`.
- Update `forward()` to call `x = self.encoder(x, data.batch)` before readout.
- Preserve existing functionality: mean‐pool readout, super‐node, and node‐feature encoder all still work.
- Add `test_transformer_block_identity` to verify that replacing one stub layer with a custom module (e.g. `AddOneLayer`) correctly affects the output shape.
- Test that gradients flow properly through the GraphTransformer model
- Test that GraphTransformer can be traced with TorchScript
…p stubs

This commit completes Cycle 4-A (RED → GREEN → REFACTOR): the codebase now contains a real, pluggable encoder skeleton ready for multi-head attention and feed-forward logic in the next mini-cycle, while all existing behaviour, TorchScript traceability, and test guarantees remain intact

- Added
===============
* torch_geometric.contrib.nn.layers.transformer.GraphTransformerEncoderLayer

* contains the two LayerNorm blocks that will precede MHA & FFN in  next iteration

* forwards x, batch unchanged for now (keeps tests green)

Changed
======================
- GraphTransformer

* builds either a single GraphTransformerEncoderLayer (num_encoder_layers==0) or a full GraphTransformerEncoder stack.

* keeps the existing public API and super-node/mean-pool read-outs.

- Test-suite

* rewired to use the new stack (model.encoder[0] instead of model.encoder.layers[0]), added test_encoder_layer_type that asserts the first layer is a GraphTransformerEncoderLayer

- GraphTransformerEncoder – thin wrapper that holds an arbitrary number of identical encoder layers and exposes __len__, __getitem__, __setitem__ so tests (and users) can hot-swap layers.

Removed
===========
Legacy scaffolding IdentityLayer and EncoderLayers that were only placeholders in Cycle 3
- Introduces configurable num_heads, dropout, ffn_hidden_dim, and activation params to GraphTransformerEncoderLayer
- Adds a placeholder self.self_attn = nn.Identity() so the interface is ready for real multi-head attention in a future cycle
- Implements the full Add & Norm ➜ FFN ➜ Add & Norm pattern, meaning the layer now changes the tensor values (tests confirm) while preserving shape
- Exposes a JIT-ignored static helper _build_key_padding_mask (returns None for now) to prepare for masking logic
- Updates GraphTransformer to forward the new constructor args.
… mask

- adds nn.MultiheadAttention(batch_first=True) to encoder layer
- implements per-graph key_padding_mask so distinct graphs don’t cross-talk
- updates tests to confirm the layer now transforms identical rows
- refactors shape helpers for readability
Changed test to verify encoder affects outputs rather than assuming direct pooling
Use Identity encoder for simpler test
Compare outputs with raw vs scaled features
Updated assertion to check that scaling changes output
Introduces an optional degree_encoder hook to enrich node embeddings with structural information (in-/out-degree, or any custom degree-based feature). When provided, the encoder’s output is added to the raw (optionally pre-encoded) node features before the transformer stack, enabling degree awareness without altering existing workflows.
Extract “sum up extra encoders” into a private helper to keep forward tidy
Ensures logits differ when a structural mask is supplied.
Keeps all earlier behaviour while unlocking arbitrary structural masks.
Add attn_mask to AddOneLayer
* `_prepend_cls_token_flat` adds one learnable row per graph and returns the
  updated batch vector.
* `forward` uses the new helper so the encoder still receives a flat tensor.
* `_readout` now picks CLS rows via vectorised indexing
  (`first_idx = cumsum(graph_sizes)-graph_sizes`).
* Removes the temporary _add_cls_token call inside forward.

This lets the CLS token flow through attention while keeping the encoder API
unchanged; the new `test_cls_token_transformation` turns green.
…mation

the CLS token is used directly in readout, but our model now transforms it through the encoder.
omarkhater-school and others added 14 commits April 30, 2025 15:10
unify encoder invocation into single loop, removing separate single-layer vs. stack branches in _run_encoder
so that .to(device) picks it up automatically with the rest of the model.
…AttnSpatialBias

Add BaseBiasProvider._pad_to_tensor helper to zero-pad ragged (Ni×Ni)
matrices into a uniform (B×L×L) tensor.

Update GraphAttnSpatialBias._extract_raw_distances to slice out true
Ni×Ni blocks (rows and columns) from data.spatial_pos rather than
only row-slicing, enabling support for batched graphs of differing sizes.

Add Tests

test_spatial_bias_handles_heterogeneous_blockdiag_spatial_pos:
verifies that a block-diagonal spatial_pos with unequal block sizes
no longer crashes but is correctly padded.

test_spatial_bias_supports_variable_graph_sizes:
confirms that arbitrary graph‐size lists can be passed and handled.

Update fixtures in conftest.py and test_spatial_bias.py to supply
variable‐sized and padded spatial_pos inputs for the new tests.
…iders

Refactor GraphAttnHopBias._extract_raw_distances to:

Read data.hop_dist and data.ptr,

Slice out each Ni×Ni block for graph i,

Return a List[Tensor] of per-graph blocks for uniform padding downstream.

Add a clear Google-style docstring explaining inputs, outputs, and errors.

Refactor GraphAttnEdgeBias._extract_raw_distances to:

Require data.ptr,

If data.edge_dist is None, build zero‐filled Ni×Ni blocks;

Else slice each diagonal Ni×Ni block from the flat matrix;

Return a List[Tensor] for consistent padding.

Expand the docstring to describe arguments, return, and exceptions.

Leverage the existing BaseBiasProvider._pad_to_tensor to zero-pad ragged
blocks into a single (B, L, L) tensor and then optionally super-node pad.

Add Tests (test_hop_bias_handles_variable_graph_sizes and
test_edge_bias_handles_variable_graph_sizes) with block-diagonal fixtures
to verify correct slicing and padding of irregular batched inputs.
- Add unit tests to reproduce the bug.
- Fix the bug when moving the model to GPU associated with positional encoders.
Resolves issue where GraphTransformer fails on datasets without node features (e.g., QM7B regression) with error 'linear(): argument 'input' (position 1) must be Tensor, not NoneType'.

Changes:
- Add robust None handling in _encode_and_apply_structural() method
- Add test_graph_transformer_none_features_handled() test case
@omarkhater omarkhater requested review from wsad1 and rusty1s as code owners June 9, 2025 18:03
pre-commit-ci bot and others added 14 commits June 9, 2025 18:05
- Avoid function call in argument defaults
- Raise value error in _encode_nodes if input features are None or their dimension does not match the encoder's expected input
Handle Sequential and Identity node encoders for regression tasks

- If node_feature_encoder is nn.Identity, the input feature dimension must match hidden_dim; otherwise, a clear ValueError is raised.

- For other encoders (e.g., nn.Linear, nn.Sequential), the input feature dimension is checked recursively against the first Linear layer's in_features.

- Updated and parameterized the test suite to cover all relevant cases, including nn.Identity, nn.Linear, and nn.Sequential with both matching and mismatched input dimensions.
Complexity before the change:

PS D:\projects\pytorch_geometric> radon cc -s -a torch_geometric/contrib/nn/models/graph_transformer.py
torch_geometric/contrib/nn/models/graph_transformer.py
    M 112:4 GraphTransformer._validate_init_args - D (23)
    M 52:4 GraphTransformer.__init__ - B (8)
    C 18:0 GraphTransformer - B (6)
    M 270:4 GraphTransformer._encode_nodes - B (6)
    M 335:4 GraphTransformer._encode_and_apply_structural - B (6)
    M 440:4 GraphTransformer._collect_attn_bias - B (6)
    M 258:4 GraphTransformer._find_in_features - A (5)
    M 410:4 GraphTransformer._run_encoder - A (5)
    M 469:4 GraphTransformer.__repr__ - A (5)
    M 309:4 GraphTransformer.forward - A (3)
    M 366:4 GraphTransformer._apply_gnn_if - A (3)
    M 199:4 GraphTransformer._readout - A (2)
    M 385:4 GraphTransformer._prepare_batch - A (2)
    M 402:4 GraphTransformer._is_parallel - A (2)
    M 219:4 GraphTransformer._prepend_cls_token_flat - A (1)

15 blocks (classes, functions, methods) analyzed.
Average complexity: B (5.533333333333333)

Complexity after the change:

M 52:4 GraphTransformer.__init__ - B (8)
    M 149:4 GraphTransformer._validate_dimensions - B (7)
    M 348:4 GraphTransformer._encode_nodes - B (6)
    M 413:4 GraphTransformer._encode_and_apply_structural - B (6)
    M 518:4 GraphTransformer._collect_attn_bias - B (6)
    M 336:4 GraphTransformer._find_in_features - A (5)
    M 488:4 GraphTransformer._run_encoder - A (5)
    M 547:4 GraphTransformer.__repr__ - A (5)
    C 18:0 GraphTransformer - A (4)
    M 205:4 GraphTransformer._validate_ffn_dim - A (4)
    M 239:4 GraphTransformer._validate_bias_providers - A (4)
    M 252:4 GraphTransformer._validate_gnn_config - A (4)
    M 176:4 GraphTransformer._validate_num_heads - A (3)
    M 191:4 GraphTransformer._validate_dropout - A (3)
    M 267:4 GraphTransformer._validate_positional_encoders - A (3)
    M 387:4 GraphTransformer.forward - A (3)
    M 444:4 GraphTransformer._apply_gnn_if - A (3)
    M 220:4 GraphTransformer._validate_activation - A (2)
    M 277:4 GraphTransformer._readout - A (2)
    M 463:4 GraphTransformer._prepare_batch - A (2)
    M 480:4 GraphTransformer._is_parallel - A (2)
    M 112:4 GraphTransformer._validate_init_args - A (1)
    M 167:4 GraphTransformer._validate_transformer_params - A (1)
    M 297:4 GraphTransformer._prepend_cls_token_flat - A (1)

24 blocks (classes, functions, methods) analyzed.
Average complexity: A (3.75)

MI Before

torch_geometric\contrib\nn\models\graph_transformer.py - A (45.09)

MI After:

torch_geometric\contrib\nn\models\graph_transformer.py - A (43.53)
Use alias imports to avoid multi-line import formatting conflicts between
yapf and isort tools. The alias import pattern (import module as _alias;
Class = _alias.Class) prevents the formatting loop that occurred with
parenthesized multi-line imports exceeding the 79-character line limit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants