Skip to content

Commit

Permalink
[FEAT] Full refactor of gnark circuits which reuse code now
Browse files Browse the repository at this point in the history
  • Loading branch information
Scratch-net committed Nov 21, 2024
1 parent 2615cfa commit 4630106
Show file tree
Hide file tree
Showing 41 changed files with 585 additions and 1,264 deletions.
Binary file modified bin/gnark/linux-arm64-libprove.so
Binary file not shown.
Binary file modified bin/gnark/linux-arm64-libverify.so
Binary file not shown.
Binary file modified bin/gnark/linux-x86_64-libprove.so
Binary file not shown.
Binary file modified bin/gnark/linux-x86_64-libverify.so
Binary file not shown.
249 changes: 6 additions & 243 deletions gnark/circuits/aesV2/aes.go
Original file line number Diff line number Diff line change
@@ -1,249 +1,12 @@
package aes_v2

import (
"errors"
import "github.com/consensys/gnark/frontend"

"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/std/lookup/logderivlookup"
)

const BLOCKS = 5
const NB = 4

type AESWrapper struct {
Key []frontend.Variable
Nonce [12]frontend.Variable `gnark:",public"`
Counter frontend.Variable `gnark:",public"`
In [BLOCKS * 16]frontend.Variable `gnark:",public"`
Out [BLOCKS * 16]frontend.Variable `gnark:",public"`
}

type AESGadget struct {
api frontend.API
sbox *logderivlookup.Table
RCon [11]frontend.Variable
t0, t1, t2, t3 *logderivlookup.Table
keySize int
}

// retuns AESGadget instance which can be used inside a circuit
func NewAESGadget(api frontend.API, keySize int) AESGadget {

t0 := logderivlookup.New(api)
t1 := logderivlookup.New(api)
t2 := logderivlookup.New(api)
t3 := logderivlookup.New(api)
sbox := logderivlookup.New(api)
for i := 0; i < 256; i++ {
t0.Insert(T[0][i])
t1.Insert(T[1][i])
t2.Insert(T[2][i])
t3.Insert(T[3][i])
sbox.Insert(sbox0[i])
}

RCon := [11]frontend.Variable{0x8d, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36}

return AESGadget{api: api, sbox: sbox, RCon: RCon, t0: t0, t1: t1, t2: t2, t3: t3, keySize: keySize}
}

func (aes *AESWrapper) Define(api frontend.API) error {
keySize := len(aes.Key)

if keySize != 16 && keySize != 32 {
return errors.New("key size must be 16 or 32")
}

counter := aes.Counter
var counterBlock [16]frontend.Variable

gAes := NewAESGadget(api, keySize)

for i := 0; i < 12; i++ {
counterBlock[i] = aes.Nonce[i]
}
for b := 0; b < BLOCKS; b++ {
gAes.createIV(counter, counterBlock[:])
// encrypt counter under key
keystream := gAes.Encrypt(aes.Key, counterBlock)

for i := 0; i < 16; i++ {
api.AssertIsEqual(aes.Out[b*16+i], gAes.VariableXor(keystream[i], aes.In[b*16+i], 8))
}
counter = api.Add(counter, 1)
}
api.AssertIsEqual(counter, api.Add(aes.Counter, BLOCKS))

return nil
}

// aes128 encrypt function
func (aes *AESGadget) SubBytes(state [16]frontend.Variable) (res [16]frontend.Variable) {
t := aes.Subws(aes.sbox, state[:]...)
copy(res[:], t)
return res
}

// xor on bits of two frontend.Variables
func (aes *AESGadget) VariableXor(a frontend.Variable, b frontend.Variable, size int) frontend.Variable {
bitsA := aes.api.ToBinary(a, size)
bitsB := aes.api.ToBinary(b, size)
x := make([]frontend.Variable, size)
for i := 0; i < size; i++ {
x[i] = aes.api.Xor(bitsA[i], bitsB[i])
}
return aes.api.FromBinary(x...)
}

