-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Add graph transformer #10313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
omarkhater
wants to merge
114
commits into
pyg-team:master
Choose a base branch
from
omarkhater:add_graph_transformer
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add graph transformer #10313
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…tests - Introduce GraphTransformer module: - _readout() → global_mean_pool aggregation - classifier → nn.Linear(hidden_dim → num_classes) - Add parameterized pytest to verify output shape across: * various num_graphs, num_nodes, feature_dims, num_classes - Establishes TDD baseline for subsequent transformer layers
feat: add super-node ([CLS] token) readout support - Introduce `use_super_node` flag and `cls_token` parameter - Implement `_add_cls_token` and branch in `_readout` - Add `test_super_node_readout` to verify zero-bias behavior
- Introduce `node_feature_encoder` argument (default Identity) - Wire encoder into forward via `_encode_nodes` - Add `test_node_feature_encoder_identity` to verify scaling logic
Introduce the core Transformer‐backbone scaffold without changing model behavior: - Add `IdentityLayer(nn.Module)` that accepts `(x, batch)` and simply returns `x`. - Create `EncoderLayers` container wrapping a `ModuleList` of `IdentityLayer` instances. - Wire `self.encoder = EncoderLayers(num_encoder_layers)` into `GraphTransformer.__init__`. - Update `forward()` to call `x = self.encoder(x, data.batch)` before readout. - Preserve existing functionality: mean‐pool readout, super‐node, and node‐feature encoder all still work. - Add `test_transformer_block_identity` to verify that replacing one stub layer with a custom module (e.g. `AddOneLayer`) correctly affects the output shape.
- Test that gradients flow properly through the GraphTransformer model - Test that GraphTransformer can be traced with TorchScript
…p stubs This commit completes Cycle 4-A (RED → GREEN → REFACTOR): the codebase now contains a real, pluggable encoder skeleton ready for multi-head attention and feed-forward logic in the next mini-cycle, while all existing behaviour, TorchScript traceability, and test guarantees remain intact - Added =============== * torch_geometric.contrib.nn.layers.transformer.GraphTransformerEncoderLayer * contains the two LayerNorm blocks that will precede MHA & FFN in next iteration * forwards x, batch unchanged for now (keeps tests green) Changed ====================== - GraphTransformer * builds either a single GraphTransformerEncoderLayer (num_encoder_layers==0) or a full GraphTransformerEncoder stack. * keeps the existing public API and super-node/mean-pool read-outs. - Test-suite * rewired to use the new stack (model.encoder[0] instead of model.encoder.layers[0]), added test_encoder_layer_type that asserts the first layer is a GraphTransformerEncoderLayer - GraphTransformerEncoder – thin wrapper that holds an arbitrary number of identical encoder layers and exposes __len__, __getitem__, __setitem__ so tests (and users) can hot-swap layers. Removed =========== Legacy scaffolding IdentityLayer and EncoderLayers that were only placeholders in Cycle 3
- Introduces configurable num_heads, dropout, ffn_hidden_dim, and activation params to GraphTransformerEncoderLayer - Adds a placeholder self.self_attn = nn.Identity() so the interface is ready for real multi-head attention in a future cycle - Implements the full Add & Norm ➜ FFN ➜ Add & Norm pattern, meaning the layer now changes the tensor values (tests confirm) while preserving shape - Exposes a JIT-ignored static helper _build_key_padding_mask (returns None for now) to prepare for masking logic - Updates GraphTransformer to forward the new constructor args.
… mask - adds nn.MultiheadAttention(batch_first=True) to encoder layer - implements per-graph key_padding_mask so distinct graphs don’t cross-talk - updates tests to confirm the layer now transforms identical rows - refactors shape helpers for readability
Changed test to verify encoder affects outputs rather than assuming direct pooling Use Identity encoder for simpler test Compare outputs with raw vs scaled features Updated assertion to check that scaling changes output
Introduces an optional degree_encoder hook to enrich node embeddings with structural information (in-/out-degree, or any custom degree-based feature). When provided, the encoder’s output is added to the raw (optionally pre-encoded) node features before the transformer stack, enabling degree awareness without altering existing workflows.
Extract “sum up extra encoders” into a private helper to keep forward tidy
Ensures logits differ when a structural mask is supplied.
Keeps all earlier behaviour while unlocking arbitrary structural masks.
Add attn_mask to AddOneLayer
* `_prepend_cls_token_flat` adds one learnable row per graph and returns the updated batch vector. * `forward` uses the new helper so the encoder still receives a flat tensor. * `_readout` now picks CLS rows via vectorised indexing (`first_idx = cumsum(graph_sizes)-graph_sizes`). * Removes the temporary _add_cls_token call inside forward. This lets the CLS token flow through attention while keeping the encoder API unchanged; the new `test_cls_token_transformation` turns green.
…mation the CLS token is used directly in readout, but our model now transforms it through the encoder.
unify encoder invocation into single loop, removing separate single-layer vs. stack branches in _run_encoder
so that .to(device) picks it up automatically with the rest of the model.
They are now part of centralized conftest
…AttnSpatialBias Add BaseBiasProvider._pad_to_tensor helper to zero-pad ragged (Ni×Ni) matrices into a uniform (B×L×L) tensor. Update GraphAttnSpatialBias._extract_raw_distances to slice out true Ni×Ni blocks (rows and columns) from data.spatial_pos rather than only row-slicing, enabling support for batched graphs of differing sizes. Add Tests test_spatial_bias_handles_heterogeneous_blockdiag_spatial_pos: verifies that a block-diagonal spatial_pos with unequal block sizes no longer crashes but is correctly padded. test_spatial_bias_supports_variable_graph_sizes: confirms that arbitrary graph‐size lists can be passed and handled. Update fixtures in conftest.py and test_spatial_bias.py to supply variable‐sized and padded spatial_pos inputs for the new tests.
…iders Refactor GraphAttnHopBias._extract_raw_distances to: Read data.hop_dist and data.ptr, Slice out each Ni×Ni block for graph i, Return a List[Tensor] of per-graph blocks for uniform padding downstream. Add a clear Google-style docstring explaining inputs, outputs, and errors. Refactor GraphAttnEdgeBias._extract_raw_distances to: Require data.ptr, If data.edge_dist is None, build zero‐filled Ni×Ni blocks; Else slice each diagonal Ni×Ni block from the flat matrix; Return a List[Tensor] for consistent padding. Expand the docstring to describe arguments, return, and exceptions. Leverage the existing BaseBiasProvider._pad_to_tensor to zero-pad ragged blocks into a single (B, L, L) tensor and then optionally super-node pad. Add Tests (test_hop_bias_handles_variable_graph_sizes and test_edge_bias_handles_variable_graph_sizes) with block-diagonal fixtures to verify correct slicing and padding of irregular batched inputs.
- Add unit tests to reproduce the bug. - Fix the bug when moving the model to GPU associated with positional encoders.
Resolves issue where GraphTransformer fails on datasets without node features (e.g., QM7B regression) with error 'linear(): argument 'input' (position 1) must be Tensor, not NoneType'. Changes: - Add robust None handling in _encode_and_apply_structural() method - Add test_graph_transformer_none_features_handled() test case
for more information, see https://pre-commit.ci
- Avoid function call in argument defaults - Raise value error in _encode_nodes if input features are None or their dimension does not match the encoder's expected input
for more information, see https://pre-commit.ci
Handle Sequential and Identity node encoders for regression tasks - If node_feature_encoder is nn.Identity, the input feature dimension must match hidden_dim; otherwise, a clear ValueError is raised. - For other encoders (e.g., nn.Linear, nn.Sequential), the input feature dimension is checked recursively against the first Linear layer's in_features. - Updated and parameterized the test suite to cover all relevant cases, including nn.Identity, nn.Linear, and nn.Sequential with both matching and mismatched input dimensions.
…/pytorch_geometric into add_graph_transformer
…/pytorch_geometric into add_graph_transformer
for more information, see https://pre-commit.ci
Complexity before the change: PS D:\projects\pytorch_geometric> radon cc -s -a torch_geometric/contrib/nn/models/graph_transformer.py torch_geometric/contrib/nn/models/graph_transformer.py M 112:4 GraphTransformer._validate_init_args - D (23) M 52:4 GraphTransformer.__init__ - B (8) C 18:0 GraphTransformer - B (6) M 270:4 GraphTransformer._encode_nodes - B (6) M 335:4 GraphTransformer._encode_and_apply_structural - B (6) M 440:4 GraphTransformer._collect_attn_bias - B (6) M 258:4 GraphTransformer._find_in_features - A (5) M 410:4 GraphTransformer._run_encoder - A (5) M 469:4 GraphTransformer.__repr__ - A (5) M 309:4 GraphTransformer.forward - A (3) M 366:4 GraphTransformer._apply_gnn_if - A (3) M 199:4 GraphTransformer._readout - A (2) M 385:4 GraphTransformer._prepare_batch - A (2) M 402:4 GraphTransformer._is_parallel - A (2) M 219:4 GraphTransformer._prepend_cls_token_flat - A (1) 15 blocks (classes, functions, methods) analyzed. Average complexity: B (5.533333333333333) Complexity after the change: M 52:4 GraphTransformer.__init__ - B (8) M 149:4 GraphTransformer._validate_dimensions - B (7) M 348:4 GraphTransformer._encode_nodes - B (6) M 413:4 GraphTransformer._encode_and_apply_structural - B (6) M 518:4 GraphTransformer._collect_attn_bias - B (6) M 336:4 GraphTransformer._find_in_features - A (5) M 488:4 GraphTransformer._run_encoder - A (5) M 547:4 GraphTransformer.__repr__ - A (5) C 18:0 GraphTransformer - A (4) M 205:4 GraphTransformer._validate_ffn_dim - A (4) M 239:4 GraphTransformer._validate_bias_providers - A (4) M 252:4 GraphTransformer._validate_gnn_config - A (4) M 176:4 GraphTransformer._validate_num_heads - A (3) M 191:4 GraphTransformer._validate_dropout - A (3) M 267:4 GraphTransformer._validate_positional_encoders - A (3) M 387:4 GraphTransformer.forward - A (3) M 444:4 GraphTransformer._apply_gnn_if - A (3) M 220:4 GraphTransformer._validate_activation - A (2) M 277:4 GraphTransformer._readout - A (2) M 463:4 GraphTransformer._prepare_batch - A (2) M 480:4 GraphTransformer._is_parallel - A (2) M 112:4 GraphTransformer._validate_init_args - A (1) M 167:4 GraphTransformer._validate_transformer_params - A (1) M 297:4 GraphTransformer._prepend_cls_token_flat - A (1) 24 blocks (classes, functions, methods) analyzed. Average complexity: A (3.75) MI Before torch_geometric\contrib\nn\models\graph_transformer.py - A (45.09) MI After: torch_geometric\contrib\nn\models\graph_transformer.py - A (43.53)
for more information, see https://pre-commit.ci
Use alias imports to avoid multi-line import formatting conflicts between yapf and isort tools. The alias import pattern (import module as _alias; Class = _alias.Class) prevents the formatting loop that occurred with parenthesized multi-line imports exceeding the 79-character line limit.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR introduces a modular and extensible GraphTransformer model for PyTorch Geometric, following the architecture described in "Transformer for Graphs: An Overview from Architecture Perspective".
Key Features
Benchmarking
I have thoroughly tested this model on a variety of node and graph classification benchmarks, including CORA, CITSEER, PUBMED, MUTAG, and PROTEINS. Results and reproducible scripts are available in my benchmarking repository:
👉 omarkhater/benchmark_graph_transformer
The model achieves strong performance across all tested datasets and supports a wide range of configurations (see the benchmark repo for details).
Future Work
If desirable, I am happy to continue working on this model to add regression task support and address any additional requirements from the PyG team.
Let me know if you have feedback or requests for further improvements!