diff --git a/.gitignore b/.gitignore index 80ef81a1..9d5e24ee 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,6 @@ Cargo.lock *.a *.dylib -constantine/out \ No newline at end of file +constantine/out +gnark/gnark-jni/*.h + diff --git a/gnark/gnark-jni/gnark-eip-196.go b/gnark/gnark-jni/gnark-eip-196.go index cd9c792b..3a14f3fe 100644 --- a/gnark/gnark-jni/gnark-eip-196.go +++ b/gnark/gnark-jni/gnark-eip-196.go @@ -31,6 +31,8 @@ const ( EIP196PreallocateForG2 = EIP196PreallocateForG1 * 2 // G2 points are encoded as 2 concatenated G1 points ) +var EIP196ScalarTwo = big.NewInt(2) + // bn254Modulus is the value 21888242871839275222246405745257275088696311157297823662689037894645226208583 var bn254Modulus = new(big.Int).SetBytes([]byte{ @@ -40,9 +42,6 @@ var bn254Modulus = new(big.Int).SetBytes([]byte{ 0x3c, 0x20, 0x8c, 0x16, 0xd8, 0x7c, 0xfd, 0x47, }) -// Predefine a zero slice of length 16 -var zeroEIP196Slice = make([]byte, 16) - //export eip196altbn128G1Add func eip196altbn128G1Add(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInputLen C.int, cOutputLen, cErrorLen *C.int) C.int { inputLen := int(cInputLen) @@ -77,7 +76,7 @@ func eip196altbn128G1Add(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInp if inputLen < 2*EIP196PreallocateForG1 { // if incomplete input is all zero, return p0 - if isAllZeroEIP196(input, 64) { + if isAllZeroEIP196(input, 64, 64) { ret := p0.Marshal() g1AffineEncode(ret, javaOutputBuf) *outputLen = EIP196PreallocateForG1 @@ -128,6 +127,12 @@ func eip196altbn128G1Mul(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInp // Convert input C pointers to Go slice input := (*[EIP196PreallocateForG1 + EIP196PreallocateForScalar]byte)(unsafe.Pointer(javaInputBuf))[:inputLen:inputLen] + // infinity check: + if isAllZeroEIP196(input, 0, 64) { + *outputLen = EIP196PreallocateForG1 + return 0 + } + // generate p0 g1 affine var p0 bn254.G1Affine err := safeUnmarshalEIP196(&p0, input, 0) @@ -154,8 +159,14 @@ func eip196altbn128G1Mul(javaInputBuf, javaOutputBuf, javaErrorBuf *C.char, cInp scalar := big.NewInt(0) scalar.SetBytes(scalarBytes[:]) - // multiply g1 point by scalar - result := p0.ScalarMultiplication(&p0, scalar) + var result *bn254.G1Affine + if scalar.Cmp(EIP196ScalarTwo) == 0 { + // if scalar == 2, double is faster + result = p0.Double(&p0) + } else { + // multiply g1 point by scalar + result = p0.ScalarMultiplication(&p0, scalar) + } // marshal the resulting point and encode directly to the output buffer ret := result.Marshal() @@ -326,9 +337,14 @@ func checkInFieldEIP196(data []byte) bool { } // isAllZero checks if all elements in the byte slice are zero -func isAllZeroEIP196(data []byte, offset int) bool { - if len(data) > 64 { - slice := data [offset:] +func isAllZeroEIP196(data []byte, offset, length int) bool { + + if len(data) > offset { + tail := offset + length + if len(data) < tail { + tail = len(data) + } + slice := data [offset:tail] for _, b := range slice { if b != 0 { return false