func (aes *AESGadget) XorSubWords(a, b, c, d frontend.Variable, xk []frontend.Variable) []frontend.Variable {

aa := aes.t0.Lookup(a)[0]
bb := aes.t1.Lookup(b)[0]
cc := aes.t2.Lookup(c)[0]
dd := aes.t3.Lookup(d)[0]

t0 := aes.api.ToBinary(aa, 32)
t1 := aes.api.ToBinary(bb, 32)
t2 := aes.api.ToBinary(cc, 32)
t3 := aes.api.ToBinary(dd, 32)

t4 := append(aes.api.ToBinary(xk[0], 8), aes.api.ToBinary(xk[1], 8)...)
t4 = append(t4, aes.api.ToBinary(xk[2], 8)...)
t4 = append(t4, aes.api.ToBinary(xk[3], 8)...)

t := make([]frontend.Variable, 32)
for i := 0; i < 32; i++ {
t[i] = aes.api.Xor(t0[i], t1[i])
t[i] = aes.api.Xor(t[i], t2[i])
t[i] = aes.api.Xor(t[i], t3[i])
t[i] = aes.api.Xor(t[i], t4[i])
}

newWord := make([]frontend.Variable, 4)
newWord[0] = aes.api.FromBinary(t[:8]...)
newWord[1] = aes.api.FromBinary(t[8:16]...)
newWord[2] = aes.api.FromBinary(t[16:24]...)
newWord[3] = aes.api.FromBinary(t[24:32]...)
return newWord
}

func (aes *AESGadget) ShiftSub(state [16]frontend.Variable) []frontend.Variable {
t := make([]frontend.Variable, 16)
for i := 0; i < 16; i++ {
t[i] = state[byte_order[i]]
}
return aes.Subws(aes.sbox, t...)
}

// substitute word with naive lookup of sbox
func (aes *AESGadget) Subws(sbox *logderivlookup.Table, a ...frontend.Variable) []frontend.Variable {
return sbox.Lookup(a...)
}

func (aes *AESGadget) createIV(counter frontend.Variable, iv []frontend.Variable) {
aBits := aes.api.ToBinary(counter, 32)

for i := 0; i < 4; i++ {
iv[15-i] = aes.api.FromBinary(aBits[i*8 : i*8+8]...)
}
}

