From 1da6dde25ac3ef8891a497d2e6ffcbad664ee9d9 Mon Sep 17 00:00:00 2001 From: mdheller <21163552+mdheller@users.noreply.github.com> Date: Mon, 12 Jan 2026 10:37:43 -0500 Subject: [PATCH] Align Rust envelope IDs and add fixture round-trip checks --- .github/workflows/ci.yml | 27 + Makefile | 20 + README.md | 16 +- docs/REPOSITORY_GUIDE.md | 1 + docs/THEORY.md | 31 +- docs/audit_tritrpc_v1_parity.md | 25 + docs/integration_readiness_checklist.md | 50 ++ go/tritrpcv1/avrodec.go | 436 ++++++++++ go/tritrpcv1/avroenc.go | 261 +++--- go/tritrpcv1/cmd/trpc/main.go | 400 ++++----- go/tritrpcv1/envelope.go | 2 - go/tritrpcv1/envelope_decode.go | 143 ++++ go/tritrpcv1/fixtures_test.go | 277 +++---- go/tritrpcv1/go.sum | 1 + go/tritrpcv1/pathb_dec.go | 68 +- go/tritrpcv1/tleb3.go | 40 +- go/tritrpcv1/tritpack243.go | 3 +- go/tritrpcv1/tritrpcv1_test.go | 7 +- rust/tritrpc_v1/Cargo.toml | 1 + rust/tritrpc_v1/src/bin/trpc.rs | 77 +- rust/tritrpc_v1/src/lib.rs | 1014 ++++++++++++++++++----- rust/tritrpc_v1/tests/fixtures.rs | 136 +-- rust/tritrpc_v1/tests/vectors.rs | 9 +- spec/README-full-spec.md | 17 +- 24 files changed, 2239 insertions(+), 823 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 Makefile create mode 100644 docs/audit_tritrpc_v1_parity.md create mode 100644 docs/integration_readiness_checklist.md create mode 100644 go/tritrpcv1/avrodec.go create mode 100644 go/tritrpcv1/envelope_decode.go create mode 100644 go/tritrpcv1/go.sum diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..d88fc42 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,27 @@ +name: ci + +on: + push: + pull_request: + +jobs: + verify: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: '1.21' + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install Python deps + run: python -m pip install --upgrade pip cryptography + - name: Verify + run: make verify diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..544b6a8 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +.PHONY: verify fmt rust-fmt go-fmt rust-test go-test fixtures + +verify: fmt rust-test go-test fixtures + +fmt: rust-fmt go-fmt + +rust-fmt: + cd rust/tritrpc_v1 && cargo fmt --check + +go-fmt: + cd go/tritrpcv1 && test -z "$$(gofmt -l .)" + +rust-test: + cd rust/tritrpc_v1 && cargo test + +go-test: + cd go/tritrpcv1 && go test + +fixtures: + python tools/verify_fixtures_strict.py diff --git a/README.md b/README.md index f4afa38..056791d 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ byte-transport, along with authenticated envelope framing. The focus of this rep - **Theory & conceptual model:** `docs/THEORY.md` - **Full specification:** `spec/README-full-spec.md` - **Reference implementation:** `reference/tritrpc_v1.py` +- **Integration readiness checklist:** `docs/integration_readiness_checklist.md` - **Fixtures (canonical vectors):** `fixtures/` - **Rust port:** `rust/` - **Go port:** `go/` @@ -43,8 +44,7 @@ TritRPC v1 is built on these conceptual layers: - **Path-A** payloads are encoded with Avro Binary Encoding (used by the reference implementation and most fixtures). - **Path-B** payloads are ternary-native (toy subset fixtures demonstrate this). -- **AEAD integrity** uses XChaCha20-Poly1305 (or a deterministic MAC fallback) with - 24-byte nonces for authenticated frames. +- **AEAD integrity** uses XChaCha20-Poly1305 with 24-byte nonces for authenticated frames. For complete detail, read `docs/THEORY.md` and the full spec. @@ -68,8 +68,8 @@ A more detailed guide lives in `docs/REPOSITORY_GUIDE.md`. At a glance: ### Fixture verification -- Rust: `cargo test -p tritrpc_v1` validates AEAD tags using `.nonces` and checks Avro - payload bytes for key frames. +- Rust: `cargo test -p tritrpc_v1` validates AEAD tags, schema/context IDs, and full-frame + repack determinism using `.nonces`. - Go: `cd go/tritrpcv1 && go test` performs the same validations. ### CLI tools @@ -105,8 +105,8 @@ See `fixtures/vectors_hex_pathB.txt` (+ `.nonces`). These use ternary-native enc ## CI -A GitHub Actions workflow is included in `.github/workflows/ci.yml` to run `cargo test` -and `go test` on push/PR. +A GitHub Actions workflow runs `make verify` (format checks + tests + fixture verification) +on push/PR. ## Release workflow @@ -116,8 +116,8 @@ and `go test` on push/PR. ## Repack check -CI job `repack-check` rebuilds a canonical AddVertex_a frame with Rust and Go CLIs and -compares it against the fixtures. +Repack determinism is verified in the fixture tests by re-encoding envelopes and comparing +full-frame bytes to fixture vectors. ## Pre-commit hook (strict AEAD verification) diff --git a/docs/REPOSITORY_GUIDE.md b/docs/REPOSITORY_GUIDE.md index eeb66fd..f9fc0f4 100644 --- a/docs/REPOSITORY_GUIDE.md +++ b/docs/REPOSITORY_GUIDE.md @@ -8,6 +8,7 @@ folder. It complements the top-level README by giving concrete navigation hints. - `README.md`: Project summary, build instructions, and high-level navigation. - `docs/`: Conceptual and procedural documentation. - `THEORY.md`: Theory and conceptual model for TritRPC v1. + - `integration_readiness_checklist.md`: Pre-integration verification checklist. - `REPOSITORY_GUIDE.md`: This file. - `spec/`: Specification material. - `README-full-spec.md`: The full spec text (repo copy). diff --git a/docs/THEORY.md b/docs/THEORY.md index 5d7ccab..5445bb8 100644 --- a/docs/THEORY.md +++ b/docs/THEORY.md @@ -71,34 +71,27 @@ changes. The protocol authenticates frames using an **AEAD lane**: -- The preferred suite in the reference implementation is **XChaCha20-Poly1305**. -- Some fixtures allow a deterministic MAC fallback when AEAD primitives are not available. -- The AEAD tag is computed over the envelope's AAD (associated data) and the payload, using - a 24-byte nonce. This makes integrity checks deterministic and replay-resistant. +- The suite used across fixtures and ports is **XChaCha20-Poly1305**. +- The AEAD tag is computed with **empty plaintext** and **AAD equal to the envelope bytes + before the final tag field**, using a 24-byte nonce. This includes payload and AUX data. +- Deterministic MAC fallback support exists only in the Python reference and is + **non-normative for Go/Rust**; fixtures are generated with XChaCha20-Poly1305. A strict verification mode is used by fixtures and tooling to ensure tags remain correct if any portion of the envelope or payload changes. ## 7. Streaming and rolling nonces -For streaming sequences of frames, TritRPC uses **rolling nonces**: - -- A base nonce is derived or agreed upon. -- Each subsequent frame increments or derives the next nonce in a deterministic way. -- This keeps AEAD authentication safe across a stream while retaining deterministic test - fixtures. +Rolling nonces are described in early sketches, but **Go/Rust ports and fixtures use +explicit per-frame nonces** stored in `fixtures/*.nonces`. Rolling nonce derivation is not +implemented in those ports. ## 8. AUX structures -AUX structures are optional fields that can be inserted into an envelope for additional -metadata. The reference implementation includes: - -- **Trace**: tracing and correlation metadata -- **Sig**: placeholder for signature material -- **PoE (Proof-of-Execution)**: a strict-initial placeholder for execution proofs - -These are designed to be extensible so that additional metadata can be added without breaking -existing envelope parsing. +AUX structures are optional byte fields that can be inserted into an envelope for additional +metadata. The current fixtures do **not** include AUX data. The Python reference contains +toy encoders for Trace/Sig/PoE, but Go/Rust treat AUX as an opaque byte slice and do not yet +define structured decoding. ## 9. Hypergraph service model diff --git a/docs/audit_tritrpc_v1_parity.md b/docs/audit_tritrpc_v1_parity.md new file mode 100644 index 0000000..ad4e6d6 --- /dev/null +++ b/docs/audit_tritrpc_v1_parity.md @@ -0,0 +1,25 @@ +# TritRPC v1 Repo Health + Parity Audit + +Status: **Completed (post-fix)**. This report captures the protocol-relevant issues found +and the corrective actions taken. + +## Findings (protocol-relevant) + +| ID | Severity | Issue | Evidence | Status | +| --- | --- | --- | --- | --- | +| A1 | **High** | Rust envelope builder previously used zeroed `schema_id`/`context_id`, diverging from fixtures. | `rust/tritrpc_v1/src/lib.rs` now defines canonical constants and uses them in the envelope builder (L125-L193). | **Fixed** | +| A2 | **High** | Fixture verification did not assert full-frame byte equality or repack determinism. | Go fixture tests re-encode and compare full frames (`go/tritrpcv1/fixtures_test.go` L89-L102); Rust tests do the same (`rust/tritrpc_v1/tests/fixtures.rs` L81-L105). | **Fixed** | +| A3 | **Medium** | Go/Rust lacked structured envelope decoders; tests used ad-hoc field splits only. | Envelope decoders exist in Go (`go/tritrpcv1/envelope_decode.go` L24-L143) and Rust (`rust/tritrpc_v1/src/lib.rs` L219-L341). | **Fixed** | +| A4 | **Medium** | Path-A Avro subset had no decode/round-trip checks. | HGRequest/HGResponse decoders/encoders added in Go (`go/tritrpcv1/avrodec.go` L1-L321) and Rust (`rust/tritrpc_v1/src/lib.rs` L563-L844) with round-trip tests (`go/tritrpcv1/fixtures_test.go` L129-L153; `rust/tritrpc_v1/tests/fixtures.rs` L131-L147). | **Fixed** | +| A5 | **Medium** | Spec/docs claimed MAC fallback, rolling nonces, and AUX structures were supported across ports. | Spec/docs now reflect current port behavior (`spec/README-full-spec.md` L7-L23; `docs/THEORY.md` L70-L94). | **Fixed** | +| A6 | **Low** | README referenced CI/repack-check jobs that did not exist. | Added `Makefile` + CI workflow and updated README (`Makefile` L1-L20; `.github/workflows/ci.yml` L1-L27; `README.md` L106-L120). | **Fixed** | +| A7 | **Low** | Tag comparisons used direct equality; constant-time compare was missing. | Constant-time comparisons now used in Go/Rust fixture checks (`go/tritrpcv1/fixtures_test.go` L117-L122; `rust/tritrpc_v1/tests/fixtures.rs` L123-L125). | **Fixed** | + +## Scope checklist coverage + +- **A) Cross-language envelope parity:** Schema/context constants aligned; AAD definition enforced. +- **B) Strict fixture verification completeness:** Full-frame repack checks added and required. +- **C) Decoder symmetry / round-trip reliability:** HGRequest/HGResponse decode → encode checks added for Path‑A fixtures. +- **D) Spec alignment and missing normative text:** Spec/docs updated to match port behavior (no MAC fallback, explicit nonces, AUX opaque). +- **E) Cryptography invariants:** Tag size + nonce size checked; constant-time tag comparison enforced in verification. +- **F) Tooling / CI ergonomics:** Added `make verify` and CI workflow. diff --git a/docs/integration_readiness_checklist.md b/docs/integration_readiness_checklist.md new file mode 100644 index 0000000..a34dbcd --- /dev/null +++ b/docs/integration_readiness_checklist.md @@ -0,0 +1,50 @@ +# TritRPC v1 Integration Readiness Checklist + +This checklist documents the **minimum verification steps** required before integrating a +policy/view (permissions + privacy) AUX bundle. It reflects the current behavior of the +Go/Rust ports and fixtures. + +## ✅ Protocol invariants (must hold) + +- Schema/context IDs are canonical and match fixtures. +- AEAD tags are computed with **empty plaintext** and **AAD = envelope bytes before the tag + field** (payload + AUX included). +- Fixtures are the source of truth; repacking must reproduce **identical bytes**. +- Nonces are deterministic and pulled from `fixtures/*.nonces`. + +## ✅ Local verification (single command) + +Run from repo root: + +```bash +make verify +``` + +This runs: +- Rust format check + tests +- Go format check + tests +- Fixture AEAD verification script + +## ✅ Per-language commands (if running manually) + +```bash +cd rust/tritrpc_v1 +cargo fmt --check +cargo test +``` + +```bash +cd go/tritrpcv1 +gofmt -l . +go test +``` + +```bash +python tools/verify_fixtures_strict.py +``` + +## ✅ Readiness gates + +- All commands above must succeed on a clean checkout. +- Fixture repack tests must pass (full-frame byte equality). +- Envelope decode → encode stability holds for HGRequest/HGResponse Path-A payloads. diff --git a/go/tritrpcv1/avrodec.go b/go/tritrpcv1/avrodec.go new file mode 100644 index 0000000..36ff8b5 --- /dev/null +++ b/go/tritrpcv1/avrodec.go @@ -0,0 +1,436 @@ +package tritrpcv1 + +import "errors" + +type Vertex struct { + Vid string + Label *string + Attr map[string]string +} + +type Hyperedge struct { + Eid string + Members []string + Weight *int64 + Attr map[string]string +} + +type HGRequest struct { + Op int32 + Vertex *Vertex + Hyperedge *Hyperedge + Vid *string + Eid *string + K *int32 +} + +type HGResponse struct { + Ok bool + Err *string + Vertices []Vertex + Edges []Hyperedge +} + +func decVarint(buf []byte, off int) (uint64, int, error) { + var out uint64 + var shift uint + for { + if off >= len(buf) { + return 0, 0, errors.New("EOF in varint") + } + b := buf[off] + off++ + out |= uint64(b&0x7F) << shift + if (b & 0x80) == 0 { + break + } + shift += 7 + if shift > 63 { + return 0, 0, errors.New("varint overflow") + } + } + return out, off, nil +} + +func decLong(buf []byte, off int) (int64, int, error) { + u, no, err := decVarint(buf, off) + if err != nil { + return 0, 0, err + } + val := int64(u>>1) ^ -int64(u&1) + return val, no, nil +} + +func decInt(buf []byte, off int) (int32, int, error) { + v, no, err := decLong(buf, off) + return int32(v), no, err +} + +func decBool(buf []byte, off int) (bool, int, error) { + if off >= len(buf) { + return false, 0, errors.New("EOF in bool") + } + return buf[off] != 0, off + 1, nil +} + +func decString(buf []byte, off int) (string, int, error) { + l, no, err := decLong(buf, off) + if err != nil { + return "", 0, err + } + if l < 0 { + return "", 0, errors.New("negative string length") + } + end := no + int(l) + if end > len(buf) { + return "", 0, errors.New("string length exceeds buffer") + } + return string(buf[no:end]), end, nil +} + +func decArrayStrings(buf []byte, off int) ([]string, int, error) { + count, no, err := decLong(buf, off) + if err != nil { + return nil, 0, err + } + if count == 0 { + return []string{}, no, nil + } + if count < 0 { + return nil, 0, errors.New("negative array block count") + } + out := make([]string, 0, count) + for i := int64(0); i < count; i++ { + s, n2, err := decString(buf, no) + if err != nil { + return nil, 0, err + } + no = n2 + out = append(out, s) + } + endCount, endOff, err := decLong(buf, no) + if err != nil { + return nil, 0, err + } + if endCount != 0 { + return nil, 0, errors.New("non-zero array terminator") + } + return out, endOff, nil +} + +func decMapStrings(buf []byte, off int) (map[string]string, int, error) { + count, no, err := decLong(buf, off) + if err != nil { + return nil, 0, err + } + if count == 0 { + return map[string]string{}, no, nil + } + if count < 0 { + return nil, 0, errors.New("negative map block count") + } + out := map[string]string{} + for i := int64(0); i < count; i++ { + k, n2, err := decString(buf, no) + if err != nil { + return nil, 0, err + } + v, n3, err := decString(buf, n2) + if err != nil { + return nil, 0, err + } + no = n3 + out[k] = v + } + endCount, endOff, err := decLong(buf, no) + if err != nil { + return nil, 0, err + } + if endCount != 0 { + return nil, 0, errors.New("non-zero map terminator") + } + return out, endOff, nil +} + +func decUnionIndex(buf []byte, off int) (int64, int, error) { + return decLong(buf, off) +} + +func decVertex(buf []byte, off int) (*Vertex, int, error) { + vid, no, err := decString(buf, off) + if err != nil { + return nil, 0, err + } + idx, no, err := decUnionIndex(buf, no) + if err != nil { + return nil, 0, err + } + var label *string + if idx == 1 { + s, n2, err := decString(buf, no) + if err != nil { + return nil, 0, err + } + label = &s + no = n2 + } else if idx != 0 { + return nil, 0, errors.New("invalid union index for label") + } + attr, no, err := decMapStrings(buf, no) + if err != nil { + return nil, 0, err + } + return &Vertex{Vid: vid, Label: label, Attr: attr}, no, nil +} + +func decHyperedge(buf []byte, off int) (*Hyperedge, int, error) { + eid, no, err := decString(buf, off) + if err != nil { + return nil, 0, err + } + members, no, err := decArrayStrings(buf, no) + if err != nil { + return nil, 0, err + } + idx, no, err := decUnionIndex(buf, no) + if err != nil { + return nil, 0, err + } + var weight *int64 + if idx == 1 { + w, n2, err := decLong(buf, no) + if err != nil { + return nil, 0, err + } + weight = &w + no = n2 + } else if idx != 0 { + return nil, 0, errors.New("invalid union index for weight") + } + attr, no, err := decMapStrings(buf, no) + if err != nil { + return nil, 0, err + } + return &Hyperedge{Eid: eid, Members: members, Weight: weight, Attr: attr}, no, nil +} + +func DecodeHGRequest(buf []byte) (HGRequest, error) { + op, off, err := decInt(buf, 0) + if err != nil { + return HGRequest{}, err + } + idxV, off, err := decUnionIndex(buf, off) + if err != nil { + return HGRequest{}, err + } + var vtx *Vertex + if idxV == 1 { + vtx, off, err = decVertex(buf, off) + if err != nil { + return HGRequest{}, err + } + } else if idxV != 0 { + return HGRequest{}, errors.New("invalid union index for vertex") + } + idxE, off, err := decUnionIndex(buf, off) + if err != nil { + return HGRequest{}, err + } + var edge *Hyperedge + if idxE == 1 { + edge, off, err = decHyperedge(buf, off) + if err != nil { + return HGRequest{}, err + } + } else if idxE != 0 { + return HGRequest{}, errors.New("invalid union index for hyperedge") + } + idxVid, off, err := decUnionIndex(buf, off) + if err != nil { + return HGRequest{}, err + } + var vid *string + if idxVid == 1 { + s, n2, err := decString(buf, off) + if err != nil { + return HGRequest{}, err + } + vid = &s + off = n2 + } else if idxVid != 0 { + return HGRequest{}, errors.New("invalid union index for vid") + } + idxEid, off, err := decUnionIndex(buf, off) + if err != nil { + return HGRequest{}, err + } + var eid *string + if idxEid == 1 { + s, n2, err := decString(buf, off) + if err != nil { + return HGRequest{}, err + } + eid = &s + off = n2 + } else if idxEid != 0 { + return HGRequest{}, errors.New("invalid union index for eid") + } + idxK, off, err := decUnionIndex(buf, off) + if err != nil { + return HGRequest{}, err + } + var k *int32 + if idxK == 1 { + kv, n2, err := decInt(buf, off) + if err != nil { + return HGRequest{}, err + } + k = &kv + off = n2 + } else if idxK != 0 { + return HGRequest{}, errors.New("invalid union index for k") + } + if off != len(buf) { + return HGRequest{}, errors.New("extra bytes after HGRequest") + } + return HGRequest{Op: op, Vertex: vtx, Hyperedge: edge, Vid: vid, Eid: eid, K: k}, nil +} + +func EncodeHGRequest(req HGRequest) ([]byte, error) { + switch req.Op { + case 0: + if req.Vertex == nil { + return nil, errors.New("missing vertex") + } + if len(req.Vertex.Attr) != 0 { + return nil, errors.New("vertex attr not supported in request encoder") + } + return EncHGRequestAddVertex(req.Vertex.Vid, req.Vertex.Label), nil + case 1: + if req.Hyperedge == nil { + return nil, errors.New("missing hyperedge") + } + if len(req.Hyperedge.Attr) != 0 { + return nil, errors.New("hyperedge attr not supported in request encoder") + } + return EncHGRequestAddHyperedge(req.Hyperedge.Eid, req.Hyperedge.Members, req.Hyperedge.Weight), nil + case 2: + if req.Vid == nil { + return nil, errors.New("missing vid") + } + return EncHGRequestRemoveVertex(*req.Vid), nil + case 3: + if req.Eid == nil { + return nil, errors.New("missing eid") + } + return EncHGRequestRemoveHyperedge(*req.Eid), nil + case 4: + if req.Vid == nil { + return nil, errors.New("missing vid") + } + k := int32(1) + if req.K != nil { + k = *req.K + } + return EncHGRequestQueryNeighbors(*req.Vid, k), nil + case 5: + if req.Vid == nil { + return nil, errors.New("missing vid") + } + k := int32(1) + if req.K != nil { + k = *req.K + } + return EncHGRequestGetSubgraph(*req.Vid, k), nil + default: + return nil, errors.New("unsupported op") + } +} + +func DecodeHGResponse(buf []byte) (HGResponse, error) { + ok, off, err := decBool(buf, 0) + if err != nil { + return HGResponse{}, err + } + idxErr, off, err := decUnionIndex(buf, off) + if err != nil { + return HGResponse{}, err + } + var errStr *string + if idxErr == 1 { + s, n2, err := decString(buf, off) + if err != nil { + return HGResponse{}, err + } + errStr = &s + off = n2 + } else if idxErr != 0 { + return HGResponse{}, errors.New("invalid union index for err") + } + vcount, off, err := decLong(buf, off) + if err != nil { + return HGResponse{}, err + } + vertices := []Vertex{} + if vcount < 0 { + return HGResponse{}, errors.New("negative vertices block") + } + if vcount > 0 { + for i := int64(0); i < vcount; i++ { + v, n2, err := decVertex(buf, off) + if err != nil { + return HGResponse{}, err + } + if len(v.Attr) != 0 { + return HGResponse{}, errors.New("vertex attr not supported in response fixtures") + } + vertices = append(vertices, *v) + off = n2 + } + endCount, endOff, err := decLong(buf, off) + if err != nil { + return HGResponse{}, err + } + if endCount != 0 { + return HGResponse{}, errors.New("non-zero vertices terminator") + } + off = endOff + } + ecount, off, err := decLong(buf, off) + if err != nil { + return HGResponse{}, err + } + edges := []Hyperedge{} + if ecount < 0 { + return HGResponse{}, errors.New("negative edges block") + } + if ecount > 0 { + for i := int64(0); i < ecount; i++ { + e, n2, err := decHyperedge(buf, off) + if err != nil { + return HGResponse{}, err + } + if len(e.Attr) != 0 { + return HGResponse{}, errors.New("edge attr not supported in response fixtures") + } + edges = append(edges, *e) + off = n2 + } + endCount, endOff, err := decLong(buf, off) + if err != nil { + return HGResponse{}, err + } + if endCount != 0 { + return HGResponse{}, errors.New("non-zero edges terminator") + } + off = endOff + } + if off != len(buf) { + return HGResponse{}, errors.New("extra bytes after HGResponse") + } + return HGResponse{Ok: ok, Err: errStr, Vertices: vertices, Edges: edges}, nil +} + +func EncodeHGResponse(resp HGResponse) ([]byte, error) { + return EncHGResponse(resp.Ok, resp.Err, resp.Vertices, resp.Edges), nil +} diff --git a/go/tritrpcv1/avroenc.go b/go/tritrpcv1/avroenc.go index 329cd61..10fc53b 100644 --- a/go/tritrpcv1/avroenc.go +++ b/go/tritrpcv1/avroenc.go @@ -1,135 +1,200 @@ - package tritrpcv1 func zigzag(n int64) uint64 { - return uint64((n << 1) ^ (n >> 63)) + return uint64((n << 1) ^ (n >> 63)) } func EncVarint(u uint64) []byte { - out := []byte{} - for (u & ^uint64(0x7F)) != 0 { - out = append(out, byte(u&0x7F)|0x80) - u >>= 7 - } - out = append(out, byte(u)) - return out + out := []byte{} + for (u & ^uint64(0x7F)) != 0 { + out = append(out, byte(u&0x7F)|0x80) + u >>= 7 + } + out = append(out, byte(u)) + return out } func EncLong(n int64) []byte { return EncVarint(zigzag(n)) } func EncInt(n int32) []byte { return EncLong(int64(n)) } -func EncBool(v bool) []byte { if v { return []byte{1} } ; return []byte{0} } +func EncBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} +} func EncString(s string) []byte { - b := []byte(s) - out := EncLong(int64(len(b))) - out = append(out, b...) - return out + b := []byte(s) + out := EncLong(int64(len(b))) + out = append(out, b...) + return out } func EncBytes(b []byte) []byte { - out := EncLong(int64(len(b))) - out = append(out, b...) - return out + out := EncLong(int64(len(b))) + out = append(out, b...) + return out } func EncArray(items [][]byte) []byte { - if len(items) == 0 { return []byte{0} } - out := EncLong(int64(len(items))) - for _, it := range items { out = append(out, it...) } - out = append(out, 0) - return out + if len(items) == 0 { + return []byte{0} + } + out := EncLong(int64(len(items))) + for _, it := range items { + out = append(out, it...) + } + out = append(out, 0) + return out } func EncMap(m map[string]string) []byte { - if len(m) == 0 { return []byte{0} } - out := EncLong(int64(len(m))) - for k, v := range m { - out = append(out, EncString(k)...) - out = append(out, EncString(v)...) - } - out = append(out, 0) - return out + if len(m) == 0 { + return []byte{0} + } + out := EncLong(int64(len(m))) + for k, v := range m { + out = append(out, EncString(k)...) + out = append(out, EncString(v)...) + } + out = append(out, 0) + return out } func EncUnion(index int64, payload []byte) []byte { - out := EncLong(index) - out = append(out, payload...) - return out + out := EncLong(index) + out = append(out, payload...) + return out } func EncEnum(index int32) []byte { return EncInt(index) } // Control func EncHello(modes, suites, comp []string, contextURI *string) []byte { - arr := func(ss []string) []byte { - chunks := make([][]byte, 0, len(ss)) - for _, s := range ss { chunks = append(chunks, EncString(s)) } - return EncArray(chunks) - } - out := []byte{} - out = append(out, arr(modes)...) - out = append(out, arr(suites)...) - out = append(out, arr(comp)...) - if contextURI == nil { - out = append(out, EncUnion(0, []byte{})...) - } else { - out = append(out, EncUnion(1, EncString(*contextURI))...) - } - return out + arr := func(ss []string) []byte { + chunks := make([][]byte, 0, len(ss)) + for _, s := range ss { + chunks = append(chunks, EncString(s)) + } + return EncArray(chunks) + } + out := []byte{} + out = append(out, arr(modes)...) + out = append(out, arr(suites)...) + out = append(out, arr(comp)...) + if contextURI == nil { + out = append(out, EncUnion(0, []byte{})...) + } else { + out = append(out, EncUnion(1, EncString(*contextURI))...) + } + return out } func EncChoose(mode, suite, comp string) []byte { - out := []byte{} - out = append(out, EncString(mode)...) - out = append(out, EncString(suite)...) - out = append(out, EncString(comp)...) - return out + out := []byte{} + out = append(out, EncString(mode)...) + out = append(out, EncString(suite)...) + out = append(out, EncString(comp)...) + return out } // Hypergraph func EncVertex(vid string, label *string, attr map[string]string) []byte { - out := []byte{} - out = append(out, EncString(vid)...) - if label == nil { - out = append(out, EncUnion(0, []byte{})...) - } else { - out = append(out, EncUnion(1, EncString(*label))...) - } - out = append(out, EncMap(attr)...) - return out + out := []byte{} + out = append(out, EncString(vid)...) + if label == nil { + out = append(out, EncUnion(0, []byte{})...) + } else { + out = append(out, EncUnion(1, EncString(*label))...) + } + out = append(out, EncMap(attr)...) + return out } func EncHyperedge(eid string, members []string, weight *int64, attr map[string]string) []byte { - out := []byte{} - out = append(out, EncString(eid)...) - items := make([][]byte, 0, len(members)) - for _, m := range members { items = append(items, EncString(m)) } - out = append(out, EncArray(items)...) - if weight == nil { - out = append(out, EncUnion(0, []byte{})...) - } else { - out = append(out, EncUnion(1, EncLong(*weight))...) - } - out = append(out, EncMap(attr)...) - return out + out := []byte{} + out = append(out, EncString(eid)...) + items := make([][]byte, 0, len(members)) + for _, m := range members { + items = append(items, EncString(m)) + } + out = append(out, EncArray(items)...) + if weight == nil { + out = append(out, EncUnion(0, []byte{})...) + } else { + out = append(out, EncUnion(1, EncLong(*weight))...) + } + out = append(out, EncMap(attr)...) + return out } func EncHGRequestAddVertex(vid string, label *string) []byte { - out := []byte{} - out = append(out, EncEnum(0)...) - out = append(out, EncUnion(1, EncVertex(vid, label, map[string]string{}))...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(0, []byte{})...) - return out + out := []byte{} + out = append(out, EncEnum(0)...) + out = append(out, EncUnion(1, EncVertex(vid, label, map[string]string{}))...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + return out } func EncHGRequestAddHyperedge(eid string, members []string, weight *int64) []byte { - out := []byte{} - out = append(out, EncEnum(1)...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(1, EncHyperedge(eid, members, weight, map[string]string{}))...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(0, []byte{})...) - return out + out := []byte{} + out = append(out, EncEnum(1)...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncHyperedge(eid, members, weight, map[string]string{}))...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + return out +} +func EncHGRequestRemoveVertex(vid string) []byte { + out := []byte{} + out = append(out, EncEnum(2)...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncString(vid))...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + return out +} +func EncHGRequestRemoveHyperedge(eid string) []byte { + out := []byte{} + out = append(out, EncEnum(3)...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncString(eid))...) + out = append(out, EncUnion(0, []byte{})...) + return out } func EncHGRequestQueryNeighbors(vid string, k int32) []byte { - out := []byte{} - out = append(out, EncEnum(4)...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(1, EncString(vid))...) - out = append(out, EncUnion(0, []byte{})...) - out = append(out, EncUnion(1, EncInt(k))...) - return out + out := []byte{} + out = append(out, EncEnum(4)...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncString(vid))...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncInt(k))...) + return out +} +func EncHGRequestGetSubgraph(vid string, k int32) []byte { + out := []byte{} + out = append(out, EncEnum(5)...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncString(vid))...) + out = append(out, EncUnion(0, []byte{})...) + out = append(out, EncUnion(1, EncInt(k))...) + return out +} + +func EncHGResponse(ok bool, err *string, vertices []Vertex, edges []Hyperedge) []byte { + out := []byte{} + out = append(out, EncBool(ok)...) + if err == nil { + out = append(out, EncUnion(0, []byte{})...) + } else { + out = append(out, EncUnion(1, EncString(*err))...) + } + vbytes := make([][]byte, 0, len(vertices)) + for _, v := range vertices { + vbytes = append(vbytes, EncVertex(v.Vid, v.Label, v.Attr)) + } + out = append(out, EncArray(vbytes)...) + ebytes := make([][]byte, 0, len(edges)) + for _, e := range edges { + ebytes = append(ebytes, EncHyperedge(e.Eid, e.Members, e.Weight, e.Attr)) + } + out = append(out, EncArray(ebytes)...) + return out } diff --git a/go/tritrpcv1/cmd/trpc/main.go b/go/tritrpcv1/cmd/trpc/main.go index 2f9840f..4a07275 100644 --- a/go/tritrpcv1/cmd/trpc/main.go +++ b/go/tritrpcv1/cmd/trpc/main.go @@ -1,257 +1,179 @@ - package main import ( - "encoding/json" - "encoding/hex" - "flag" - "fmt" - "io/ioutil" - "os" + "bufio" + "crypto/subtle" + "encoding/hex" + "encoding/json" + "flag" + "fmt" + "os" + "strings" - tr "github.com/example/tritrpcv1" - "golang.org/x/crypto/chacha20poly1305" + tr "github.com/example/tritrpcv1" + "golang.org/x/crypto/chacha20poly1305" ) func main() { - if len(os.Args) < 2 { - fmt.Println("Usage: trpc pack|verify ...") - os.Exit(1) - } - switch os.Args[1] { - case "pack": - fs := flag.NewFlagSet("pack", flag.ExitOnError) - svc := fs.String("service", "", "service") - method := fs.String("method", "", "method") - jsonPath := fs.String("json", "", "json path (request/response)") - nonceHex := fs.String("nonce", "", "24-byte nonce hex") - keyHex := fs.String("key", "", "32-byte key hex") - fs.Parse(os.Args[2:]) - if *svc == "" || *method == "" || *jsonPath == "" || *nonceHex == "" || *keyHex == "" { - fs.Usage(); os.Exit(1) - } - jb, _ := ioutil.ReadFile(*jsonPath) - // For brevity, treat payload as request AddVertex 'a' if not parsing JSON - payload := buildFromJSON(*method, jb) - key, _ := hex.DecodeString(*keyHex) - nonce, _ := hex.DecodeString(*nonceHex) - var k [32]byte; copy(k[:], key[:32]) - var n [24]byte; copy(n[:], nonce[:24]) - frame, _, _ := tr.EnvelopeWithTag(*svc, *method, payload, nil, k, n) - fmt.Println(hex.EncodeToString(frame)) - case "verify": - fs := flag.NewFlagSet("verify", flag.ExitOnError) - fixtures := fs.String("fixtures", "", "fixtures file") - nonces := fs.String("nonces", "", "nonces file") - fs.Parse(os.Args[2:]) - if *fixtures == "" || *nonces == "" { fs.Usage(); os.Exit(1) } - // reuse test logic by re-implementing minimal verifier here: - pairs := readPairs(*fixtures) - nmap := readNonces(*nonces) - key := [32]byte{} - for _, p := range pairs { - name := string(p[0]) - frame := p[1] - fields := splitFields(frame) - flags := fields[3] - if aeadBit(flags) { - aad, tag := aadAndTag(frame) - nonce := nmap[name] - a, _ := chacha20poly1305.NewX(key[:]) - ct := a.Seal(nil, nonce, []byte{}, aad) - if hex.EncodeToString(ct[len(ct)-16:]) != hex.EncodeToString(tag) { - fmt.Println("tag mismatch for", name); os.Exit(2) - } - } - } - fmt.Println("OK:", *fixtures) - default: - fmt.Println("Usage: trpc pack|verify ..."); os.Exit(1) - } + if len(os.Args) < 2 { + fmt.Println("Usage: trpc pack|verify ...") + os.Exit(1) + } + switch os.Args[1] { + case "pack": + fs := flag.NewFlagSet("pack", flag.ExitOnError) + svc := fs.String("service", "", "service") + method := fs.String("method", "", "method") + jsonPath := fs.String("json", "", "json path (request/response)") + nonceHex := fs.String("nonce", "", "24-byte nonce hex") + keyHex := fs.String("key", "", "32-byte key hex") + fs.Parse(os.Args[2:]) + if *svc == "" || *method == "" || *jsonPath == "" || *nonceHex == "" || *keyHex == "" { + fs.Usage() + os.Exit(1) + } + jb, err := os.ReadFile(*jsonPath) + if err != nil { + fmt.Println("read error:", err) + os.Exit(1) + } + payload := buildFromJSON(*method, jb) + key, _ := hex.DecodeString(*keyHex) + nonce, _ := hex.DecodeString(*nonceHex) + var k [32]byte + copy(k[:], key[:32]) + var n [24]byte + copy(n[:], nonce[:24]) + frame, _, _ := tr.EnvelopeWithTag(*svc, *method, payload, nil, k, n) + fmt.Println(hex.EncodeToString(frame)) + case "verify": + fs := flag.NewFlagSet("verify", flag.ExitOnError) + fixtures := fs.String("fixtures", "", "fixtures file") + nonces := fs.String("nonces", "", "nonces file") + fs.Parse(os.Args[2:]) + if *fixtures == "" || *nonces == "" { + fs.Usage() + os.Exit(1) + } + pairs := readPairs(*fixtures) + nmap := readNonces(*nonces) + key := [32]byte{} + for _, p := range pairs { + name := string(p[0]) + frame := p[1] + env, err := tr.DecodeEnvelope(frame) + if err != nil { + fmt.Println("decode error for", name, ":", err) + os.Exit(2) + } + if env.AeadOn { + aad, err := tr.AADBeforeTag(frame, env) + if err != nil { + fmt.Println("aad error for", name, ":", err) + os.Exit(2) + } + nonce := nmap[name] + a, _ := chacha20poly1305.NewX(key[:]) + ct := a.Seal(nil, nonce, []byte{}, aad) + computed := ct[len(ct)-16:] + if subtle.ConstantTimeCompare(computed, env.Tag) != 1 { + fmt.Println("tag mismatch for", name) + os.Exit(2) + } + } + } + fmt.Println("OK:", *fixtures) + default: + fmt.Println("Usage: trpc pack|verify ...") + os.Exit(1) + } } -// helpers (copied from tests) - -package tritrpcv1 - -import ( - "encoding/json" - "bufio" - "encoding/hex" - "os" - "strings" - "testing" - "golang.org/x/crypto/chacha20poly1305" -) - func readPairs(path string) [][2][]byte { - f, _ := os.Open(path); defer f.Close() - sc := bufio.NewScanner(f) - out := make([][2][]byte, 0) - for sc.Scan() { - ln := sc.Text() - if ln == "" || strings.HasPrefix(ln, "#") { continue } - parts := strings.SplitN(ln, " ", 2) - name := []byte(parts[0]) - b, _ := hex.DecodeString(parts[1]) - out = append(out, [2][]byte{name, b}) - } - return out + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + sc := bufio.NewScanner(f) + out := make([][2][]byte, 0) + for sc.Scan() { + ln := sc.Text() + if ln == "" || strings.HasPrefix(ln, "#") { + continue + } + parts := strings.SplitN(ln, " ", 2) + name := []byte(parts[0]) + b, _ := hex.DecodeString(parts[1]) + out = append(out, [2][]byte{name, b}) + } + return out } func readNonces(path string) map[string][]byte { - f, _ := os.Open(path); defer f.Close() - sc := bufio.NewScanner(f) - out := map[string][]byte{} - for sc.Scan() { - ln := sc.Text() - if ln == "" { continue } - parts := strings.SplitN(ln, " ", 2) - key := parts[0] - b, _ := hex.DecodeString(parts[1]) - out[key] = b - } - return out -} - -func tleb3DecodeLen(buf []byte, offset int) (val uint64, newOff int) { - // read one byte; unpack trits; parse tritlets until C=0; compute used bytes by re-pack - trits := []byte{} - off := offset - for { - b := buf[off]; off++ - ts, _ := TritUnpack243([]byte{b}) - trits = append(trits, ts...) - if len(trits) < 3 { continue } - v := uint64(0) - used := 0 - for j := 0; j < len(trits)/3; j++ { - c, p1, p0 := trits[3*j], trits[3*j+1], trits[3*j+2] - digit := uint64(p1)*3 + uint64(p0) - // base-9 little-endian - mul := uint64(1) - for k:=0; k 0 { - pack := TritPack243(trits[:used]) - usedBytes := len(pack) - return v, offset + usedBytes - 1 + (off - offset) - } - } + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + sc := bufio.NewScanner(f) + out := map[string][]byte{} + for sc.Scan() { + ln := sc.Text() + if ln == "" { + continue + } + parts := strings.SplitN(ln, " ", 2) + key := parts[0] + b, _ := hex.DecodeString(parts[1]) + out[key] = b + } + return out } -func splitFields(buf []byte) [][]byte { - fields := [][]byte{} - off := 0 - for off < len(buf) { - l, no := tleb3DecodeLen(buf, off) - lo := int(l) - valStart := no - valEnd := valStart + lo - fields = append(fields, buf[valStart:valEnd]) - off = valEnd - } - return fields -} - -func aeadBit(flags []byte) bool { - ts, _ := TritUnpack243(flags) - return len(ts) >= 1 && ts[0] == 2 -} - -func TestFixturesAEADAndPayloads(t *testing.T) { - sets := [][2]string{ - {"fixtures/vectors_hex.txt","fixtures/vectors_hex.txt.nonces"}, - {"fixtures/vectors_hex_stream_avrochunk.txt","fixtures/vectors_hex_stream_avrochunk.txt.nonces"}, - {"fixtures/vectors_hex_unary_rich.txt","fixtures/vectors_hex_unary_rich.txt.nonces"}, - {"fixtures/vectors_hex_stream_avronested.txt","fixtures/vectors_hex_stream_avronested.txt.nonces"}, - } - key := [32]byte{} - for _, s := range sets { - pairs := readPairs(s[0]) - nonces := readNonces(s[1]) - for _, p := range pairs { - name := string(p[0]) - frame := p[1] - fields := splitFields(frame) - if len(fields) < 9 { t.Fatalf("too few fields for %s", name) } - flags := fields[3] - if aeadBit(flags) { - // find start of last field (tag) - off := 0 - lastStart := 0 - idx := 0 - for off < len(frame) { - l, no := tleb3DecodeLen(frame, off) - lastStart = off - off = no + int(l) - idx++ - } - aad := frame[:lastStart] - tag := fields[len(fields)-1] - n := nonces[name] - a, _ := chacha20poly1305.NewX(key[:]) - ct := a.Seal(nil, n, []byte{}, aad) - if hex.EncodeToString(ct[len(ct)-16:]) != hex.EncodeToString(tag) { - t.Fatalf("tag mismatch for %s", name) - } - // Payload checks - if strings.HasSuffix(name, "hyper.v1.AddVertex_a.REQ") || strings.HasSuffix(name, "hyper.v1.AddVertex_a") { - payload := fields[8] - la := "A" - want := EncHGRequestAddVertex("a", &la) - if hex.EncodeToString(payload) != hex.EncodeToString(want) { - t.Fatalf("payload mismatch %s", name) - } - } - if strings.HasSuffix(name, "hyper.v1.AddHyperedge_e1_ab.REQ") || strings.HasSuffix(name, "hyper.v1.AddHyperedge_e1_ab") { - payload := fields[8] - var w int64 = 1 - want := EncHGRequestAddHyperedge("e1", []string{"a","b"}, &w) - if hex.EncodeToString(payload) != hex.EncodeToString(want) { - t.Fatalf("payload mismatch %s", name) - } - } - if strings.HasSuffix(name, "hyper.v1.QueryNeighbors_a_k1.REQ") || strings.HasSuffix(name, "hyper.v1.QueryNeighbors_a_k1") { - payload := fields[8] - want := EncHGRequestQueryNeighbors("a", 1) - if hex.EncodeToString(payload) != hex.EncodeToString(want) { - t.Fatalf("payload mismatch %s", name) - } - } - } - } - } -} - - /* JSON-driven payload builder (subset) */ func buildFromJSON(method string, jb []byte) []byte { - type Vtx struct{ Vid, Label string } - type Edge struct{ Eid string; Members []string; Weight int64 } - type Req struct { - Op string - Vertex *Vtx - Edge *Edge - Vid string - Eid string - K int32 - } - var r Req - _ = json.Unmarshal(jb, &r) - switch r.Op { - case "AddVertex": - return tr.EncHGRequestAddVertex(r.Vertex.Vid, strPtr(r.Vertex.Label)) - case "AddHyperedge": - return tr.EncHGRequestAddHyperedge(r.Edge.Eid, r.Edge.Members, &r.Edge.Weight) - case "QueryNeighbors": - return tr.EncHGRequestQueryNeighbors(r.Vid, r.K) - default: - return tr.EncHGRequestAddVertex("a", strPtr("A")) - } + type Vtx struct{ Vid, Label string } + type Edge struct { + Eid string + Members []string + Weight int64 + } + type Req struct { + Op string + Vertex *Vtx + Edge *Edge + Vid string + Eid string + K int32 + } + var r Req + if err := json.Unmarshal(jb, &r); err != nil { + return tr.EncHGRequestAddVertex("a", strPtr("A")) + } + switch r.Op { + case "AddVertex": + if r.Vertex == nil { + return tr.EncHGRequestAddVertex("a", strPtr("A")) + } + return tr.EncHGRequestAddVertex(r.Vertex.Vid, strPtr(r.Vertex.Label)) + case "AddHyperedge": + if r.Edge == nil { + return tr.EncHGRequestAddHyperedge("e1", []string{"a", "b"}, int64Ptr(1)) + } + return tr.EncHGRequestAddHyperedge(r.Edge.Eid, r.Edge.Members, int64Ptr(r.Edge.Weight)) + case "QueryNeighbors": + return tr.EncHGRequestQueryNeighbors(r.Vid, r.K) + case "RemoveVertex": + return tr.EncHGRequestRemoveVertex(r.Vid) + case "RemoveHyperedge": + return tr.EncHGRequestRemoveHyperedge(r.Eid) + case "GetSubgraph": + return tr.EncHGRequestGetSubgraph(r.Vid, r.K) + default: + return tr.EncHGRequestAddVertex("a", strPtr("A")) + } } + +func strPtr(s string) *string { return &s } +func int64Ptr(v int64) *int64 { return &v } diff --git a/go/tritrpcv1/envelope.go b/go/tritrpcv1/envelope.go index 51811a6..619fed6 100644 --- a/go/tritrpcv1/envelope.go +++ b/go/tritrpcv1/envelope.go @@ -1,11 +1,9 @@ - package tritrpcv1 import ( "golang.org/x/crypto/chacha20poly1305" ) - var SCHEMA_ID_BYTES = []byte{178, 171, 129, 69, 136, 249, 156, 135, 93, 55, 187, 117, 70, 208, 223, 67, 105, 194, 139, 197, 246, 12, 227, 138, 102, 7, 218, 196, 104, 3, 67, 82} var CONTEXT_ID_BYTES = []byte{230, 87, 44, 14, 97, 143, 24, 213, 114, 212, 194, 150, 157, 180, 144, 150, 89, 240, 158, 174, 243, 46, 198, 111, 187, 128, 75, 173, 157, 137, 170, 205} var MAGIC_B2 = []byte{0xF3, 0x2A} diff --git a/go/tritrpcv1/envelope_decode.go b/go/tritrpcv1/envelope_decode.go new file mode 100644 index 0000000..986da43 --- /dev/null +++ b/go/tritrpcv1/envelope_decode.go @@ -0,0 +1,143 @@ +package tritrpcv1 + +import ( + "errors" +) + +type Envelope struct { + Magic []byte + Version []byte + Mode []byte + Flags []byte + Schema []byte + Context []byte + Service string + Method string + Payload []byte + Aux []byte + Tag []byte + AeadOn bool + Compress bool + TagStart int +} + +func DecodeEnvelope(frame []byte) (*Envelope, error) { + off := 0 + readField := func() ([]byte, int, int, error) { + l, no, err := TLEB3DecodeLen(frame, off) + if err != nil { + return nil, 0, 0, err + } + valStart := no + valEnd := valStart + int(l) + if valEnd > len(frame) { + return nil, 0, 0, errors.New("field length exceeds frame") + } + start := off + off = valEnd + return frame[valStart:valEnd], valEnd, start, nil + } + + magic, _, _, err := readField() + if err != nil { + return nil, err + } + version, _, _, err := readField() + if err != nil { + return nil, err + } + mode, _, _, err := readField() + if err != nil { + return nil, err + } + flags, _, _, err := readField() + if err != nil { + return nil, err + } + schema, _, _, err := readField() + if err != nil { + return nil, err + } + context, _, _, err := readField() + if err != nil { + return nil, err + } + svcBytes, _, _, err := readField() + if err != nil { + return nil, err + } + methodBytes, _, _, err := readField() + if err != nil { + return nil, err + } + payload, _, _, err := readField() + if err != nil { + return nil, err + } + + trits, _ := TritUnpack243(flags) + aeadOn := len(trits) > 0 && trits[0] == 2 + compress := len(trits) > 1 && trits[1] == 2 + + var aux []byte + var tag []byte + tagStart := -1 + + if off < len(frame) { + if aeadOn { + field, _, start, err := readField() + if err != nil { + return nil, err + } + if off < len(frame) { + tagField, _, tagFieldStart, err := readField() + if err != nil { + return nil, err + } + aux = append([]byte{}, field...) + tag = append([]byte{}, tagField...) + tagStart = tagFieldStart + } else { + tag = append([]byte{}, field...) + tagStart = start + } + } else { + field, _, _, err := readField() + if err != nil { + return nil, err + } + aux = append([]byte{}, field...) + } + } + + if off != len(frame) { + return nil, errors.New("extra bytes after envelope parsing") + } + + return &Envelope{ + Magic: append([]byte{}, magic...), + Version: append([]byte{}, version...), + Mode: append([]byte{}, mode...), + Flags: append([]byte{}, flags...), + Schema: append([]byte{}, schema...), + Context: append([]byte{}, context...), + Service: string(svcBytes), + Method: string(methodBytes), + Payload: append([]byte{}, payload...), + Aux: aux, + Tag: tag, + AeadOn: aeadOn, + Compress: compress, + TagStart: tagStart, + }, nil +} + +func AADBeforeTag(frame []byte, env *Envelope) ([]byte, error) { + if !env.AeadOn || len(env.Tag) == 0 { + return frame, nil + } + if env.TagStart < 0 || env.TagStart > len(frame) { + return nil, errors.New("invalid tag start") + } + return frame[:env.TagStart], nil +} diff --git a/go/tritrpcv1/fixtures_test.go b/go/tritrpcv1/fixtures_test.go index 21024d4..a7d1ac8 100644 --- a/go/tritrpcv1/fixtures_test.go +++ b/go/tritrpcv1/fixtures_test.go @@ -1,158 +1,157 @@ - package tritrpcv1 import ( - "bufio" - "encoding/hex" - "os" - "strings" - "testing" - "golang.org/x/crypto/chacha20poly1305" + "bufio" + "crypto/subtle" + "encoding/hex" + "golang.org/x/crypto/chacha20poly1305" + "os" + "strings" + "testing" ) func readPairs(path string) [][2][]byte { - f, _ := os.Open(path); defer f.Close() - sc := bufio.NewScanner(f) - out := make([][2][]byte, 0) - for sc.Scan() { - ln := sc.Text() - if ln == "" || strings.HasPrefix(ln, "#") { continue } - parts := strings.SplitN(ln, " ", 2) - name := []byte(parts[0]) - b, _ := hex.DecodeString(parts[1]) - out = append(out, [2][]byte{name, b}) - } - return out + f, _ := os.Open(path) + defer f.Close() + sc := bufio.NewScanner(f) + out := make([][2][]byte, 0) + for sc.Scan() { + ln := sc.Text() + if ln == "" || strings.HasPrefix(ln, "#") { + continue + } + parts := strings.SplitN(ln, " ", 2) + name := []byte(parts[0]) + b, _ := hex.DecodeString(parts[1]) + out = append(out, [2][]byte{name, b}) + } + return out } func readNonces(path string) map[string][]byte { - f, _ := os.Open(path); defer f.Close() - sc := bufio.NewScanner(f) - out := map[string][]byte{} - for sc.Scan() { - ln := sc.Text() - if ln == "" { continue } - parts := strings.SplitN(ln, " ", 2) - key := parts[0] - b, _ := hex.DecodeString(parts[1]) - out[key] = b - } - return out -} - -func tleb3DecodeLen(buf []byte, offset int) (val uint64, newOff int) { - // read one byte; unpack trits; parse tritlets until C=0; compute used bytes by re-pack - trits := []byte{} - off := offset - for { - b := buf[off]; off++ - ts, _ := TritUnpack243([]byte{b}) - trits = append(trits, ts...) - if len(trits) < 3 { continue } - v := uint64(0) - used := 0 - for j := 0; j < len(trits)/3; j++ { - c, p1, p0 := trits[3*j], trits[3*j+1], trits[3*j+2] - digit := uint64(p1)*3 + uint64(p0) - // base-9 little-endian - mul := uint64(1) - for k:=0; k 0 { - pack := TritPack243(trits[:used]) - usedBytes := len(pack) - return v, offset + usedBytes - 1 + (off - offset) - } - } + f, _ := os.Open(path) + defer f.Close() + sc := bufio.NewScanner(f) + out := map[string][]byte{} + for sc.Scan() { + ln := sc.Text() + if ln == "" { + continue + } + parts := strings.SplitN(ln, " ", 2) + key := parts[0] + b, _ := hex.DecodeString(parts[1]) + out[key] = b + } + return out } func splitFields(buf []byte) [][]byte { - fields := [][]byte{} - off := 0 - for off < len(buf) { - l, no := tleb3DecodeLen(buf, off) - lo := int(l) - valStart := no - valEnd := valStart + lo - fields = append(fields, buf[valStart:valEnd]) - off = valEnd - } - return fields + fields := [][]byte{} + off := 0 + for off < len(buf) { + l, no, err := TLEB3DecodeLen(buf, off) + if err != nil { + panic(err) + } + lo := int(l) + valStart := no + valEnd := valStart + lo + fields = append(fields, buf[valStart:valEnd]) + off = valEnd + } + return fields } func aeadBit(flags []byte) bool { - ts, _ := TritUnpack243(flags) - return len(ts) >= 1 && ts[0] == 2 + ts, _ := TritUnpack243(flags) + return len(ts) >= 1 && ts[0] == 2 } func TestFixturesAEADAndPayloads(t *testing.T) { - sets := [][2]string{ - {"fixtures/vectors_hex.txt","fixtures/vectors_hex.txt.nonces"}, - {"fixtures/vectors_hex_stream_avrochunk.txt","fixtures/vectors_hex_stream_avrochunk.txt.nonces"}, - {"fixtures/vectors_hex_unary_rich.txt","fixtures/vectors_hex_unary_rich.txt.nonces"}, - {"fixtures/vectors_hex_stream_avronested.txt","fixtures/vectors_hex_stream_avronested.txt.nonces"}, - } - key := [32]byte{} - for _, s := range sets { - pairs := readPairs(s[0]) - nonces := readNonces(s[1]) - for _, p := range pairs { - name := string(p[0]) - frame := p[1] - fields := splitFields(frame) - if len(fields) < 9 { t.Fatalf("too few fields for %s", name) } - flags := fields[3] - if aeadBit(flags) { - // find start of last field (tag) - off := 0 - lastStart := 0 - idx := 0 - for off < len(frame) { - l, no := tleb3DecodeLen(frame, off) - lastStart = off - off = no + int(l) - idx++ - } - aad := frame[:lastStart] - tag := fields[len(fields)-1] - n := nonces[name] - a, _ := chacha20poly1305.NewX(key[:]) - strict := os.Getenv("STRICT_AEAD") == "1" - ct := a.Seal(nil, n, []byte{}, aad) - if hex.EncodeToString(ct[len(ct)-16:]) != hex.EncodeToString(tag) { - if strict { t.Fatalf("strict AEAD tag mismatch for %s", name) } - t.Fatalf("tag mismatch for %s", name) - } - // Payload checks - if strings.HasSuffix(name, "hyper.v1.AddVertex_a.REQ") || strings.HasSuffix(name, "hyper.v1.AddVertex_a") { - payload := fields[8] - la := "A" - want := EncHGRequestAddVertex("a", &la) - if hex.EncodeToString(payload) != hex.EncodeToString(want) { - t.Fatalf("payload mismatch %s", name) - } - } - if strings.HasSuffix(name, "hyper.v1.AddHyperedge_e1_ab.REQ") || strings.HasSuffix(name, "hyper.v1.AddHyperedge_e1_ab") { - payload := fields[8] - var w int64 = 1 - want := EncHGRequestAddHyperedge("e1", []string{"a","b"}, &w) - if hex.EncodeToString(payload) != hex.EncodeToString(want) { - t.Fatalf("payload mismatch %s", name) - } - } - if strings.HasSuffix(name, "hyper.v1.QueryNeighbors_a_k1.REQ") || strings.HasSuffix(name, "hyper.v1.QueryNeighbors_a_k1") { - payload := fields[8] - want := EncHGRequestQueryNeighbors("a", 1) - if hex.EncodeToString(payload) != hex.EncodeToString(want) { - t.Fatalf("payload mismatch %s", name) - } - } - } - } - } + sets := [][2]string{ + {"fixtures/vectors_hex.txt", "fixtures/vectors_hex.txt.nonces"}, + {"fixtures/vectors_hex_stream_avrochunk.txt", "fixtures/vectors_hex_stream_avrochunk.txt.nonces"}, + {"fixtures/vectors_hex_unary_rich.txt", "fixtures/vectors_hex_unary_rich.txt.nonces"}, + {"fixtures/vectors_hex_stream_avronested.txt", "fixtures/vectors_hex_stream_avronested.txt.nonces"}, + } + key := [32]byte{} + for _, s := range sets { + pairs := readPairs(s[0]) + nonces := readNonces(s[1]) + for _, p := range pairs { + name := string(p[0]) + frame := p[1] + fields := splitFields(frame) + if len(fields) < 9 { + t.Fatalf("too few fields for %s", name) + } + env, err := DecodeEnvelope(frame) + if err != nil { + t.Fatalf("decode error %s: %v", name, err) + } + if hex.EncodeToString(env.Schema) != hex.EncodeToString(SCHEMA_ID_32) { + t.Fatalf("schema id mismatch %s", name) + } + if hex.EncodeToString(env.Context) != hex.EncodeToString(CONTEXT_ID_32) { + t.Fatalf("context id mismatch %s", name) + } + repacked := BuildEnvelope(env.Service, env.Method, env.Payload, env.Aux, env.Tag, env.AeadOn, env.Compress) + if hex.EncodeToString(repacked) != hex.EncodeToString(frame) { + t.Fatalf("repack mismatch %s", name) + } + flags := fields[3] + if aeadBit(flags) { + aad, err := AADBeforeTag(frame, env) + if err != nil { + t.Fatalf("aad error %s: %v", name, err) + } + tag := env.Tag + n := nonces[name] + if len(n) != 24 { + t.Fatalf("nonce size mismatch %s", name) + } + if len(tag) != 16 { + t.Fatalf("tag size mismatch %s", name) + } + a, _ := chacha20poly1305.NewX(key[:]) + strict := os.Getenv("STRICT_AEAD") == "1" + ct := a.Seal(nil, n, []byte{}, aad) + computed := ct[len(ct)-16:] + if subtle.ConstantTimeCompare(computed, tag) != 1 { + if strict { + t.Fatalf("strict AEAD tag mismatch for %s", name) + } + t.Fatalf("tag mismatch for %s", name) + } + } + + if strings.HasSuffix(env.Method, ".REQ") { + req, err := DecodeHGRequest(env.Payload) + if err != nil { + t.Fatalf("decode request %s: %v", name, err) + } + recoded, err := EncodeHGRequest(req) + if err != nil { + t.Fatalf("re-encode request %s: %v", name, err) + } + if hex.EncodeToString(recoded) != hex.EncodeToString(env.Payload) { + t.Fatalf("HGRequest round-trip mismatch %s", name) + } + } + if strings.HasSuffix(env.Method, ".RSP") { + resp, err := DecodeHGResponse(env.Payload) + if err != nil { + t.Fatalf("decode response %s: %v", name, err) + } + recoded, err := EncodeHGResponse(resp) + if err != nil { + t.Fatalf("re-encode response %s: %v", name, err) + } + if hex.EncodeToString(recoded) != hex.EncodeToString(env.Payload) { + t.Fatalf("HGResponse round-trip mismatch %s", name) + } + } + } + } } diff --git a/go/tritrpcv1/go.sum b/go/tritrpcv1/go.sum new file mode 100644 index 0000000..b6e3d53 --- /dev/null +++ b/go/tritrpcv1/go.sum @@ -0,0 +1 @@ +golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= diff --git a/go/tritrpcv1/pathb_dec.go b/go/tritrpcv1/pathb_dec.go index 0d1dfb1..2ef3686 100644 --- a/go/tritrpcv1/pathb_dec.go +++ b/go/tritrpcv1/pathb_dec.go @@ -2,40 +2,48 @@ package tritrpcv1 // Minimal Path-B decoders for strings and union index (subset used in fixtures) func PBDecodeLen(buf []byte, off int) (int, int) { - // TLEB3 decode for length: reuse TLEB3 decoder by repacking; here we assume small inputs and just reuse TritUnpack on a byte-by-byte basis - // NOTE: For production, implement a proper scanner. - trits := []byte{} - start := off - for { - b := buf[off]; off++ - ts, _ := TritUnpack243([]byte{b}) - trits = append(trits, ts...) - if len(trits) >= 3 { - v := uint64(0); used := 0 - for j:=0; j0 { - pack := TritPack243(trits[:used]) - usedBytes := len(pack) - newOff := start + usedBytes - return int(v), newOff - } - } - } + // TLEB3 decode for length: reuse TLEB3 decoder by repacking; here we assume small inputs and just reuse TritUnpack on a byte-by-byte basis + // NOTE: For production, implement a proper scanner. + trits := []byte{} + start := off + for { + b := buf[off] + off++ + ts, _ := TritUnpack243([]byte{b}) + trits = append(trits, ts...) + if len(trits) >= 3 { + v := uint64(0) + used := 0 + for j := 0; j < len(trits)/3; j++ { + c, p1, p0 := trits[3*j], trits[3*j+1], trits[3*j+2] + digit := uint64(p1)*3 + uint64(p0) + mul := uint64(1) + for k := 0; k < j; k++ { + mul *= 9 + } + v += digit * mul + if c == 0 { + used = (j + 1) * 3 + break + } + } + if used > 0 { + pack := TritPack243(trits[:used]) + usedBytes := len(pack) + newOff := start + usedBytes + return int(v), newOff + } + } + } } func PBDecodeString(buf []byte, off int) (string, int) { - l, o2 := PBDecodeLen(buf, off) - s := string(buf[o2:o2+l]) - return s, o2+l + l, o2 := PBDecodeLen(buf, off) + s := string(buf[o2 : o2+l]) + return s, o2 + l } func PBDecodeUnionIndex(buf []byte, off int) (int, int) { - l, o2 := PBDecodeLen(buf, off) - return l, o2 + l, o2 := PBDecodeLen(buf, off) + return l, o2 } diff --git a/go/tritrpcv1/tleb3.go b/go/tritrpcv1/tleb3.go index 22fbcab..d84cd13 100644 --- a/go/tritrpcv1/tleb3.go +++ b/go/tritrpcv1/tleb3.go @@ -1,6 +1,7 @@ - package tritrpcv1 +import "errors" + func TLEB3EncodeLen(n uint64) []byte { var digits []byte if n == 0 { @@ -23,3 +24,40 @@ func TLEB3EncodeLen(n uint64) []byte { } return TritPack243(trits) } + +func TLEB3DecodeLen(buf []byte, offset int) (val uint64, newOff int, err error) { + trits := []byte{} + off := offset + for { + if off >= len(buf) { + return 0, 0, errors.New("EOF in TLEB3") + } + b := buf[off] + off++ + ts, _ := TritUnpack243([]byte{b}) + trits = append(trits, ts...) + if len(trits) < 3 { + continue + } + v := uint64(0) + used := 0 + for j := 0; j < len(trits)/3; j++ { + c, p1, p0 := trits[3*j], trits[3*j+1], trits[3*j+2] + digit := uint64(p1)*3 + uint64(p0) + mul := uint64(1) + for k := 0; k < j; k++ { + mul *= 9 + } + v += digit * mul + if c == 0 { + used = (j + 1) * 3 + break + } + } + if used > 0 { + pack := TritPack243(trits[:used]) + usedBytes := len(pack) + return v, offset + usedBytes - 1 + (off - offset), nil + } + } +} diff --git a/go/tritrpcv1/tritpack243.go b/go/tritrpcv1/tritpack243.go index 4a9aa61..c0ae57a 100644 --- a/go/tritrpcv1/tritpack243.go +++ b/go/tritrpcv1/tritpack243.go @@ -1,4 +1,3 @@ - package tritrpcv1 import "fmt" @@ -47,7 +46,7 @@ func TritUnpack243(bytes []byte) ([]byte, error) { if i >= len(bytes) { return nil, fmt.Errorf("truncated tail") } - k := int(b-243+1) + k := int(b - 243 + 1) val := int(bytes[i]) i++ group := make([]byte, k) diff --git a/go/tritrpcv1/tritrpcv1_test.go b/go/tritrpcv1/tritrpcv1_test.go index 5dfe59f..76227a0 100644 --- a/go/tritrpcv1/tritrpcv1_test.go +++ b/go/tritrpcv1/tritrpcv1_test.go @@ -1,21 +1,20 @@ - package tritrpcv1 import "testing" func TestMicroVectors(t *testing.T) { - b := TritPack243([]byte{2,1,0,0,2}) + b := TritPack243([]byte{2, 1, 0, 0, 2}) if len(b) != 1 || b[0] != 0xBF { t.Fatalf("pack fail, got %x", b) } - b2 := TritPack243([]byte{2,2,1}) + b2 := TritPack243([]byte{2, 2, 1}) if len(b2) != 2 || b2[0] != 0xF5 || b2[1] != 0x19 { t.Fatalf("tail fail, got %x", b2) } } func TestTleb3EncodeLen(t *testing.T) { - for _, n := range []uint64{0,1,2,3,8,9,10,123,4096,65535} { + for _, n := range []uint64{0, 1, 2, 3, 8, 9, 10, 123, 4096, 65535} { enc := TLEB3EncodeLen(n) if len(enc) == 0 { t.Fatalf("empty encoding for %d", n) diff --git a/rust/tritrpc_v1/Cargo.toml b/rust/tritrpc_v1/Cargo.toml index 01ca49c..69b48e6 100644 --- a/rust/tritrpc_v1/Cargo.toml +++ b/rust/tritrpc_v1/Cargo.toml @@ -11,3 +11,4 @@ chacha20poly1305 = { version = "0.10", features = ["xchacha20poly1305"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" hex = "0.4" +subtle = "2.5" diff --git a/rust/tritrpc_v1/src/bin/trpc.rs b/rust/tritrpc_v1/src/bin/trpc.rs index 5cafc17..b250dc6 100644 --- a/rust/tritrpc_v1/src/bin/trpc.rs +++ b/rust/tritrpc_v1/src/bin/trpc.rs @@ -1,16 +1,15 @@ - use std::env; use std::fs; use std::process::exit; -use tritrpc_v1::{envelope, avroenc}; +use tritrpc_v1::{avroenc, envelope}; -fn hex_to_bytes(s:&str)->Vec{ +fn hex_to_bytes(s: &str) -> Vec { let s = s.trim(); let mut out = Vec::new(); let mut it = s.as_bytes().chunks(2); for ch in it { let hh = std::str::from_utf8(ch).unwrap(); - out.push(u8::from_str_radix(hh,16).unwrap()); + out.push(u8::from_str_radix(hh, 16).unwrap()); } out } @@ -22,7 +21,10 @@ fn usage() { fn main() { let args: Vec = env::args().collect(); - if args.len() < 2 { usage(); exit(1); } + if args.len() < 2 { + usage(); + exit(1); + } match args[1].as_str() { "pack" => { let mut svc = String::new(); @@ -33,16 +35,39 @@ fn main() { let mut i = 2; while i < args.len() { match args[i].as_str() { - "--service" => { i+=1; svc = args[i].clone(); } - "--method" => { i+=1; m = args[i].clone(); } - "--json" => { i+=1; jsonp = args[i].clone(); } - "--nonce" => { i+=1; nonce_hex = args[i].clone(); } - "--key" => { i+=1; key_hex = args[i].clone(); } + "--service" => { + i += 1; + svc = args[i].clone(); + } + "--method" => { + i += 1; + m = args[i].clone(); + } + "--json" => { + i += 1; + jsonp = args[i].clone(); + } + "--nonce" => { + i += 1; + nonce_hex = args[i].clone(); + } + "--key" => { + i += 1; + key_hex = args[i].clone(); + } _ => {} } - i+=1; + i += 1; + } + if svc.is_empty() + || m.is_empty() + || jsonp.is_empty() + || nonce_hex.is_empty() + || key_hex.is_empty() + { + usage(); + exit(2); } - if svc.is_empty() || m.is_empty() || jsonp.is_empty() || nonce_hex.is_empty() || key_hex.is_empty() { usage(); exit(2); } let js = fs::read_to_string(&jsonp).expect("read json"); let v: serde_json::Value = serde_json::from_str(&js).expect("json"); let payload = if m.ends_with(".REQ") || m.ends_with(".Req") || m.ends_with(".Request") { @@ -55,8 +80,10 @@ fn main() { }; let keyb = hex_to_bytes(&key_hex); let nonceb = hex_to_bytes(&nonce_hex); - let mut key = [0u8;32]; key.copy_from_slice(&keyb[..32]); - let mut nonce = [0u8;24]; nonce.copy_from_slice(&nonceb[..24]); + let mut key = [0u8; 32]; + key.copy_from_slice(&keyb[..32]); + let mut nonce = [0u8; 24]; + nonce.copy_from_slice(&nonceb[..24]); let (frame, _tag) = envelope::envelope_with_tag(&svc, &m, &payload, None, &key, &nonce); println!("{}", hex::encode(frame)); } @@ -66,16 +93,28 @@ fn main() { let mut i = 2; while i < args.len() { match args[i].as_str() { - "--fixtures" => { i+=1; fixtures = args[i].clone(); } - "--nonces" => { i+=1; nonces = args[i].clone(); } + "--fixtures" => { + i += 1; + fixtures = args[i].clone(); + } + "--nonces" => { + i += 1; + nonces = args[i].clone(); + } _ => {} } - i+=1; + i += 1; + } + if fixtures.is_empty() || nonces.is_empty() { + usage(); + exit(3); } - if fixtures.is_empty() || nonces.is_empty() { usage(); exit(3); } let out = tritrpc_v1_tests::verify_file(&fixtures, &nonces); println!("{}", out); } - _ => { usage(); exit(4); } + _ => { + usage(); + exit(4); + } } } diff --git a/rust/tritrpc_v1/src/lib.rs b/rust/tritrpc_v1/src/lib.rs index 5bf2f7f..c82a90d 100644 --- a/rust/tritrpc_v1/src/lib.rs +++ b/rust/tritrpc_v1/src/lib.rs @@ -1,11 +1,10 @@ - pub mod tritpack243 { pub fn pack(trits: &[u8]) -> Vec { let mut out: Vec = Vec::new(); let mut i: usize = 0; while i + 5 <= trits.len() { let mut val: u32 = 0; - for &t in &trits[i..i+5] { + for &t in &trits[i..i + 5] { assert!(t <= 2, "invalid trit"); val = val * 3 + t as u32; } @@ -28,7 +27,8 @@ pub mod tritpack243 { let mut trits: Vec = Vec::new(); let mut i: usize = 0; while i < bytes.len() { - let b = bytes[i]; i += 1; + let b = bytes[i]; + i += 1; if b <= 242 { let mut val = b as u32; let mut group = [0u8; 5]; @@ -38,9 +38,12 @@ pub mod tritpack243 { } trits.extend_from_slice(&group); } else if (243..=246).contains(&b) { - if i >= bytes.len() { return Err("truncated tail marker".into()); } + if i >= bytes.len() { + return Err("truncated tail marker".into()); + } let k = (b - 243 + 1) as usize; - let mut val = bytes[i] as u32; i += 1; + let mut val = bytes[i] as u32; + i += 1; let mut group = vec![0u8; k]; for j in (0..k).rev() { group[j] = (val % 3) as u8; @@ -59,7 +62,9 @@ pub mod tleb3 { use super::tritpack243; pub fn encode_len(mut n: u64) -> Vec { let mut digits: Vec = Vec::new(); - if n == 0 { digits.push(0); } else { + if n == 0 { + digits.push(0); + } else { while n > 0 { digits.push((n % 9) as u8); n /= 9; @@ -67,10 +72,12 @@ pub mod tleb3 { } let mut trits: Vec = Vec::new(); for (i, d) in digits.iter().enumerate() { - let c = if i < digits.len()-1 { 2 } else { 0 }; + let c = if i < digits.len() - 1 { 2 } else { 0 }; let p1 = d / 3; let p0 = d % 3; - trits.push(c); trits.push(*p1); trits.push(*p0); + trits.push(c); + trits.push(*p1); + trits.push(*p0); } tritpack243::pack(&trits) } @@ -78,21 +85,26 @@ pub mod tleb3 { pub fn decode_len(bytes: &[u8], mut offset: usize) -> Result<(u64, usize), String> { let mut trits: Vec = Vec::new(); loop { - if offset >= bytes.len() { return Err("EOF in TLEB3".into()); } - let b = bytes[offset]; offset += 1; + if offset >= bytes.len() { + return Err("EOF in TLEB3".into()); + } + let b = bytes[offset]; + offset += 1; let ts = super::tritpack243::unpack(&[b])?; trits.extend_from_slice(&ts); - if trits.len() < 3 { continue; } + if trits.len() < 3 { + continue; + } let mut val: u64 = 0; let mut used_trits: usize = 0; - for j in 0..(trits.len()/3) { - let c = trits[3*j] as u64; - let p1 = trits[3*j+1] as u64; - let p0 = trits[3*j+2] as u64; - let digit = p1*3 + p0; + for j in 0..(trits.len() / 3) { + let c = trits[3 * j] as u64; + let p1 = trits[3 * j + 1] as u64; + let p0 = trits[3 * j + 2] as u64; + let digit = p1 * 3 + p0; val += digit * 9u64.pow(j as u32); if c == 0 { - used_trits = (j+1)*3; + used_trits = (j + 1) * 3; break; } } @@ -106,11 +118,21 @@ pub mod tleb3 { } pub mod envelope { - use super::{tritpack243, tleb3}; + use super::{tleb3, tritpack243}; use chacha20poly1305::aead::{Aead, KeyInit}; use chacha20poly1305::XChaCha20Poly1305; - const MAGIC_B2: [u8;2] = [0xF3, 0x2A]; + const MAGIC_B2: [u8; 2] = [0xF3, 0x2A]; + pub const SCHEMA_ID_32: [u8; 32] = [ + 0xb2, 0xab, 0x81, 0x45, 0x88, 0xf9, 0x9c, 0x87, 0x5d, 0x37, 0xbb, 0x75, 0x46, 0xd0, 0xdf, + 0x43, 0x69, 0xc2, 0x8b, 0xc5, 0xf6, 0x0c, 0xe3, 0x8a, 0x66, 0x07, 0xda, 0xc4, 0x68, 0x03, + 0x43, 0x52, + ]; + pub const CONTEXT_ID_32: [u8; 32] = [ + 0xe6, 0x57, 0x2c, 0x0e, 0x61, 0x8f, 0x18, 0xd5, 0x72, 0xd4, 0xc2, 0x96, 0x9d, 0xb4, 0x90, + 0x96, 0x59, 0xf0, 0x9e, 0xae, 0xf3, 0x2e, 0xc6, 0x6f, 0xbb, 0x80, 0x4b, 0xad, 0x9d, 0x89, + 0xaa, 0xcd, + ]; fn len_prefix(b: &[u8]) -> Vec { tleb3::encode_len(b.len() as u64) @@ -120,43 +142,187 @@ pub mod envelope { tritpack243::pack(ts) } - pub fn flags_trits(aead: bool, compress: bool) -> [u8;3] { - [ - if aead {2} else {0}, - if compress {2} else {0}, - 0 - ] + pub fn flags_trits(aead: bool, compress: bool) -> [u8; 3] { + [if aead { 2 } else { 0 }, if compress { 2 } else { 0 }, 0] } - pub fn build(service:&str, method:&str, payload:&[u8], aux: Option<&[u8]>, aead_tag: Option<&[u8]>, aead_on: bool, compress: bool) -> Vec { + pub fn build( + service: &str, + method: &str, + payload: &[u8], + aux: Option<&[u8]>, + aead_tag: Option<&[u8]>, + aead_on: bool, + compress: bool, + ) -> Vec { let mut out: Vec = Vec::new(); - out.extend(len_prefix(&MAGIC_B2)); out.extend(MAGIC_B2); - let ver = pack_trits(&[1]); out.extend(len_prefix(&ver)); out.extend(ver); - let mode = pack_trits(&[0]); out.extend(len_prefix(&mode)); out.extend(mode); - let flags = pack_trits(&super::envelope::flags_trits(aead_on, compress)); out.extend(len_prefix(&flags)); out.extend(flags); - let schema = vec![0u8;32]; out.extend(len_prefix(&schema)); out.extend(&schema); - let context = vec![0u8;32]; out.extend(len_prefix(&context)); out.extend(&context); - let svc = service.as_bytes(); out.extend(len_prefix(svc)); out.extend(svc); - let m = method.as_bytes(); out.extend(len_prefix(m)); out.extend(m); - out.extend(len_prefix(payload)); out.extend(payload); - if let Some(auxb) = aux { out.extend(len_prefix(auxb)); out.extend(auxb); } - if let Some(tag) = aead_tag { out.extend(len_prefix(tag)); out.extend(tag); } + out.extend(len_prefix(&MAGIC_B2)); + out.extend(MAGIC_B2); + let ver = pack_trits(&[1]); + out.extend(len_prefix(&ver)); + out.extend(ver); + let mode = pack_trits(&[0]); + out.extend(len_prefix(&mode)); + out.extend(mode); + let flags = pack_trits(&super::envelope::flags_trits(aead_on, compress)); + out.extend(len_prefix(&flags)); + out.extend(flags); + let schema = SCHEMA_ID_32; + out.extend(len_prefix(&schema)); + out.extend(&schema); + let context = CONTEXT_ID_32; + out.extend(len_prefix(&context)); + out.extend(&context); + let svc = service.as_bytes(); + out.extend(len_prefix(svc)); + out.extend(svc); + let m = method.as_bytes(); + out.extend(len_prefix(m)); + out.extend(m); + out.extend(len_prefix(payload)); + out.extend(payload); + if let Some(auxb) = aux { + out.extend(len_prefix(auxb)); + out.extend(auxb); + } + if let Some(tag) = aead_tag { + out.extend(len_prefix(tag)); + out.extend(tag); + } out } - pub fn envelope_with_tag(service:&str, method:&str, payload:&[u8], aux: Option<&[u8]>, key:&[u8;32], nonce:&[u8;24]) -> (Vec, Vec) { + pub fn envelope_with_tag( + service: &str, + method: &str, + payload: &[u8], + aux: Option<&[u8]>, + key: &[u8; 32], + nonce: &[u8; 24], + ) -> (Vec, Vec) { let aad = build(service, method, payload, aux, None, true, false); let aead = XChaCha20Poly1305::new(key.into()); - let ct = aead.encrypt(nonce.into(), chacha20poly1305::aead::Payload { msg: b"", aad: &aad }).expect("encrypt"); - let tag = ct[ct.len()-16..].to_vec(); + let ct = aead + .encrypt( + nonce.into(), + chacha20poly1305::aead::Payload { + msg: b"", + aad: &aad, + }, + ) + .expect("encrypt"); + let tag = ct[ct.len() - 16..].to_vec(); let frame = build(service, method, payload, aux, Some(&tag), true, false); (frame, tag) } + + #[derive(Debug, Clone)] + pub struct DecodedEnvelope { + pub magic: Vec, + pub version: Vec, + pub mode: Vec, + pub flags: Vec, + pub schema: Vec, + pub context: Vec, + pub service: String, + pub method: String, + pub payload: Vec, + pub aux: Option>, + pub tag: Option>, + pub aead_on: bool, + pub compress: bool, + pub tag_start: Option, + } + + fn read_field(frame: &[u8], off: usize) -> Result<(Vec, usize, usize), String> { + let (len, new_off) = tleb3::decode_len(frame, off)?; + let l = len as usize; + let val_start = new_off; + let val_end = val_start + l; + if val_end > frame.len() { + return Err("field length exceeds frame".into()); + } + Ok((frame[val_start..val_end].to_vec(), val_end, off)) + } + + pub fn decode(frame: &[u8]) -> Result { + let mut off = 0usize; + let (magic, off1, _) = read_field(frame, off)?; + off = off1; + let (version, off2, _) = read_field(frame, off)?; + off = off2; + let (mode, off3, _) = read_field(frame, off)?; + off = off3; + let (flags, off4, _) = read_field(frame, off)?; + off = off4; + let (schema, off5, _) = read_field(frame, off)?; + off = off5; + let (context, off6, _) = read_field(frame, off)?; + off = off6; + let (svc, off7, _) = read_field(frame, off)?; + off = off7; + let (method, off8, _) = read_field(frame, off)?; + off = off8; + let (payload, off9, _) = read_field(frame, off)?; + off = off9; + + let trits = tritpack243::unpack(&flags)?; + let aead_on = trits.get(0) == Some(&2u8); + let compress = trits.get(1) == Some(&2u8); + + let mut aux: Option> = None; + let mut tag: Option> = None; + let mut tag_start: Option = None; + + let remaining = frame.len().saturating_sub(off); + if remaining > 0 { + if aead_on { + // If two fields remain, treat as aux + tag. If one remains, tag only. + let (first, off10, start10) = read_field(frame, off)?; + off = off10; + if off < frame.len() { + let (tag_val, off11, start11) = read_field(frame, off)?; + off = off11; + aux = Some(first); + tag = Some(tag_val); + tag_start = Some(start11); + } else { + tag = Some(first); + tag_start = Some(start10); + } + } else { + let (aux_val, off10, _) = read_field(frame, off)?; + off = off10; + aux = Some(aux_val); + } + } + if off != frame.len() { + return Err("extra bytes after envelope parsing".into()); + } + Ok(DecodedEnvelope { + magic, + version, + mode, + flags, + schema, + context, + service: String::from_utf8(svc).map_err(|_| "service not utf8")?, + method: String::from_utf8(method).map_err(|_| "method not utf8")?, + payload, + aux, + tag, + aead_on, + compress, + tag_start, + }) + } } pub mod avroenc { // Avro subset encoders: zigzag, varint, string, bytes, array, map, union, enum, records for control+HG - fn zigzag(n: i64) -> u64 { ((n << 1) ^ (n >> 63)) as u64 } + fn zigzag(n: i64) -> u64 { + ((n << 1) ^ (n >> 63)) as u64 + } pub fn enc_varint(mut u: u64) -> Vec { let mut out = Vec::new(); while (u & !0x7F) != 0 { @@ -166,9 +332,19 @@ pub mod avroenc { out.push(u as u8); out } - pub fn enc_long(n: i64) -> Vec { enc_varint(zigzag(n)) } - pub fn enc_int(n: i32) -> Vec { enc_long(n as i64) } - pub fn enc_bool(v: bool) -> Vec { if v { vec![1] } else { vec![0] } } + pub fn enc_long(n: i64) -> Vec { + enc_varint(zigzag(n)) + } + pub fn enc_int(n: i32) -> Vec { + enc_long(n as i64) + } + pub fn enc_bool(v: bool) -> Vec { + if v { + vec![1] + } else { + vec![0] + } + } pub fn enc_string(s: &str) -> Vec { let b = s.as_bytes(); let mut out = enc_long(b.len() as i64); @@ -180,16 +356,22 @@ pub mod avroenc { out.extend_from_slice(b); out } - pub fn enc_array(items: &[T], f: fn(&T)->Vec) -> Vec { - if items.is_empty() { return vec![0]; } + pub fn enc_array(items: &[T], f: fn(&T) -> Vec) -> Vec { + if items.is_empty() { + return vec![0]; + } let mut out = Vec::new(); out.extend(enc_long(items.len() as i64)); - for it in items { out.extend(f(it)); } + for it in items { + out.extend(f(it)); + } out.push(0); out } pub fn enc_map(m: &[(&str, &str)]) -> Vec { - if m.is_empty() { return vec![0]; } + if m.is_empty() { + return vec![0]; + } let mut out = Vec::new(); out.extend(enc_long(m.len() as i64)); for (k, v) in m { @@ -204,10 +386,17 @@ pub mod avroenc { out.extend(payload); out } - pub fn enc_enum(index: i32) -> Vec { enc_int(index) } + pub fn enc_enum(index: i32) -> Vec { + enc_int(index) + } // Control - pub fn enc_Hello(modes:&[&str], suites:&[&str], comp:&[&str], context_uri: Option<&str>) -> Vec { + pub fn enc_Hello( + modes: &[&str], + suites: &[&str], + comp: &[&str], + context_uri: Option<&str>, + ) -> Vec { let mut out = Vec::new(); out.extend(enc_array(modes, |s| enc_string(s))); out.extend(enc_array(suites, |s| enc_string(s))); @@ -218,14 +407,14 @@ pub mod avroenc { } out } - pub fn enc_Choose(mode:&str, suite:&str, comp:&str) -> Vec { + pub fn enc_Choose(mode: &str, suite: &str, comp: &str) -> Vec { let mut out = Vec::new(); out.extend(enc_string(mode)); out.extend(enc_string(suite)); out.extend(enc_string(comp)); out } - pub fn enc_Error(code:i32, msg:&str, details: Option<&[u8]>) -> Vec { + pub fn enc_Error(code: i32, msg: &str, details: Option<&[u8]>) -> Vec { let mut out = Vec::new(); out.extend(enc_int(code)); out.extend(enc_string(msg)); @@ -237,7 +426,7 @@ pub mod avroenc { } // Hypergraph - pub fn enc_Vertex(vid:&str, label: Option<&str>, attrs: &[(&str,&str)]) -> Vec { + pub fn enc_Vertex(vid: &str, label: Option<&str>, attrs: &[(&str, &str)]) -> Vec { let mut out = Vec::new(); out.extend(enc_string(vid)); match label { @@ -247,7 +436,12 @@ pub mod avroenc { out.extend(enc_map(attrs)); out } - pub fn enc_Hyperedge(eid:&str, members:&[&str], weight: Option, attrs:&[(&str,&str)]) -> Vec { + pub fn enc_Hyperedge( + eid: &str, + members: &[&str], + weight: Option, + attrs: &[(&str, &str)], + ) -> Vec { let mut out = Vec::new(); out.extend(enc_string(eid)); out.extend(enc_array(members, |s| enc_string(s))); @@ -258,7 +452,7 @@ pub mod avroenc { out.extend(enc_map(attrs)); out } - pub fn enc_HGRequest_AddVertex(vid:&str, label: Option<&str>) -> Vec { + pub fn enc_HGRequest_AddVertex(vid: &str, label: Option<&str>) -> Vec { let mut out = Vec::new(); out.extend(enc_enum(0)); out.extend(enc_union(1, enc_Vertex(vid, label, &[]))); @@ -268,7 +462,7 @@ pub mod avroenc { out.extend(enc_union(0, vec![])); // k null out } - pub fn enc_HGRequest_AddHyperedge(eid:&str, members:&[&str], weight: Option) -> Vec { + pub fn enc_HGRequest_AddHyperedge(eid: &str, members: &[&str], weight: Option) -> Vec { let mut out = Vec::new(); out.extend(enc_enum(1)); out.extend(enc_union(0, vec![])); // vertex null @@ -278,7 +472,27 @@ pub mod avroenc { out.extend(enc_union(0, vec![])); // k null out } - pub fn enc_HGRequest_QueryNeighbors(vid:&str, k:i32) -> Vec { + pub fn enc_HGRequest_RemoveVertex(vid: &str) -> Vec { + let mut out = Vec::new(); + out.extend(enc_enum(2)); + out.extend(enc_union(0, vec![])); // vertex null + out.extend(enc_union(0, vec![])); // edge null + out.extend(enc_union(1, enc_string(vid))); + out.extend(enc_union(0, vec![])); // eid null + out.extend(enc_union(0, vec![])); // k null + out + } + pub fn enc_HGRequest_RemoveHyperedge(eid: &str) -> Vec { + let mut out = Vec::new(); + out.extend(enc_enum(3)); + out.extend(enc_union(0, vec![])); // vertex null + out.extend(enc_union(0, vec![])); // edge null + out.extend(enc_union(0, vec![])); // vid null + out.extend(enc_union(1, enc_string(eid))); + out.extend(enc_union(0, vec![])); // k null + out + } + pub fn enc_HGRequest_QueryNeighbors(vid: &str, k: i32) -> Vec { let mut out = Vec::new(); out.extend(enc_enum(4)); out.extend(enc_union(0, vec![])); @@ -288,7 +502,7 @@ pub mod avroenc { out.extend(enc_union(1, enc_int(k))); out } - pub fn enc_HGRequest_GetSubgraph(vid:&str, k:i32) -> Vec { + pub fn enc_HGRequest_GetSubgraph(vid: &str, k: i32) -> Vec { let mut out = Vec::new(); out.extend(enc_enum(5)); out.extend(enc_union(0, vec![])); @@ -298,7 +512,12 @@ pub mod avroenc { out.extend(enc_union(1, enc_int(k))); out } - pub fn enc_HGResponse(ok: bool, err: Option<&str>, vertices:&[(&str, Option<&str>)], edges:&[(&str, Vec<&str>, Option)]) -> Vec { + pub fn enc_HGResponse( + ok: bool, + err: Option<&str>, + vertices: &[(&str, Option<&str>)], + edges: &[(&str, Vec<&str>, Option)], + ) -> Vec { let mut out = Vec::new(); out.extend(enc_bool(ok)); match err { @@ -306,22 +525,34 @@ pub mod avroenc { Some(e) => out.extend(enc_union(1, enc_string(e))), } // vertices - let vbytes = vertices.iter().map(|(vid,l)| enc_Vertex(vid, *l, &[])).collect::>(); + let vbytes = vertices + .iter() + .map(|(vid, l)| enc_Vertex(vid, *l, &[])) + .collect::>(); let mut arr = Vec::new(); - if vbytes.is_empty() { arr.push(0); } - else { + if vbytes.is_empty() { + arr.push(0); + } else { arr.extend(enc_long(vbytes.len() as i64)); - for vb in vbytes { arr.extend(vb); } + for vb in vbytes { + arr.extend(vb); + } arr.push(0); } out.extend(arr); // edges - let ebytes = edges.iter().map(|(eid,mem,w)| enc_Hyperedge(eid, mem, *w, &[])).collect::>(); + let ebytes = edges + .iter() + .map(|(eid, mem, w)| enc_Hyperedge(eid, mem, *w, &[])) + .collect::>(); let mut arr2 = Vec::new(); - if ebytes.is_empty() { arr2.push(0); } - else { + if ebytes.is_empty() { + arr2.push(0); + } else { arr2.extend(enc_long(ebytes.len() as i64)); - for eb in ebytes { arr2.extend(eb); } + for eb in ebytes { + arr2.extend(eb); + } arr2.push(0); } out.extend(arr2); @@ -329,36 +560,452 @@ pub mod avroenc { } } +pub mod avrodec { + use super::avroenc; + + fn zigzag_inv(u: u64) -> i64 { + ((u >> 1) as i64) ^ (-((u & 1) as i64)) + } + + pub fn dec_varint(bytes: &[u8], mut off: usize) -> Result<(u64, usize), String> { + let mut shift = 0u32; + let mut out = 0u64; + loop { + if off >= bytes.len() { + return Err("EOF in varint".into()); + } + let b = bytes[off]; + off += 1; + out |= ((b & 0x7F) as u64) << shift; + if (b & 0x80) == 0 { + break; + } + shift += 7; + if shift > 63 { + return Err("varint overflow".into()); + } + } + Ok((out, off)) + } + + pub fn dec_long(bytes: &[u8], off: usize) -> Result<(i64, usize), String> { + let (u, new_off) = dec_varint(bytes, off)?; + Ok((zigzag_inv(u), new_off)) + } + + pub fn dec_int(bytes: &[u8], off: usize) -> Result<(i32, usize), String> { + let (v, new_off) = dec_long(bytes, off)?; + Ok((v as i32, new_off)) + } + + pub fn dec_bool(bytes: &[u8], off: usize) -> Result<(bool, usize), String> { + if off >= bytes.len() { + return Err("EOF in bool".into()); + } + Ok((bytes[off] != 0, off + 1)) + } + + pub fn dec_string(bytes: &[u8], off: usize) -> Result<(String, usize), String> { + let (len, mut new_off) = dec_long(bytes, off)?; + if len < 0 { + return Err("negative string length".into()); + } + let l = len as usize; + let end = new_off + l; + if end > bytes.len() { + return Err("string length exceeds buffer".into()); + } + let s = std::str::from_utf8(&bytes[new_off..end]) + .map_err(|_| "invalid utf8")? + .to_string(); + new_off = end; + Ok((s, new_off)) + } + + pub fn dec_bytes(bytes: &[u8], off: usize) -> Result<(Vec, usize), String> { + let (len, mut new_off) = dec_long(bytes, off)?; + if len < 0 { + return Err("negative bytes length".into()); + } + let l = len as usize; + let end = new_off + l; + if end > bytes.len() { + return Err("bytes length exceeds buffer".into()); + } + let out = bytes[new_off..end].to_vec(); + new_off = end; + Ok((out, new_off)) + } + + pub fn dec_array_strings(bytes: &[u8], mut off: usize) -> Result<(Vec, usize), String> { + let (count, mut new_off) = dec_long(bytes, off)?; + if count == 0 { + return Ok((Vec::new(), new_off)); + } + if count < 0 { + return Err("negative array block count".into()); + } + let mut out = Vec::new(); + for _ in 0..count { + let (s, n2) = dec_string(bytes, new_off)?; + new_off = n2; + out.push(s); + } + let (end_count, end_off) = dec_long(bytes, new_off)?; + if end_count != 0 { + return Err("non-zero array terminator".into()); + } + Ok((out, end_off)) + } + + pub fn dec_map_strings( + bytes: &[u8], + mut off: usize, + ) -> Result<(Vec<(String, String)>, usize), String> { + let (count, mut new_off) = dec_long(bytes, off)?; + if count == 0 { + return Ok((Vec::new(), new_off)); + } + if count < 0 { + return Err("negative map block count".into()); + } + let mut out = Vec::new(); + for _ in 0..count { + let (k, o1) = dec_string(bytes, new_off)?; + let (v, o2) = dec_string(bytes, o1)?; + new_off = o2; + out.push((k, v)); + } + let (end_count, end_off) = dec_long(bytes, new_off)?; + if end_count != 0 { + return Err("non-zero map terminator".into()); + } + Ok((out, end_off)) + } + + pub fn dec_union_index(bytes: &[u8], off: usize) -> Result<(i64, usize), String> { + dec_long(bytes, off) + } + + #[derive(Debug, Clone)] + pub struct Vertex { + pub vid: String, + pub label: Option, + pub attr: Vec<(String, String)>, + } + + #[derive(Debug, Clone)] + pub struct Hyperedge { + pub eid: String, + pub members: Vec, + pub weight: Option, + pub attr: Vec<(String, String)>, + } + + #[derive(Debug, Clone)] + pub struct HGRequest { + pub op: i32, + pub vertex: Option, + pub hyperedge: Option, + pub vid: Option, + pub eid: Option, + pub k: Option, + } + + #[derive(Debug, Clone)] + pub struct HGResponse { + pub ok: bool, + pub err: Option, + pub vertices: Vec<(String, Option)>, + pub edges: Vec<(String, Vec, Option)>, + } + + pub fn dec_vertex(bytes: &[u8], off: usize) -> Result<(Vertex, usize), String> { + let (vid, mut o1) = dec_string(bytes, off)?; + let (idx, mut o2) = dec_union_index(bytes, o1)?; + let label = if idx == 0 { + None + } else if idx == 1 { + let (s, o3) = dec_string(bytes, o2)?; + o2 = o3; + Some(s) + } else { + return Err("invalid union index for label".into()); + }; + let (attr, o4) = dec_map_strings(bytes, o2)?; + Ok((Vertex { vid, label, attr }, o4)) + } + + pub fn dec_hyperedge(bytes: &[u8], off: usize) -> Result<(Hyperedge, usize), String> { + let (eid, mut o1) = dec_string(bytes, off)?; + let (members, mut o2) = dec_array_strings(bytes, o1)?; + let (idx, mut o3) = dec_union_index(bytes, o2)?; + let weight = if idx == 0 { + None + } else if idx == 1 { + let (w, o4) = dec_long(bytes, o3)?; + o3 = o4; + Some(w) + } else { + return Err("invalid union index for weight".into()); + }; + let (attr, o5) = dec_map_strings(bytes, o3)?; + Ok(( + Hyperedge { + eid, + members, + weight, + attr, + }, + o5, + )) + } + + pub fn dec_hg_request(bytes: &[u8]) -> Result { + let (op, mut off) = dec_int(bytes, 0)?; + let (idx_v, mut off2) = dec_union_index(bytes, off)?; + let mut vertex = None; + if idx_v == 1 { + let (v, o3) = dec_vertex(bytes, off2)?; + vertex = Some(v); + off2 = o3; + } + let (idx_e, mut off3) = dec_union_index(bytes, off2)?; + let mut hyperedge = None; + if idx_e == 1 { + let (e, o4) = dec_hyperedge(bytes, off3)?; + hyperedge = Some(e); + off3 = o4; + } + let (idx_vid, mut off4) = dec_union_index(bytes, off3)?; + let mut vid = None; + if idx_vid == 1 { + let (s, o5) = dec_string(bytes, off4)?; + vid = Some(s); + off4 = o5; + } + let (idx_eid, mut off5) = dec_union_index(bytes, off4)?; + let mut eid = None; + if idx_eid == 1 { + let (s, o6) = dec_string(bytes, off5)?; + eid = Some(s); + off5 = o6; + } + let (idx_k, mut off6) = dec_union_index(bytes, off5)?; + let mut k = None; + if idx_k == 1 { + let (kv, o7) = dec_int(bytes, off6)?; + k = Some(kv); + off6 = o7; + } + if off6 != bytes.len() { + return Err("extra bytes after HGRequest".into()); + } + Ok(HGRequest { + op, + vertex, + hyperedge, + vid, + eid, + k, + }) + } + + pub fn enc_hg_request(req: &HGRequest) -> Result, String> { + match req.op { + 0 => { + let v = req.vertex.as_ref().ok_or("missing vertex")?; + if !v.attr.is_empty() { + return Err("vertex attr not supported in encoder".into()); + } + Ok(avroenc::enc_HGRequest_AddVertex(&v.vid, v.label.as_deref())) + } + 1 => { + let e = req.hyperedge.as_ref().ok_or("missing hyperedge")?; + if !e.attr.is_empty() { + return Err("hyperedge attr not supported in encoder".into()); + } + let members = e.members.iter().map(|s| s.as_str()).collect::>(); + Ok(avroenc::enc_HGRequest_AddHyperedge( + &e.eid, &members, e.weight, + )) + } + 2 => { + let vid = req.vid.as_ref().ok_or("missing vid")?; + Ok(avroenc::enc_HGRequest_RemoveVertex(vid)) + } + 3 => { + let eid = req.eid.as_ref().ok_or("missing eid")?; + Ok(avroenc::enc_HGRequest_RemoveHyperedge(eid)) + } + 4 => { + let vid = req.vid.as_ref().ok_or("missing vid")?; + let k = req.k.unwrap_or(1); + Ok(avroenc::enc_HGRequest_QueryNeighbors(vid, k)) + } + 5 => { + let vid = req.vid.as_ref().ok_or("missing vid")?; + let k = req.k.unwrap_or(1); + Ok(avroenc::enc_HGRequest_GetSubgraph(vid, k)) + } + _ => Err("unsupported op".into()), + } + } + + pub fn dec_hg_response(bytes: &[u8]) -> Result { + let (ok, mut off) = dec_bool(bytes, 0)?; + let (idx_err, mut off2) = dec_union_index(bytes, off)?; + let err = if idx_err == 0 { + None + } else if idx_err == 1 { + let (s, o3) = dec_string(bytes, off2)?; + off2 = o3; + Some(s) + } else { + return Err("invalid union index for err".into()); + }; + let (vcount, mut off3) = dec_long(bytes, off2)?; + let mut vertices = Vec::new(); + if vcount < 0 { + return Err("negative vertices block".into()); + } + if vcount == 0 { + // ok + } else { + for _ in 0..vcount { + let (v, o4) = dec_vertex(bytes, off3)?; + if !v.attr.is_empty() { + return Err("vertex attr not supported in response fixtures".into()); + } + off3 = o4; + vertices.push((v.vid, v.label)); + } + let (endc, o5) = dec_long(bytes, off3)?; + if endc != 0 { + return Err("non-zero vertices terminator".into()); + } + off3 = o5; + } + let (ecount, mut off4) = dec_long(bytes, off3)?; + let mut edges = Vec::new(); + if ecount < 0 { + return Err("negative edges block".into()); + } + if ecount == 0 { + // ok + } else { + for _ in 0..ecount { + let (e, o5) = dec_hyperedge(bytes, off4)?; + if !e.attr.is_empty() { + return Err("edge attr not supported in response fixtures".into()); + } + off4 = o5; + edges.push((e.eid, e.members, e.weight)); + } + let (endc, o6) = dec_long(bytes, off4)?; + if endc != 0 { + return Err("non-zero edges terminator".into()); + } + off4 = o6; + } + if off4 != bytes.len() { + return Err("extra bytes after HGResponse".into()); + } + Ok(HGResponse { + ok, + err, + vertices, + edges, + }) + } + + pub fn enc_hg_response(resp: &HGResponse) -> Result, String> { + let vertices = resp + .vertices + .iter() + .map(|(vid, label)| (vid.as_str(), label.as_deref())) + .collect::>(); + let edges = resp + .edges + .iter() + .map(|(eid, members, weight)| { + let members_ref = members.iter().map(|s| s.as_str()).collect::>(); + (eid.as_str(), members_ref, *weight) + }) + .collect::>(); + Ok(avroenc::enc_HGResponse( + resp.ok, + resp.err.as_deref(), + &vertices, + &edges, + )) + } +} pub mod tritrpc_v1_tests { - use super::{tleb3, tritpack243, avroenc}; + use super::envelope; use chacha20poly1305::aead::{Aead, KeyInit}; use chacha20poly1305::XChaCha20Poly1305; use std::collections::HashMap; use std::fs; + use subtle::ConstantTimeEq; pub fn verify_file(fx: &str, nonces_path: &str) -> String { - let key = [0u8;32]; + let key = [0u8; 32]; let pairs = read_pairs(fx); let nonces = read_nonces(nonces_path); let mut ok = 0usize; for (name, frame) in pairs { - let fields = split_fields(&frame); - let flags = &fields[3]; - if aead_bit(flags) { - let tag = fields.last().unwrap(); - let aad = aad_before_last(&frame); + let decoded = envelope::decode(&frame).expect("decode envelope"); + assert_eq!( + decoded.schema.as_slice(), + envelope::SCHEMA_ID_32.as_slice(), + "schema id mismatch {}", + name + ); + assert_eq!( + decoded.context.as_slice(), + envelope::CONTEXT_ID_32.as_slice(), + "context id mismatch {}", + name + ); + let repacked = envelope::build( + &decoded.service, + &decoded.method, + &decoded.payload, + decoded.aux.as_deref(), + decoded.tag.as_deref(), + decoded.aead_on, + decoded.compress, + ); + assert_eq!(repacked, frame, "repack mismatch {}", name); + if decoded.aead_on { + let tag = decoded.tag.as_ref().expect("missing tag"); let nonce = nonces.get(&name).expect("nonce missing"); + assert_eq!(nonce.len(), 24, "nonce size mismatch {}", name); + assert_eq!(tag.len(), 16, "tag size mismatch {}", name); + let aad_start = decoded.tag_start.expect("tag start missing"); + let aad = &frame[..aad_start]; let aead = XChaCha20Poly1305::new(&key.into()); - let ct = aead.encrypt(nonce.as_slice().into(), chacha20poly1305::aead::Payload{ msg: b"", aad }).unwrap(); - assert_eq!(&ct[ct.len()-16..], tag.as_slice(), "tag mismatch {}", name); + let ct = aead + .encrypt( + nonce.as_slice().into(), + chacha20poly1305::aead::Payload { msg: b"", aad }, + ) + .unwrap(); + let computed = &ct[ct.len() - 16..]; + assert!( + computed.ct_eq(tag.as_slice()).into(), + "tag mismatch {}", + name + ); } ok += 1; } format!("Verified {} frames in {}", ok, fx) } - fn read_pairs(path:&str)->Vec<(String, Vec)>{ + fn read_pairs(path: &str) -> Vec<(String, Vec)> { let txt = fs::read_to_string(path).expect("read fixtures"); txt.lines() .filter(|l| !l.is_empty() && !l.starts_with('#')) @@ -371,7 +1018,7 @@ pub mod tritrpc_v1_tests { }) .collect() } - fn read_nonces(path:&str)->HashMap>{ + fn read_nonces(path: &str) -> HashMap> { let txt = fs::read_to_string(path).expect("read nonces"); txt.lines() .filter(|l| !l.is_empty()) @@ -380,34 +1027,8 @@ pub mod tritrpc_v1_tests { let name = it.next().unwrap().to_string(); let hexs = it.next().unwrap(); (name, hex::decode(hexs).unwrap()) - }).collect() - } - fn split_fields(buf: &[u8]) -> Vec> { - let mut off = 0usize; - let mut fields: Vec> = Vec::new(); - while off < buf.len() { - let (len, new_off) = super::tleb3::decode_len(buf, off).unwrap(); - let l = len as usize; - let val_off = new_off; - let val_end = val_off + l; - fields.push(buf[val_off..val_end].to_vec()); - off = val_end; - } - fields - } - fn aead_bit(flags_bytes: &[u8]) -> bool { - let trits = super::tritpack243::unpack(flags_bytes).unwrap(); - trits.get(0) == Some(&2u8) - } - fn aad_before_last(frame: &[u8]) -> &[u8] { - let mut off = 0usize; - let mut last_start = 0usize; - while off < frame.len() { - let (len, new_off) = super::tleb3::decode_len(frame, off).unwrap(); - last_start = off; - off = new_off + len as usize; - } - &frame[..last_start] + }) + .collect() } } @@ -415,7 +1036,7 @@ pub mod avroenc_json { use super::avroenc; use serde_json::Value; - pub fn enc_HGRequest(v:&Value) -> Vec { + pub fn enc_HGRequest(v: &Value) -> Vec { let op = v["op"].as_str().unwrap(); match op { "AddVertex" => { @@ -426,7 +1047,12 @@ pub mod avroenc_json { } "AddHyperedge" => { let eid = v["edge"]["eid"].as_str().unwrap(); - let members = v["edge"]["members"].as_array().unwrap().iter().map(|x| x.as_str().unwrap()).collect::>(); + let members = v["edge"]["members"] + .as_array() + .unwrap() + .iter() + .map(|x| x.as_str().unwrap()) + .collect::>(); avroenc::enc_HGRequest_AddHyperedge(eid, &members, Some(1)) } "QueryNeighbors" => { @@ -440,108 +1066,72 @@ pub mod avroenc_json { avroenc::enc_HGRequest_GetSubgraph(vid, k) } "RemoveVertex" => { - // simple: not used in CLI pack example - avroenc::enc_HGRequest_GetSubgraph("a", 1) + let vid = v["vid"].as_str().unwrap_or("a"); + avroenc::enc_HGRequest_RemoveVertex(vid) } - _ => avroenc::enc_HGRequest_GetSubgraph("a", 1) + "RemoveHyperedge" => { + let eid = v["eid"].as_str().unwrap_or("e1"); + avroenc::enc_HGRequest_RemoveHyperedge(eid) + } + _ => avroenc::enc_HGRequest_GetSubgraph("a", 1), } } pub fn enc_HGResponse_json(v: &Value) -> Vec { let ok = v["ok"].as_bool().unwrap_or(true); let err = v.get("err").and_then(|e| e.as_str()); - let vertices = v["vertices"].as_array().unwrap_or(&vec![]).iter().map(|x| { - (x["vid"].as_str().unwrap(), x.get("label").and_then(|l| l.as_str())) - }).collect::>(); - let edges = v["edges"].as_array().unwrap_or(&vec![]).iter().map(|x| { - let eid = x["eid"].as_str().unwrap(); - let members = x["members"].as_array().unwrap().iter().map(|m| m.as_str().unwrap()).collect::>(); - let weight = x.get("weight").and_then(|w| w.as_i64()); - (eid, members, weight) - }).collect::>(); + let vertices = v["vertices"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .map(|x| { + ( + x["vid"].as_str().unwrap(), + x.get("label").and_then(|l| l.as_str()), + ) + }) + .collect::>(); + let edges = v["edges"] + .as_array() + .unwrap_or(&vec![]) + .iter() + .map(|x| { + let eid = x["eid"].as_str().unwrap(); + let members = x["members"] + .as_array() + .unwrap() + .iter() + .map(|m| m.as_str().unwrap()) + .collect::>(); + let weight = x.get("weight").and_then(|w| w.as_i64()); + (eid, members, weight) + }) + .collect::>(); super::avroenc::enc_HGResponse(ok, err, &vertices, &edges) } } - -pub mod avrodec { - // Minimal Avro Binary decoders for our subset (string, int/long, array<>, map, union, enum) - pub fn dec_varint(mut it: &mut &[u8]) -> u64 { - let mut val: u64 = 0; - let mut shift = 0; - loop { - let b = it[0]; *it = &it[1..]; - val |= ((b & 0x7F) as u64) << shift; - if (b & 0x80) == 0 { break; } - shift += 7; - } - val - } - pub fn dec_long(it: &mut &[u8]) -> i64 { - let u = dec_varint(it); - // zigzag inverse - ((u >> 1) as i64) ^ (-((u & 1) as i64)) - } - pub fn dec_int(it: &mut &[u8]) -> i32 { dec_long(it) as i32 } - pub fn dec_string(it: &mut &[u8]) -> String { - let len = dec_long(it) as usize; - let s = std::str::from_utf8(&it[..len]).unwrap(); - *it = &it[len..]; - s.to_string() - } - pub fn dec_bytes(it: &mut &[u8]) -> Vec { - let len = dec_long(it) as usize; - let b = it[..len].to_vec(); - *it = &it[len..]; - b - } - pub fn dec_union_index(it: &mut &[u8]) -> i64 { dec_long(it) } - pub fn dec_enum_index(it: &mut &[u8]) -> i32 { dec_int(it) } - pub fn dec_array(it: &mut &[u8], mut f: impl FnMut(&mut &[u8])->T) -> Vec { - let mut out = Vec::new(); - let mut count = dec_long(it); - if count == 0 { return out; } - while count != 0 { - if count < 0 { let _ = dec_long(it); count = -count; } - for _ in 0..count { out.push(f(it)); } - count = dec_long(it); - } - out - } - pub fn dec_map(it: &mut &[u8], mut fv: impl FnMut(&mut &[u8])->V) -> std::collections::BTreeMap { - let mut out = std::collections::BTreeMap::new(); - let mut count = dec_long(it); - if count == 0 { return out; } - while count != 0 { - if count < 0 { let _ = dec_long(it); count = -count; } - for _ in 0..count { - let k = dec_string(it); - let v = fv(it); - out.insert(k, v); - } - count = dec_long(it); - } - out - } -} - pub mod pathb { use super::tleb3; use super::tritpack243; pub fn bt_encode(mut n: i64) -> Vec { let mut digits: Vec = vec![]; - if n == 0 { digits.push(0); } - else { + if n == 0 { + digits.push(0); + } else { while n != 0 { let mut rem = (n % 3) as i8; n /= 3; - if rem == 2 { rem = -1; n += 1; } + if rem == 2 { + rem = -1; + n += 1; + } digits.push(rem); } digits.reverse(); } - let trits: Vec = digits.into_iter().map(|d| (d+1) as u8).collect(); + let trits: Vec = digits.into_iter().map(|d| (d + 1) as u8).collect(); let mut out = tleb3::encode_len(trits.len() as u64); out.extend(tritpack243::pack(&trits)); out @@ -553,28 +1143,39 @@ pub mod pathb { out } - pub fn enc_enum(index: u64) -> Vec { tleb3::encode_len(index) } - pub fn enc_union_index(index: u64) -> Vec { tleb3::encode_len(index) } + pub fn enc_enum(index: u64) -> Vec { + tleb3::encode_len(index) + } + pub fn enc_union_index(index: u64) -> Vec { + tleb3::encode_len(index) + } - pub fn enc_array(items:&[T], f: fn(&T)->Vec) -> Vec { - if items.is_empty() { return vec![0]; } + pub fn enc_array(items: &[T], f: fn(&T) -> Vec) -> Vec { + if items.is_empty() { + return vec![0]; + } let mut out = tleb3::encode_len(items.len() as u64); - for it in items { out.extend(f(it)); } - out.push(0); out + for it in items { + out.extend(f(it)); + } + out.push(0); + out } - pub fn enc_map(m:&[(&str,&str)]) -> Vec { - if m.is_empty() { return vec![0]; } + pub fn enc_map(m: &[(&str, &str)]) -> Vec { + if m.is_empty() { + return vec![0]; + } let mut out = tleb3::encode_len(m.len() as u64); - for (k,v) in m { + for (k, v) in m { out.extend(enc_string(k)); out.extend(enc_string(v)); } - out.push(0); out + out.push(0); + out } } - pub mod pathb_dec { use super::{tleb3, tritpack243}; @@ -584,22 +1185,27 @@ pub mod pathb_dec { (val as usize, new_off) } - pub fn dec_string(bytes:&[u8], off: usize) -> (String, usize) { + pub fn dec_string(bytes: &[u8], off: usize) -> (String, usize) { let (l, o2) = dec_len(bytes, off); - let s = std::str::from_utf8(&bytes[o2..o2+l]).unwrap().to_string(); - (s, o2+l) + let s = std::str::from_utf8(&bytes[o2..o2 + l]).unwrap().to_string(); + (s, o2 + l) } - pub fn dec_union_index(bytes:&[u8], off: usize) -> (u64, usize) { + pub fn dec_union_index(bytes: &[u8], off: usize) -> (u64, usize) { let (u, o2) = super::tleb3::decode_len(bytes, off).unwrap(); (u, o2) } - pub fn dec_vertex(bytes:&[u8], off: usize) -> ((String, Option), usize) { + pub fn dec_vertex(bytes: &[u8], off: usize) -> ((String, Option), usize) { let (vid, o2) = dec_string(bytes, off); let (uix, o3) = dec_union_index(bytes, o2); - let (label, o4) = if uix == 0 { (None, o3) } else { let (s, p) = dec_string(bytes, o3); (Some(s), p) }; + let (label, o4) = if uix == 0 { + (None, o3) + } else { + let (s, p) = dec_string(bytes, o3); + (Some(s), p) + }; // skip attr map (length + entries) — for fixtures attr is empty (0x00) - ( (vid, label), o4 + 1 ) + ((vid, label), o4 + 1) } } diff --git a/rust/tritrpc_v1/tests/fixtures.rs b/rust/tritrpc_v1/tests/fixtures.rs index 5b94d76..c5e96e4 100644 --- a/rust/tritrpc_v1/tests/fixtures.rs +++ b/rust/tritrpc_v1/tests/fixtures.rs @@ -1,11 +1,11 @@ - -use std::fs; -use std::collections::HashMap; -use tritrpc_v1::{tleb3, tritpack243, envelope, avroenc}; use chacha20poly1305::aead::{Aead, KeyInit}; use chacha20poly1305::XChaCha20Poly1305; +use std::collections::HashMap; +use std::fs; +use subtle::ConstantTimeEq; +use tritrpc_v1::{avrodec, avroenc, envelope, tleb3, tritpack243}; -fn read_pairs(path:&str)->Vec<(String, Vec)>{ +fn read_pairs(path: &str) -> Vec<(String, Vec)> { let txt = fs::read_to_string(path).expect("read fixtures"); txt.lines() .filter(|l| !l.is_empty() && !l.starts_with('#')) @@ -19,7 +19,7 @@ fn read_pairs(path:&str)->Vec<(String, Vec)>{ .collect() } -fn read_nonces(path:&str)->HashMap>{ +fn read_nonces(path: &str) -> HashMap> { let txt = fs::read_to_string(path).expect("read nonces"); txt.lines() .filter(|l| !l.is_empty()) @@ -28,7 +28,8 @@ fn read_nonces(path:&str)->HashMap>{ let name = it.next().unwrap().to_string(); let hexs = it.next().unwrap(); (name, hex::decode(hexs).unwrap()) - }).collect() + }) + .collect() } fn split_fields(mut buf: &[u8]) -> Vec> { @@ -53,63 +54,98 @@ fn aead_bit(flags_bytes: &[u8]) -> bool { #[test] fn verify_all_frames_and_payloads() { let sets = vec![ - ("fixtures/vectors_hex.txt","fixtures/vectors_hex.txt.nonces"), - ("fixtures/vectors_hex_stream_avrochunk.txt","fixtures/vectors_hex_stream_avrochunk.txt.nonces"), - ("fixtures/vectors_hex_unary_rich.txt","fixtures/vectors_hex_unary_rich.txt.nonces"), - ("fixtures/vectors_hex_stream_avronested.txt","fixtures/vectors_hex_stream_avronested.txt.nonces"), + ( + "fixtures/vectors_hex.txt", + "fixtures/vectors_hex.txt.nonces", + ), + ( + "fixtures/vectors_hex_stream_avrochunk.txt", + "fixtures/vectors_hex_stream_avrochunk.txt.nonces", + ), + ( + "fixtures/vectors_hex_unary_rich.txt", + "fixtures/vectors_hex_unary_rich.txt.nonces", + ), + ( + "fixtures/vectors_hex_stream_avronested.txt", + "fixtures/vectors_hex_stream_avronested.txt.nonces", + ), ]; - let key = [0u8;32]; + let key = [0u8; 32]; for (fx, nx) in sets { let pairs = read_pairs(fx); let nonces = read_nonces(nx); for (name, frame) in pairs { let fields = split_fields(&frame); assert!(fields.len() >= 9, "{}", name); + let decoded = envelope::decode(&frame).expect("decode envelope"); + assert_eq!( + decoded.schema.as_slice(), + envelope::SCHEMA_ID_32.as_slice(), + "schema id mismatch {}", + name + ); + assert_eq!( + decoded.context.as_slice(), + envelope::CONTEXT_ID_32.as_slice(), + "context id mismatch {}", + name + ); + + let repacked = envelope::build( + &decoded.service, + &decoded.method, + &decoded.payload, + decoded.aux.as_deref(), + decoded.tag.as_deref(), + decoded.aead_on, + decoded.compress, + ); + assert_eq!(repacked, frame, "repack mismatch {}", name); + let flags = &fields[3]; let has_aead = aead_bit(flags); if has_aead { - // last field is tag - let tag = fields.last().unwrap(); - // AAD is everything before the last field; reconstruct by slicing - // We reconstruct by walking lengths: easier approach is to compute tag by encrypting empty with aad=the AAD bytes. - // AAD bytes are frame[.. frame.len() - (lenprefix(tag)+tag.len())], but we don't have lenprefix length. - // Instead, recompute by removing the final length+value pair using TLEB3 decode traversal. - // We'll rebuild the traversal to find the starting index of last field. - // Implementation: walk again until we reach the final field, computing offsets. - let mut off = 0usize; - let mut last_start = 0usize; - let mut idx = 0usize; - while off < frame.len() { - let (len, new_off) = tleb3::decode_len(&frame, off).unwrap(); - last_start = off; - off = new_off + len as usize; - idx += 1; - } - // now AAD is frame[..last_start] - let aad = &frame[..last_start]; + let tag = decoded.tag.as_ref().expect("missing tag"); + assert_eq!(tag.len(), 16, "tag size mismatch {}", name); let nonce = nonces.get(&name).expect("nonce missing"); - let strict = std::env::var("STRICT_AEAD").ok().as_deref()==Some("1"); + assert_eq!(nonce.len(), 24, "nonce size mismatch {}", name); + let aad_start = decoded.tag_start.expect("tag start missing"); + let aad = &frame[..aad_start]; + let strict = std::env::var("STRICT_AEAD").ok().as_deref() == Some("1"); let aead = XChaCha20Poly1305::new(&key.into()); - let ct = aead.encrypt(nonce.as_slice().into(), chacha20poly1305::aead::Payload{ msg: b"", aad }).unwrap(); - assert_eq!(&ct[ct.len()-16..], tag.as_slice(), "tag mismatch for {}", name); - - // Payload check for a few known names - if name.ends_with("hyper.v1.AddVertex_a.REQ") || name.ends_with("hyper.v1.AddVertex_a") { - let payload = &fields[8]; - let want = avroenc::enc_HGRequest_AddVertex("a", Some("A")); - assert_eq!(payload, &want, "payload mismatch {}", name); - } - if name.ends_with("hyper.v1.AddHyperedge_e1_ab.REQ") || name.ends_with("hyper.v1.AddHyperedge_e1_ab") { - let payload = &fields[8]; - let want = avroenc::enc_HGRequest_AddHyperedge("e1", &["a","b"], Some(1)); - assert_eq!(payload, &want, "payload mismatch {}", name); - } - if name.ends_with("hyper.v1.QueryNeighbors_a_k1.REQ") || name.ends_with("hyper.v1.QueryNeighbors_a_k1") { - let payload = &fields[8]; - let want = avroenc::enc_HGRequest_QueryNeighbors("a", 1); - assert_eq!(payload, &want, "payload mismatch {}", name); + let ct = aead + .encrypt( + nonce.as_slice().into(), + chacha20poly1305::aead::Payload { msg: b"", aad }, + ) + .unwrap(); + let computed = &ct[ct.len() - 16..]; + let matches = computed.ct_eq(tag.as_slice()).into(); + assert!(matches, "tag mismatch for {}", name); + if strict { + assert!(matches, "strict tag mismatch for {}", name); } } + + if decoded.method.ends_with(".REQ") { + let parsed = avrodec::dec_hg_request(&decoded.payload).expect("decode HGRequest"); + let recoded = avrodec::enc_hg_request(&parsed).expect("re-encode HGRequest"); + assert_eq!( + recoded, decoded.payload, + "HGRequest round-trip mismatch {}", + name + ); + } + if decoded.method.ends_with(".RSP") { + let parsed = avrodec::dec_hg_response(&decoded.payload).expect("decode HGResponse"); + let recoded = avrodec::enc_hg_response(&parsed).expect("re-encode HGResponse"); + assert_eq!( + recoded, decoded.payload, + "HGResponse round-trip mismatch {}", + name + ); + } } } } diff --git a/rust/tritrpc_v1/tests/vectors.rs b/rust/tritrpc_v1/tests/vectors.rs index 2dd6066..89dcd21 100644 --- a/rust/tritrpc_v1/tests/vectors.rs +++ b/rust/tritrpc_v1/tests/vectors.rs @@ -1,17 +1,16 @@ - -use tritrpc_v1::{tritpack243, tleb3}; +use tritrpc_v1::{tleb3, tritpack243}; #[test] fn micro_vectors() { - let b = tritpack243::pack(&[2,1,0,0,2]); + let b = tritpack243::pack(&[2, 1, 0, 0, 2]); assert_eq!(b, vec![0xBF]); - let b2 = tritpack243::pack(&[2,2,1]); + let b2 = tritpack243::pack(&[2, 2, 1]); assert_eq!(b2, vec![0xF5, 0x19]); } #[test] fn tleb3_roundtrip() { - for &n in [0u64,1,2,3,8,9,10,123,4096,65535].iter() { + for &n in [0u64, 1, 2, 3, 8, 9, 10, 123, 4096, 65535].iter() { let enc = tleb3::encode_len(n); let (dec, _) = tleb3::decode_len(&enc, 0).unwrap(); assert_eq!(dec, n); diff --git a/spec/README-full-spec.md b/spec/README-full-spec.md index 4364183..a1ab442 100644 --- a/spec/README-full-spec.md +++ b/spec/README-full-spec.md @@ -6,10 +6,21 @@ Strict-Initial (PoE), and the reference hypergraph service. For convenience, the **reference implementation** in `reference/tritrpc_v1.py` adheres to: - Path-A (Avro Binary Encoding) for payloads, -- AEAD lane (XChaCha20-Poly1305 if available; else BLAKE2b MAC fallback labeled as such), +- AEAD lane (XChaCha20-Poly1305), - HELLO/CHOOSE negotiation examples, -- Streaming with rolling 24-byte nonces, -- AUX structures: Trace, Sig (placeholder), PoE. +- AUX structures: Trace, Sig (placeholder), PoE (toy subset). + +**Port note (Go/Rust):** only **XChaCha20-Poly1305** is implemented. The BLAKE2b MAC fallback +described in early drafts remains **reference-only and non-normative** for this repository’s +ports and fixtures. Rolling nonces are not implemented in Go/Rust; fixtures rely on explicit +per-frame nonces in `*.nonces`. +The AUX field is treated as an opaque byte slice in Go/Rust; fixtures currently omit AUX. + +### AEAD AAD definition (normative for ports + fixtures) + +When AEAD is enabled, the tag is computed using **empty plaintext** with **AAD equal to the +exact envelope bytes up to (but not including) the length prefix of the final tag field**. +This means the AAD covers all prior fields, including payload and AUX if present. See `fixtures/` for **canonical hex vectors** generated by this reference.