-
Notifications
You must be signed in to change notification settings - Fork 0
/
rsaDec.java
168 lines (130 loc) · 5.85 KB
/
rsaDec.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
package examples.generators;
import java.math.BigInteger;
import java.util.Random;
import java.nio.charset.StandardCharsets;
import util.Util;
import circuit.auxiliary.LongElement;
import circuit.config.Config;
import circuit.eval.CircuitEvaluator;
import circuit.structure.CircuitGenerator;
import circuit.structure.Wire;
import circuit.structure.WireArray;
import examples.gadgets.math.LongIntegerModGadget;
import examples.gadgets.math.ModGadget;
/*
256bits p,q 512bits n
constrain : 약 200만개 512bit 지수승
proving time : 50sec
*/
public class rsaDec extends CircuitGenerator {
private Random rand = new Random(0);
private static BigInteger p;
private static BigInteger q;
private static BigInteger n; // p * q
private static BigInteger order; // (p-1) * (q-1)
public static BigInteger pk; // e
private static BigInteger sk; // d
private static BigInteger cipherText = new BigInteger("3814131678415201177304543701723813930960506828092320199734903490445051846672373530664772742052689180801312612221921878767990815583459655782999214922500027");
public static BigInteger message;
private int nLength;
private Wire nWire;
private Wire pkWire;
private Wire skWire;
private Wire mWire;
private Wire cWire;
private LongElement nLongElement;
private LongElement pkLongElement;
private LongElement skLongElement;
private LongElement mLongElement;
private LongElement cLongElement;
// enc : 3814131678415201177304543701723813930960506828092320199734903490445051846672373530664772742052689180801312612221921878767990815583459655782999214922500027
public rsaDec(String circuitName, int size) {
super(circuitName);
this.nLength = 2*size;
rsaSetup(size);
}
@Override
protected void buildCircuit() {
nLongElement = createLongElementInput(nLength, "n");
pkLongElement = createLongElementInput(nLength, "pk");
mLongElement = createLongElementInput(nLength, "message");
cLongElement = createLongElementInput(nLength, "c");
skLongElement = createLongElementProverWitness(nLength, "sk");
LongElement calculated_mLongElement = pow(cLongElement, skLongElement, nLongElement);
calculated_mLongElement.assertEquality(mLongElement);
}
// square and multiply
Wire pow(Wire a, Wire b, Wire n){
//a^b mod n
Wire[] bBitArray = b.getBitWires(Config.LOG2_FIELD_PRIME).asArray();
Wire tmp = oneWire.mul(a);
Wire result = oneWire;
ModGadget modgadget;
for (int i=0; i < bBitArray.length ; i++){
// if bBit == 1 : check = tmp
// if bBit == 0 : check = 1
Wire check = bBitArray[i].mul(tmp).add(bBitArray[i].isEqualTo(zeroWire));
modgadget = new ModGadget(result.mul(check), n, 126);
result = modgadget.getOutputWires()[0];
modgadget = new ModGadget(tmp.mul(tmp), n, 126);
tmp = modgadget.getOutputWires()[0]; /// a^1 a^2 a^4 ....
}
return result;
}
LongElement pow(LongElement a, LongElement b, LongElement n){
Wire[] bBitWireArray = b.getBits(nLength).asArray();
LongElement square = a;
LongElement result = new LongElement(oneWire.getBitWires(nLength));
LongIntegerModGadget longintegetmodgadget;
for(int i=0; i<nLength; i++){
LongElement tmp = new LongElement(bBitWireArray[i].isEqualTo(zeroWire).getBitWires(1));
LongElement tmp2 = new LongElement(bBitWireArray[i].getBitWires(1)).mul(square);
longintegetmodgadget = new LongIntegerModGadget(result.mul(tmp2.add(tmp)), n, false);
result = longintegetmodgadget.getRemainder();
LongElement tmp3 = square.mul(square);
longintegetmodgadget = new LongIntegerModGadget(tmp3, n, true);
square = longintegetmodgadget.getRemainder();
}
return result;
}
@Override
public void generateSampleInput(CircuitEvaluator circuitEvaluator) {
circuitEvaluator.setWireValue(nLongElement, n, LongElement.CHUNK_BITWIDTH);
circuitEvaluator.setWireValue(pkLongElement, pk, LongElement.CHUNK_BITWIDTH);
circuitEvaluator.setWireValue(mLongElement, message, LongElement.CHUNK_BITWIDTH);
circuitEvaluator.setWireValue(cLongElement, cipherText, LongElement.CHUNK_BITWIDTH);
circuitEvaluator.setWireValue(skLongElement, sk, LongElement.CHUNK_BITWIDTH);
}
private void rsaSetup(int size){
p = BigInteger.probablePrime(size, rand);
q = BigInteger.probablePrime(size, rand);
while(p.multiply(q).bitLength() > size*2){
System.out.println("pick p q");
p = BigInteger.probablePrime(size, rand);
q = BigInteger.probablePrime(size, rand);
}
n = p.multiply(q);
order = p.subtract(BigInteger.ONE).multiply(q.subtract(BigInteger.ONE));
System.out.println("p bit len : " + Integer.toString(p.bitLength()));
System.out.println("q bit len : " + Integer.toString(q.bitLength()));
System.out.println("n bit len : " + Integer.toString(n.bitLength()));
pk = new BigInteger(size*2, rand).mod(order);
while(order.gcd(pk).compareTo(BigInteger.ONE) != 0){
pk = new BigInteger(size*2, rand).mod(order);
}
sk = pk.modInverse(order);
}
private static BigInteger rsaDecryption(BigInteger C){
BigInteger m = C.modPow(sk, n);
System.out.println("Dec m : " + m.toString());
return m;
}
public static void main(String[] args) throws Exception {
rsaDec generator = new rsaDec("RSA dec", 256);
message = rsaDecryption(cipherText);
generator.generateCircuit();
generator.evalCircuit();
generator.prepFiles();
generator.runLibsnark();
}
}