func (aes *AESGadget) Encrypt(key []frontend.Variable, pt [16]frontend.Variable) [16]frontend.Variable {
keySize := aes.keySize
rounds := 10
if keySize == 32 {
rounds = 14
}

// expand key
xk := aes.ExpandKey(key)
var state [16]frontend.Variable
for i := 0; i < 16; i++ {
state[i] = aes.VariableXor(xk[i], pt[i], 8)
}

var t0, t1, t2, t3 []frontend.Variable
// iterate rounds
for i := 1; i < rounds; i++ {
k := i * 16
t0 = aes.XorSubWords(state[0], state[5], state[10], state[15], xk[k+0:k+4])
t1 = aes.XorSubWords(state[4], state[9], state[14], state[3], xk[k+4:k+8])
t2 = aes.XorSubWords(state[8], state[13], state[2], state[7], xk[k+8:k+12])
t3 = aes.XorSubWords(state[12], state[1], state[6], state[11], xk[k+12:k+16])

copy(state[:4], t0)
copy(state[4:8], t1)
copy(state[8:12], t2)
copy(state[12:16], t3)
}

copy(state[:], aes.ShiftSub(state))

k := rounds * 16

for i := 0; i < 4; i++ {
state[i+0] = aes.VariableXor(state[i+0], xk[k+i+0], 8)
state[i+4] = aes.VariableXor(state[i+4], xk[k+i+4], 8)
state[i+8] = aes.VariableXor(state[i+8], xk[k+i+8], 8)
state[i+12] = aes.VariableXor(state[i+12], xk[k+i+12], 8)
}

return state
type AESCircuit struct {
AESBaseCircuit
Out [BLOCKS * 16]frontend.Variable `gnark:",public"`
}

func (aes *AESGadget) ExpandKey(key []frontend.Variable) []frontend.Variable {

keySize := aes.keySize
rounds := 10
if keySize == 32 {
rounds = 14
}

var nWords = NB * (rounds + 1)

expand := make([]frontend.Variable, nWords*4)
i := 0

for i < keySize {
expand[i] = key[i]
expand[i+1] = key[i+1]
expand[i+2] = key[i+2]
expand[i+3] = key[i+3]

i += 4
}

for i < (nWords * 4) {
t0 := expand[i-4]
t1 := expand[i-3]
t2 := expand[i-2]
t3 := expand[i-1]

if i%keySize == 0 {
// rotation
t0, t1, t2, t3 = t1, t2, t3, t0

// sub words
tt := aes.Subws(aes.sbox, t0, t1, t2, t3)
t0, t1, t2, t3 = tt[0], tt[1], tt[2], tt[3]

t0 = aes.VariableXor(t0, aes.RCon[i/keySize], 8)
}

if rounds == 14 && i%keySize == 16 {
// sub words
tt := aes.Subws(aes.sbox, t0, t1, t2, t3)
t0, t1, t2, t3 = tt[0], tt[1], tt[2], tt[3]

}

expand[i] = aes.VariableXor(expand[i-keySize], t0, 8)
expand[i+1] = aes.VariableXor(expand[i-keySize+1], t1, 8)
expand[i+2] = aes.VariableXor(expand[i-keySize+2], t2, 8)
expand[i+3] = aes.VariableXor(expand[i-keySize+3], t3, 8)

i += 4
}

return expand
func (c *AESCircuit) Define(api frontend.API) error {
return c.AESBaseCircuit.Define(api, c.Out)
}
34 changes: 20 additions & 14 deletions gnark/circuits/aesV2/aes128_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ func TestAES128(t *testing.T) {
nonceAssign := StrToIntSlice(Nonce, true)

// witness values preparation
assignment := AESWrapper{
Key: make([]frontend.Variable, 16),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
Out: [BLOCKS * 16]frontend.Variable{},
assignment := AESCircuit{
AESBaseCircuit: AESBaseCircuit{
Key: make([]frontend.Variable, 16),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
},
Out: [BLOCKS * 16]frontend.Variable{},
}

// assign values here because required to use make in assignment
Expand All @@ -61,12 +63,14 @@ func TestAES128(t *testing.T) {
assignment.Nonce[i] = nonceAssign[i]
}

assert.CheckCircuit(&AESWrapper{
Key: make([]frontend.Variable, 16),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
Out: [BLOCKS * 16]frontend.Variable{},
assert.CheckCircuit(&AESCircuit{
AESBaseCircuit: AESBaseCircuit{
Key: make([]frontend.Variable, 16),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
},
Out: [BLOCKS * 16]frontend.Variable{},
}, test.WithValidAssignment(&assignment))
}

Expand Down Expand Up @@ -97,8 +101,10 @@ func mustHex(s string) []byte {
func TestCompile(t *testing.T) {
curve := ecc.BN254.ScalarField()

witness := AESWrapper{
Key: make([]frontend.Variable, 16),
witness := AESCircuit{
AESBaseCircuit: AESBaseCircuit{
Key: make([]frontend.Variable, 16),
},
}

r1css, err := frontend.Compile(curve, r1cs.NewBuilder, &witness)
Expand Down
34 changes: 20 additions & 14 deletions gnark/circuits/aesV2/aes256_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@ func TestAES256(t *testing.T) {
nonceAssign := StrToIntSlice(Nonce, true)

// witness values preparation
assignment := AESWrapper{
Key: make([]frontend.Variable, 32),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
Out: [BLOCKS * 16]frontend.Variable{},
assignment := AESCircuit{
AESBaseCircuit: AESBaseCircuit{
Key: make([]frontend.Variable, 32),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
},
Out: [BLOCKS * 16]frontend.Variable{},
}

// assign values here because required to use make in assignment
Expand All @@ -61,20 +63,24 @@ func TestAES256(t *testing.T) {
assignment.Nonce[i] = nonceAssign[i]
}

assert.CheckCircuit(&AESWrapper{
Key: make([]frontend.Variable, 32),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
Out: [BLOCKS * 16]frontend.Variable{},
assert.CheckCircuit(&AESCircuit{
AESBaseCircuit: AESBaseCircuit{
Key: make([]frontend.Variable, 32),
Counter: Counter,
Nonce: [12]frontend.Variable{},
In: [BLOCKS * 16]frontend.Variable{},
},
Out: [BLOCKS * 16]frontend.Variable{},
}, test.WithValidAssignment(&assignment))
}

func TestCompile256(t *testing.T) {
curve := ecc.BN254.ScalarField()

witness := AESWrapper{
Key: make([]frontend.Variable, 32),
witness := AESCircuit{
AESBaseCircuit: AESBaseCircuit{
Key: make([]frontend.Variable, 32),
},
}

r1css, err := frontend.Compile(curve, r1cs.NewBuilder, &witness)
Expand Down
Loading

0 comments on commit 4630106

Please sign in to comment.