-
Notifications
You must be signed in to change notification settings - Fork 26
/
recover_from_samples.go
109 lines (97 loc) · 3.14 KB
/
recover_from_samples.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package kzg
import (
"fmt"
"github.com/protolambda/go-kzg/bls"
)
// unshift poly, in-place. Multiplies each coeff with 1/shift_factor**i
func (fs *FFTSettings) ShiftPoly(poly []bls.Fr) {
var shiftFactor bls.Fr
bls.AsFr(&shiftFactor, 5) // primitive root of unity
var factorPower bls.Fr
bls.CopyFr(&factorPower, &bls.ONE)
var invFactor bls.Fr
bls.InvModFr(&invFactor, &shiftFactor)
var tmp bls.Fr
for i := 0; i < len(poly); i++ {
bls.CopyFr(&tmp, &poly[i])
bls.MulModFr(&poly[i], &tmp, &factorPower)
// TODO: pre-compute all these shift scalars
bls.CopyFr(&tmp, &factorPower)
bls.MulModFr(&factorPower, &tmp, &invFactor)
}
}
// unshift poly, in-place. Multiplies each coeff with shift_factor**i
func (fs *FFTSettings) UnshiftPoly(poly []bls.Fr) {
var shiftFactor bls.Fr
bls.AsFr(&shiftFactor, 5) // primitive root of unity
var factorPower bls.Fr
bls.CopyFr(&factorPower, &bls.ONE)
var tmp bls.Fr
for i := 0; i < len(poly); i++ {
bls.CopyFr(&tmp, &poly[i])
bls.MulModFr(&poly[i], &tmp, &factorPower)
// TODO: pre-compute all these shift scalars
bls.CopyFr(&tmp, &factorPower)
bls.MulModFr(&factorPower, &tmp, &shiftFactor)
}
}
func (fs *FFTSettings) RecoverPolyFromSamples(samples []*bls.Fr, zeroPolyFn ZeroPolyFn) ([]bls.Fr, error) {
// TODO: using a single additional temporary array, all the FFTs can run in-place.
missingIndices := make([]uint64, 0, len(samples))
for i, s := range samples {
if s == nil {
missingIndices = append(missingIndices, uint64(i))
}
}
zeroEval, zeroPoly := zeroPolyFn(missingIndices, uint64(len(samples)))
for i, s := range samples {
if (s == nil) != bls.EqualZero(&zeroEval[i]) {
panic("bad zero eval")
}
}
polyEvaluationsWithZero := make([]bls.Fr, len(samples), len(samples))
for i, s := range samples {
if s == nil {
bls.CopyFr(&polyEvaluationsWithZero[i], &bls.ZERO)
} else {
bls.MulModFr(&polyEvaluationsWithZero[i], s, &zeroEval[i])
}
}
polyWithZero, err := fs.FFT(polyEvaluationsWithZero, true)
if err != nil {
return nil, err
}
// shift in-place
fs.ShiftPoly(polyWithZero)
shiftedPolyWithZero := polyWithZero
fs.ShiftPoly(zeroPoly)
shiftedZeroPoly := zeroPoly
evalShiftedPolyWithZero, err := fs.FFT(shiftedPolyWithZero, false)
if err != nil {
return nil, err
}
evalShiftedZeroPoly, err := fs.FFT(shiftedZeroPoly, false)
if err != nil {
return nil, err
}
evalShiftedReconstructedPoly := evalShiftedPolyWithZero
for i := 0; i < len(evalShiftedReconstructedPoly); i++ {
bls.DivModFr(&evalShiftedReconstructedPoly[i], &evalShiftedPolyWithZero[i], &evalShiftedZeroPoly[i])
}
shiftedReconstructedPoly, err := fs.FFT(evalShiftedReconstructedPoly, true)
if err != nil {
return nil, err
}
fs.UnshiftPoly(shiftedReconstructedPoly)
reconstructedPoly := shiftedReconstructedPoly
reconstructedData, err := fs.FFT(reconstructedPoly, false)
if err != nil {
return nil, err
}
for i, s := range samples {
if s != nil && !bls.EqualFr(&reconstructedData[i], s) {
return nil, fmt.Errorf("failed to reconstruct data correctly, changed value at index %d. Expected: %s, got: %s", i, bls.FrStr(s), bls.FrStr(&reconstructedData[i]))
}
}
return reconstructedData, nil
}