Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(constraints): add rpo constraints #257

Open
wants to merge 2 commits into
base: next
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions constraints/miden-vm/hash.air
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
mod HashChipletAir

### Constants and periodic columns ################################################################
use rpo::enforce_rpo_round

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would probably leave this blank line. I think we do it in most other places, because when there are multiple things in a section it helps readability

### Constants and periodic columns ################################################################
periodic_columns:
cycle_row_0: [1, 0, 0, 0, 0, 0, 0, 0]
cycle_row_6: [0, 0, 0, 0, 0, 0, 1, 0]
cycle_row_7: [0, 0, 0, 0, 0, 0, 0, 1]


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would leave the blank line above the comment

### Helper functions ##############################################################################

# Returns binary negation of the value.
Expand Down Expand Up @@ -83,21 +83,20 @@ fn get_f_out(s: vector[3]) -> scalar:
fn get_f_out_next(s: vector[3]) -> scalar:
return cycle_row_6 & binary_not(s[0]') & binary_not(s[1]')


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that there is an error in the function above this get_f_out_next and maybe in other functions. Functions cannot call evaluators or access the next row. These should be fixed as a separate PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed it in the hasher multiset check constraints PR (here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that there is an error in the function above this get_f_out_next and maybe in other functions. Functions cannot call evaluators or access the next row. These should be fixed as a separate PR

### Helper evaluators #############################################################################

# Enforces that column must be binary.
ev is_binary(main: [a]):
ev is_binary([a]):
enf a^2 = a


# Enforces that value in column is copied over to the next row.
ev is_unchanged(main: [column]):
ev is_unchanged([column]):
ev column' = column


# Enforce selector columns constraints
ev selector_columns(main: [s[3]]):
ev selector_columns([s[3]]):
let f_out = get_f_out(s)
let f_out_next = get_f_out_next(s)
let f_abp = get_f_abp(s)
Expand Down Expand Up @@ -126,7 +125,7 @@ ev selector_columns(main: [s[3]]):


# Enforce node index constraints
ev node_index(main: [s[3], i]):
ev node_index([s[3], i]):
let f_out = get_f_out(s)
let f_mp = get_f_mp(s)
let f_mv = get_f_mv(s)
Expand Down Expand Up @@ -156,7 +155,7 @@ ev node_index(main: [s[3], i]):


# Enforce hasher state constraints
ev hasher_state(main: [s[3], h[12], i]):
ev hasher_state([s[3], h[12], i]):
let f_mp = get_f_mp(s)
let f_mv = get_f_mv(s)
let f_mu = get_f_mu(s)
Expand All @@ -181,11 +180,10 @@ ev hasher_state(main: [s[3], h[12], i]):
is_unchanged(h[j + 4]) for j in 0..4 when !b & f_absorb_node
h[j + 8]' = h[j + 4] for j in 0..4 when b & f_absorb_node


### Hash Chiplet Air Constraints ##################################################################

# Enforces the constraints on the hash chiplet, given the columns of the hash execution trace.
ev hash_chiplet(main: [s[3], r, h[12], i]):
ev hash_chiplet([s[3], r, h[12], i]):
## Row address constraint ##
# TODO: Apply row address constraints:
# 1. Boundary constraint `enf r.first = 1`
Expand All @@ -198,7 +196,8 @@ ev hash_chiplet(main: [s[3], r, h[12], i]):
enf node_index([s, i])

## Hasher state constraints ##
# TODO: apply RPO constraints to the hasher state
enf enforce_rpo_round([h]) when !cycle_row_7

enf hasher_state([s, h, i])

# Multiset check constraints
Expand Down
91 changes: 91 additions & 0 deletions constraints/miden-vm/rpo.air
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
mod RpoAir

### Constants and periodic columns ################################################################

const STATE_WIDTH = 12

# MDS matrix used for computing the linear layer in a RPO round
const MDS = [
[7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8],
[8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22, 21],
[21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6, 22],
[22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7, 6],
[6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9, 7],
[7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10, 9],
[9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13, 10],
[10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26, 13],
[13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8, 26],
[26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23, 8],
[8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7, 23],
[23, 8, 26, 13, 10, 9, 7, 6, 22, 21, 8, 7]
]

periodic_columns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a comment to explain the round constants and how they are laid out here


# Round constants added to the hasher state in the first half of the RPO round
ark1_0: [5789762306288267264, 12987190162843097088, 18072785500942327808, 5674685213610122240, 4887609836208846848, 16308865189192448000, 7123075680859040768, 0]
ark1_1: [6522564764413702144, 653957632802705280, 6200974112677013504, 5759084860419474432, 3027115137917284352, 11977192855656443904, 1034205548717903104, 0]
ark1_2: [17809893479458207744, 4441654670647621120, 17682092219085883392, 13943282657648898048, 9595098600469471232, 12532242556065779712, 7717824418247931904, 0]
ark1_3: [107145243989736512, 4038207883745915904, 10599526828986757120, 1352748651966375424, 10528569829048483840, 14594890931430969344, 3019070937878604288, 0]
ark1_4: [6388978042437517312, 5613464648874829824, 975003873302957312, 17110913224029904896, 7864689113198940160, 7291784239689209856, 11403792746066868224, 0]
ark1_5: [15844067734406017024, 13222989726778339328, 8264241093196931072, 1003883795902368384, 17533723827845969920, 5514718540551361536, 10280580802233112576, 0]
ark1_6: [9975000513555218432, 3037761201230264320, 10065763900435474432, 4141870621881018368, 5781638039037711360, 10025733853830934528, 337153209462421248, 0]
ark1_7: [3344984123768313344, 16683759727265179648, 2181131744534710272, 8121410972417424384, 17024078752430718976, 7293794580341021696, 13333398568519923712, 0]
ark1_8: [9959189626657347584, 8337364536491240448, 6317303992309419008, 14300518605864919040, 109659393484013504, 6728552937464861696, 3596153696935337472, 0]
ark1_9: [12960773468763564032, 3227397518293416448, 1401440938888741632, 13712227150607669248, 7158933660534805504, 6332385040983343104, 8104208463525993472, 0]
ark1_10: [9602914297752487936, 8110510111539675136, 8884468225181997056, 17021852944633065472, 2955076958026921984, 13277683694236792832, 14345062289456084992, 0]
ark1_11: [16657542370200465408, 2872078294163232256, 13066900325715521536, 6252096473787587584, 7433723648458773504, 2600778905124452864, 17036731477169661952, 0]

# Round constants added to the hasher state in the second half of the RPO round
ark2_0: [6077062762357203968, 6202948458916100096, 8023374565629191168, 18389244934624493568, 6982293561042363392, 3736792340494631424, 17130398059294019584, 0]
ark2_1: [15277620170502010880, 17690140365333231616, 15013690343205953536, 16731736864863924224, 14065426295947720704, 577852220195055360, 519782857322262016, 0]
ark2_2: [5358738125714196480, 3595001575307484672, 4485500052507913216, 4440209734760478208, 16451845770444974080, 6689998335515780096, 9625384390925084672, 0]
ark2_3: [14233283787297595392, 373995945117666496, 12489737547229155328, 17208448209698889728, 7139138592091307008, 13886063479078012928, 1664893052631119104, 0]
ark2_4: [13792579614346651648, 1235734395091296000, 9500452585969031168, 8739495587021565952, 9012006439959783424, 14358505101923203072, 7629576092524553216, 0]
ark2_5: [11614812331536766976, 14172757457833930752, 2054001340201038848, 17000774922218162176, 14619614108529063936, 7744142531772273664, 3485239601103661568, 0]
ark2_6: [14871063686742261760, 707573103686350208, 12420704059284934656, 13533282547195531264, 1394813199588124416, 16135070735728404480, 9755891797164034048, 0]
ark2_7: [10148237148793042944, 15453217512188186624, 355990932618543744, 525402848358706240, 4635111139507788800, 12290902521256030208, 15218148195153268736, 0]
ark2_8: [4457428952329675776, 219777875004506016, 9071225051243524096, 16987541523062161408, 16217473952264204288, 12059913662657710080, 16460604813734957056, 0]
ark2_9: [15590786458219171840, 17876696346199468032, 12766199826003447808, 5466806524462796800, 10782018226466330624, 16456018495793752064, 9643968136937730048, 0]
ark2_10: [10063319113072093184, 17731621626449383424, 9045979173463557120, 14512769585918244864, 6844229992533661696, 4571485474751953408, 3611348709641382912, 0]
ark2_11: [14200078843431360512, 2897136237748376064, 12934431667190679552, 10973956031244050432, 7446486531695178752, 17200392109565784064, 18256379591337758720, 0]


### Helper functions ##############################################################################

fn apply_mds(state: vector[12]) -> vector[12]:
return [sum([state[i] * mds_row[i] for i in 0..STATE_WIDTH]) for mds_row in MDS]

### RPO Air Constraints ###########################################################################

ev enforce_rpo_round([h[12]]):
let ark1 = [ark1_0, ark1_1, ark1_2, ark1_3, ark1_4, ark1_5, ark1_6, ark1_7, ark1_8, ark1_9,
ark1_10, ark1_11]

let ark2 = [ark2_0, ark2_1, ark2_2, ark2_3, ark2_4, ark2_5, ark2_6, ark2_7, ark2_8, ark2_9,
ark2_10, ark2_11]

# compute the state that should result from applying the first 5 operations of the RPO round to
# the current hash state.

# 1. apply mds
let step1_initial = apply_mds(h)

# 2. add constants
let step1_with_constants = [step1_initial[i] + ark1[i] for i in 0..STATE_WIDTH]

# 3. apply sbox
let step1_with_sbox = [s^7 for s in step1_with_constants]

# 4. apply mds
let step1_with_mds = apply_mds(step1_with_sbox)

# 5. add constants
let step1 = [step1_with_mds[i] + ark2[i] for i in 0..STATE_WIDTH]

# compute the state that should result from applying the inverse of the last operation of the
# RPO round to the next step of the computation.
let step2 = [s'^7 for s in h]

# make sure that the results are equal.
enf step1[i] = step2[i] for i in 0..12