Skip to content

Conversation

@mariogeiger
Copy link
Collaborator

@mariogeiger mariogeiger commented Feb 26, 2025

SegmentedPolynomial

This PR mainly introduce cue.SegmentedPolynomial. Here is an example of how it will be used as the new descriptor:

import cuequivariance as cue

eq_poly: cue.EquivariantPolynomial = cue.descriptors.symmetric_contraction(
    cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
    cue.Irreps(cue.O3, "32x0e + 32x1o"),
    [0, 1, 2, 3, 4],
)
poly: cue.SegmentedPolynomial = eq_poly.polynomial
print(poly)

This polynomial descriptor describe an operation with its input and output buffers and a list of operations and STP associated.

╭ a=[960:30⨯(32)] b=[224:7⨯(32)] -> C=[128:4⨯(32)]
│  []·a[u]➜C[u] ───────────────────── num_paths=1 u=32
│  []·a[u]·b[u]➜C[u] ──────────────── num_paths=4 u=32
│  []·a[u]·b[u]·b[u]➜C[u] ─────────── num_paths=25 u=32
│  []·a[u]·b[u]·b[u]·b[u]➜C[u] ────── num_paths=157 u=32
╰─ []·a[u]·b[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=1015 u=32

It has many useful transformation operations like for instance backward.

poly = poly.backward([True, True], [True])
print(poly)
╭ a=[960:30⨯(32)] b=[224:7⨯(32)] c=[128:4⨯(32)] -> D=[960:30⨯(32)] E=[224:7⨯(32)]
│  []·c[u]➜D[u] ───────────────────── num_paths=1 u=32
│  []·b[u]·c[u]➜D[u] ──────────────── num_paths=4 u=32
│  []·b[u]·b[u]·c[u]➜D[u] ─────────── num_paths=25 u=32
│  []·b[u]·b[u]·b[u]·c[u]➜D[u] ────── num_paths=157 u=32
│  []·b[u]·b[u]·b[u]·b[u]·c[u]➜D[u] ─ num_paths=1015 u=32
│  []·a[u]·c[u]➜E[u] ──────────────── num_paths=4 u=32
│  []·a[u]·b[u]·c[u]➜E[u] ─────────── num_paths=25 u=32
│  []·a[u]·b[u]·b[u]·c[u]➜E[u] ────── num_paths=157 u=32
╰─ []·a[u]·b[u]·b[u]·b[u]·c[u]➜E[u] ─ num_paths=1015 u=32
eq_poly: cue.EquivariantPolynomial = cue.descriptors.fully_connected_tensor_product(
    cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
    cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
    cue.Irreps(cue.O3, "32x0e + 32x1o + 32x1e"),
)
poly: cue.SegmentedPolynomial = eq_poly.polynomial
print(poly)

poly = poly.flatten_coefficient_modes()
print(poly)
╭ a=[360448:11⨯(32,32,32)] b=[224] c=[224] -> D=[224]
╰─ [ijk]·a[uvw]·b[iu]·c[jv]➜D[kw] ─ num_paths=11 i={1, 3} j={1, 3} k={1, 3} u=32 v=32 w=32

╭ a=[360448:11⨯(32,32,32)] b=[224:7⨯(32)] c=[224:7⨯(32)] -> D=[224:7⨯(32)]
╰─ []·a[uvw]·b[u]·c[v]➜D[w] ─ num_paths=43 u=32 v=32 w=32

Execution with cuex

import numpy as np
import cuequivariance_jax as cuex
import jax

inputs = [np.random.randn(10, ope.size) for ope in poly.inputs]
outputs = [
    jax.ShapeDtypeStruct(shape=(10, ope.size), dtype=np.float32) for ope in poly.outputs
]

cuex.segmented_polynomial(poly, inputs, outputs)

@mariogeiger mariogeiger requested a review from moradza March 11, 2025 15:39
@mariogeiger mariogeiger self-assigned this Mar 11, 2025
@mariogeiger mariogeiger merged commit 5861d94 into main Mar 16, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant