Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor, perf: prover/crypto/sis improvements #554

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions prover/crypto/ringsis/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
**/*.txt
5 changes: 0 additions & 5 deletions prover/crypto/ringsis/gen.go

This file was deleted.

201 changes: 12 additions & 189 deletions prover/crypto/ringsis/ringis.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,10 @@
package ringsis

import (
"bytes"
"encoding/binary"
"io"
"math"
"runtime"
"sync"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/fft"
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr/sis"
"github.com/consensys/linea-monorepo/prover/maths/common/smartvectors"
"github.com/consensys/linea-monorepo/prover/maths/field"
"github.com/consensys/linea-monorepo/prover/utils"
"github.com/consensys/linea-monorepo/prover/utils/parallel"

"github.com/consensys/linea-monorepo/prover/crypto/ringsis/ringsis_32_8"
"github.com/consensys/linea-monorepo/prover/crypto/ringsis/ringsis_64_16"
"github.com/consensys/linea-monorepo/prover/crypto/ringsis/ringsis_64_8"
)

const (
Expand All @@ -29,21 +15,12 @@ const (
// Key encapsulates the public parameters of an instance of the ring-SIS hash
// instance.
type Key struct {
// lock guards the access to the SIS key and prevents the user from hashing
// concurrently with the same SIS key.
lock *sync.Mutex
// gnarkInternal stores the SIS key itself and some precomputed domain
// twiddles.
gnarkInternal *sis.RSis
// Params provides the parameters of the ring-SIS instance (logTwoBound,
// degree etc)
Params
// twiddleCosets stores the list of twiddles that we use to implement the
// SIS parameters. The twiddleAreInternally are only used when dealing with
// the parameters modulusDegree=64 and logTwoBound=8 and is passed as input
// to the specially unrolled [sis.FFT64] function. They are thus optionally
// constructed when [GenerateKey] is called.
twiddleCosets []field.Element
}

// GenerateKey generates a ring-SIS key from a set of a [Params] and a max
Expand All @@ -62,33 +39,10 @@ func GenerateKey(params Params, maxNumFieldToHash int) Key {
}

res := Key{
lock: &sync.Mutex{},
gnarkInternal: rsis,
Params: params,
}

// Optimization for these specific parameters
if params.LogTwoBound == 8 && 1<<params.LogTwoDegree == 64 {
res.twiddleCosets = ringsis_64_8.PrecomputeTwiddlesCoset(
rsis.Domain.Generator,
rsis.Domain.FrMultiplicativeGen,
)
}

if params.LogTwoBound == 16 && 1<<params.LogTwoDegree == 64 {
res.twiddleCosets = ringsis_64_16.PrecomputeTwiddlesCoset(
rsis.Domain.Generator,
rsis.Domain.FrMultiplicativeGen,
)
}

if params.LogTwoBound == 8 && 1<<params.LogTwoDegree == 32 {
res.twiddleCosets = ringsis_32_8.PrecomputeTwiddlesCoset(
rsis.Domain.Generator,
rsis.Domain.FrMultiplicativeGen,
)
}

return res
}

Expand All @@ -104,55 +58,32 @@ func (s *Key) Ag() [][]field.Element {
// It is equivalent to calling r.Write(element.Marshal()); outBytes = r.Sum(nil);
func (s *Key) Hash(v []field.Element) []field.Element {

// since hashing writes into internal buffers
// we need to guard against races conditions.
s.lock.Lock()
defer s.lock.Unlock()

// write the input as byte
s.gnarkInternal.Reset()
for i := range v {
_, err := s.gnarkInternal.Write(v[i].Marshal())
if err != nil {
panic(err)
}
}
sum := s.gnarkInternal.Sum(make([]byte, 0, field.Bytes*s.OutputSize()))

// unmarshal the result
var rlen [4]byte
if len(sum) > math.MaxUint32*fr.Bytes {
panic("slice too long")
}
binary.BigEndian.PutUint32(rlen[:], uint32(len(sum)/fr.Bytes)) // #nosec G115 -- Overflow checked
reader := io.MultiReader(bytes.NewReader(rlen[:]), bytes.NewReader(sum))
var result fr.Vector
_, err := result.ReadFrom(reader)
if err != nil {
sum := make([]field.Element, s.OutputSize())
if err := s.gnarkInternal.Hash(v, sum); err != nil {
panic(err)
}

return result
return sum
}

// LimbSplit breaks down the entries of `v` into short limbs representing
// `LogTwoBound` bits each. The function then flatten and flatten them in a
// vector, casted as field elements in Montgommery form.
func (s *Key) LimbSplit(vReg []field.Element) []field.Element {

writer := bytes.Buffer{}
for i := range vReg {
b := vReg[i].Bytes() // big endian serialization
writer.Write(b[:])
}

buf := writer.Bytes()
m := make([]field.Element, len(vReg)*s.NumLimbs())
sis.LimbDecomposeBytes(buf, m, s.LogTwoBound)

it := sis.NewLimbIterator(sis.NewVectorIterator(vReg), s.LogTwoBound/8)

// The limbs are in regular form, we reconvert them back into montgommery
// form
var ok bool
for i := range m {
m[i][0], ok = it.NextLimb()
if !ok {
// the rest is 0 we can stop (note that if we change the padding
// policy we may need to change this)
break
}
m[i] = field.MulR(m[i])
}

Expand Down Expand Up @@ -256,111 +187,3 @@ func (s *Key) FlattenedKey() []field.Element {
}
return res
}

// TransversalHash evaluates SIS hashes transversally over a list of smart-vectors.
// Each smart-vector is seen as the row of a matrix. All rows must have the same
// size or panic. The function returns the hash of the columns. The column hashes
// are concatenated into a single array.
//
// The function is optimize to deal with the ring-SIS instances parametrized by
//
// - modulus degree: 64 log2(bound): 8
// - modulus degree: 64 log2(bound): 16
// - modulus degree: 32 log2(bound): 8
func (s *Key) TransversalHash(v []smartvectors.SmartVector) []field.Element {

// numRows stores the number of rows in the matrix to hash it must be
// strictly positive and be within the bounds of MaxNumFieldHashable.
numRows := len(v)

if numRows == 0 {
utils.Panic("Attempted to transversally hash a matrix with no rows")
}

if numRows > s.MaxNumFieldHashable() {
utils.Panic("Attempted to hash %v rows, but the limit is %v", numRows, s.MaxNumFieldHashable())
}

// numCols stores the number of columns in the matrix to hash et must be
// positive and all the rows must have the same size.
numCols := v[0].Len()

if numCols == 0 {
utils.Panic("Provided a 0-colums matrix")
}

for i := range v {
if v[i].Len() != numCols {
utils.Panic("Unexpected : all inputs smart-vectors should have the same length the first one has length %v, but #%v has length %v",
numCols, i, v[i].Len())
}
}

if s.LogTwoBound == 8 && s.LogTwoDegree == 6 {
return ringsis_64_8.TransversalHash(
s.gnarkInternal.Ag,
v,
s.twiddleCosets,
s.gnarkInternal.Domain,
)
}

if s.LogTwoBound == 16 && s.LogTwoDegree == 6 {
return ringsis_64_16.TransversalHash(
s.gnarkInternal.Ag,
v,
s.twiddleCosets,
s.gnarkInternal.Domain,
)
}

if s.LogTwoBound == 8 && s.LogTwoDegree == 5 {
return ringsis_32_8.TransversalHash(
s.gnarkInternal.Ag,
v,
s.twiddleCosets,
s.gnarkInternal.Domain,
)
}

res := make([]field.Element, numCols*s.OutputSize())

// Will contain keys per threads
keys := make([]*Key, runtime.GOMAXPROCS(0))
buffers := make([][]field.Element, runtime.GOMAXPROCS(0))

parallel.ExecuteThreadAware(
numCols,
func(threadID int) {
keys[threadID] = s.CopyWithFreshBuffer()
buffers[threadID] = make([]field.Element, numRows)
},
func(col, threadID int) {
buffer := buffers[threadID]
key := keys[threadID]
for row := 0; row < numRows; row++ {
buffer[row] = v[row].Get(col)
}
copy(res[col*key.OutputSize():(col+1)*key.OutputSize()], key.Hash(buffer))
})

return res
}

// CopyWithFreshBuffer creates a copy of the key with fresh buffers. Shallow
// copies the the key itself.
func (s *Key) CopyWithFreshBuffer() *Key {

// Since hashing consumes and mutates the buffer stored internally in
// `gnarkInternal` go race had figured there might be a race condition
// possibility.
s.lock.Lock()
defer s.lock.Unlock()

clonedRsis := s.gnarkInternal.CopyWithFreshBuffer()
return &Key{
lock: &sync.Mutex{},
gnarkInternal: &clonedRsis,
Params: s.Params,
}
}
44 changes: 0 additions & 44 deletions prover/crypto/ringsis/ringsis_32_8/limb_decompose_test.go

This file was deleted.

Loading
Loading