-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmore_math.move.template
198 lines (176 loc) · 6.91 KB
/
more_math.move.template
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
{{$typename := .UintTypeName}}// Auto generated from gen-move-math
// https://github.com/fardream/gen-move-math
// Manual edit with caution.
// Arguments: {{.Args}}
// Version: {{.Version}}
module {{.Address}}::{{.ModuleName}} {
const E_WIDTH_OVERFLOW_U8: u64 = 1;
const E_LOG_2_OUT_OF_RANGE: u64 = 2;
const E_ZERO_FOR_LOG_2: u64 = 3;
const HALF_SIZE: u8 = {{.HalfWidth}};
const MAX_SHIFT_SIZE: u8 = {{.MaxShiftSize}};
const ALL_ONES: {{$typename}} = {{.AllOnes}};
const LOWER_ONES: {{$typename}} = (1 << {{.HalfWidth}}) - 1;
const UPPER_ONES: {{$typename}} = ((1 << {{.HalfWidth}}) - 1) << {{.HalfWidth}};
/// add two {{$typename}} with carry - will never abort.
/// First return value is the value of the result, the second return value indicate if carry happens.
public fun add_with_carry(x: {{$typename}}, y: {{$typename}}):({{$typename}}, {{$typename}}) {
let r = ALL_ONES - x;
if (r < y) {
(y - r - 1, 1)
} else {
(x + y, 0)
}
}
/// subtract y from x with borrow - will never abort.
/// First return value is the value of the result, the second return value indicate if borrow happens.
public fun sub_with_borrow(x: {{$typename}}, y:{{$typename}}): ({{$typename}}, {{$typename}}) {
if (x < y) {
((ALL_ONES - y) + 1 + x, 1)
} else {
(x-y, 0)
}
}
/// x * y, first return value is the lower part of the result, second return value is the upper part of the result.
public fun mul_with_carry(x: {{$typename}}, y: {{$typename}}):({{$typename}}, {{$typename}}) {
// split x and y into lower part and upper part.
// xh, xl, yh, yl
// result is
// upper = xh * xl + (xh * yl) >> half_size + (xl * yh) >> half_size
// lower = xl * yl + (xh * yl) << half_size + (xl * yh) << half_size
let xh = (x & UPPER_ONES) >> HALF_SIZE;
let xl = x & LOWER_ONES;
let yh = (y & UPPER_ONES) >> HALF_SIZE;
let yl = y & LOWER_ONES;
let xhyl = xh * yl;
let xlyh = xl * yh;
let (lo, lo_carry_1) = add_with_carry(xl * yl, (xhyl & LOWER_ONES) << HALF_SIZE);
let (lo, lo_carry_2) = add_with_carry(lo, (xlyh & LOWER_ONES)<< HALF_SIZE);
let hi = xh * yh + (xhyl >> HALF_SIZE) + (xlyh >> HALF_SIZE) + lo_carry_1 + lo_carry_2;
(lo, hi)
}
/// count leading zeros for {{$typename}}
public fun leading_zeros(x: {{$typename}}): u8 {
if (x == 0) {
{{if eq .UintWidth 256}} abort(E_WIDTH_OVERFLOW_U8)
{{else}} {{.UintWidth}}
{{end}} } else {
let n: u8 = 0;
{{$base_type := .UintTypeName}}{{range .UnrolledForLeadingZeros}} if (x & {{.Ones}} == 0) {
{{if ne .Width 1}}// x's higher {{.Width}} is all zero, shift the lower part over
x = x << {{.Width}};
{{end}}n = n + {{.Width}};
};
{{end}}
n
}
}
/// count trailing zeros for {{$typename}}
public fun trailing_zeros(x: {{$typename}}): u8 {
if (x == 0) {
{{if eq .UintWidth 256}} abort(E_WIDTH_OVERFLOW_U8)
{{else}} {{.UintWidth}}
{{end}} } else {
let n: u8 = 0;
{{$base_type := .UintTypeName}}{{range .UnrolledForLeadingZeros}} if (x & {{.TrailingOnes}} == 0) {
{{if ne .Width 1}}// x's lower {{.Width}} is all zero, shift the higher part over
x = x >> {{.Width}};
{{end}}n = n + {{.Width}};
};
{{end}}
n
}
}
/// sqrt returns y = floor(sqrt(x))
public fun sqrt(x: {{$typename}}): {{$typename}} {
if (x == 0) {
0
} else if (x <= 3) {
1
} else {
// reproduced from golang's big.Int.Sqrt
let n = (MAX_SHIFT_SIZE - leading_zeros(x)) >> 1;
let z = x >> ((n - 1) / 2);
let z_next = (z + x / z) >> 1;
while (z_next < z) {
z = z_next;
z_next = (z + x / z) >> 1;
};
z
}
}
/// log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2)
/// This can be used to calculate log_2 for any positive number by bit shifting the binary representation
/// into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used.
/// see: https://en.wikipedia.org/wiki/Binary_logarithm
/// Also note the last several digits of the result may not be precise.
public fun log_2(x: {{$typename}}, n: u8): {{$typename}} {
assert!(x != 0, E_ZERO_FOR_LOG_2);
let one: {{$typename}} = 1 << n;
let two: {{$typename}} = one << 1;
assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE);
let z_2 = x;
let r: {{$typename}} = 0;
let sum_m: u8 = 0;
loop {
if (z_2 == one) {
break
};
let z = (z_2 * z_2) >> n;
sum_m = sum_m + 1;
if (sum_m > n) {
break
};
if (z >= two) {
r = r + (one >> sum_m);
z_2 = z >> 1;
} else {
z_2 = z;
};
};
r
}
/// log_2 calculates log_2 z, where z is assumed to be ratio of x/2^n and in [1,2). Instead of looking for same precision as
/// n, the precision is controlled by another parameter precision. Calculation will stop once 2^precision is reached.
/// This can be used to calculate log_2 for any positive number by bit shifting the binary representation
/// into the desired region. For float point number (1.b1b2b3...b53 * 2^n), the martissa can be directly used.
/// see: https://en.wikipedia.org/wiki/Binary_logarithm
public fun log_2_with_precision(x: {{$typename}}, n: u8, precision: u8): {{$typename}} {
assert!(x != 0, E_ZERO_FOR_LOG_2);
let one: {{$typename}} = 1 << n;
let two: {{$typename}} = one << 1;
assert!(x >= one && x < two, E_LOG_2_OUT_OF_RANGE);
let z_2 = x;
let r: {{$typename}} = 0;
let sum_m: u8 = 0;
loop {
if (z_2 == one) {
break
};
let z = (z_2 * z_2) >> n;
sum_m = sum_m + 1;
if (sum_m > precision) {
break
};
if (z >= two) {
r = r + (one >> sum_m);
z_2 = z >> 1;
} else {
z_2 = z;
};
};
r
}
{{if .DoTest}}
{{range $i, $t := .SqrtTestCases}} #[test]
fun test_sqrt_{{$typename}}_{{$i}}() {
let r = sqrt({{.Squared}});
assert!(r == {{.Root}}, (r as u64));
}
{{end}}
{{range $i, $t := .Log2TestCases}} #[test]
fun test_log_2_{{$typename}}_{{$i}}() {
let r = log_2({{.Value}}, {{.N}});
assert!(r == {{.Log2}}, (r as u64));
}
{{end}}{{end}}}