Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ww-rm committed Jun 9, 2024
1 parent ffa193a commit f230991
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 153 deletions.
92 changes: 34 additions & 58 deletions src/gmalglib/core/sm2curve.c
Original file line number Diff line number Diff line change
Expand Up @@ -191,30 +191,18 @@ void SM2ModP_MontMul(const SM2ModPMont* x, const SM2ModPMont* y, SM2ModPMont* z)
}
}

static
static
void SM2ModP_ToMont(const SM2ModP* x, SM2ModPMont* y)
{
SM2ModP_MontMul(x, CONSTS_MODP_R2, y);
}

static
static
void SM2ModP_FromMont(const SM2ModPMont* x, SM2ModP* y)
{
SM2ModP_MontMul(x, CONSTS_MODP_ONE, y);
}

static
void SM2ModP_MontAdd(const SM2ModPMont* x, const SM2ModPMont* y, SM2ModPMont* z)
{
SM2ModP_Add(x, y, z);
}

static
void SM2ModP_MontSub(const SM2ModPMont* x, const SM2ModPMont* y, SM2ModPMont* z)
{
SM2ModP_Sub(x, y, z);
}

static
void SM2ModP_MontPow(const SM2ModPMont* x, const UInt256* e, SM2ModPMont* y)
{
Expand Down Expand Up @@ -284,24 +272,12 @@ int SM2ModP_MontHasSqrt(const SM2ModPMont* x, SM2ModPMont* y)
return 1;
}

static
void SM2ModP_MontNeg(const SM2ModPMont* x, SM2ModPMont* y)
{
SM2ModP_Neg(x, y);
}

static
static
void SM2ModP_MontInv(const SM2ModPMont* x, SM2ModPMont* y)
{
SM2ModP_MontPow(x, CONSTS_P_MINUS_TWO, y);
}

static
void SM2ModP_MontDiv2(const SM2ModPMont* x, SM2ModPMont* y)
{
SM2ModP_Div2(x, y);
}

int SM2JacobPointMont_IsInf(const SM2JacobPointMont* X)
{
return UInt256_IsZero(&X->z);
Expand All @@ -324,14 +300,14 @@ int SM2JacobPointMont_IsOnCurve(const SM2JacobPointMont* X)
{
// left = y^2 + 3x
SM2ModP_MontMul(&X->y, &X->y, &left);
SM2ModP_MontAdd(&left, &X->x, &left);
SM2ModP_MontAdd(&left, &X->x, &left);
SM2ModP_MontAdd(&left, &X->x, &left);
SM2ModP_Add(&left, &X->x, &left);
SM2ModP_Add(&left, &X->x, &left);
SM2ModP_Add(&left, &X->x, &left);

// right = x^3 + b
SM2ModP_MontMul(&X->x, &X->x, &right);
SM2ModP_MontMul(&right, &X->x, &right);
SM2ModP_MontAdd(&right, CONSTS_MODP_MONT_B, &right);
SM2ModP_Add(&right, CONSTS_MODP_MONT_B, &right);
}
else
{
Expand All @@ -341,16 +317,16 @@ int SM2JacobPointMont_IsOnCurve(const SM2JacobPointMont* X)

// left = y^2 + 3xz^4
SM2ModP_MontMul(&X->x, &left, &tmp); // xz^4
SM2ModP_MontAdd(&tmp, &tmp, &left); // 2xz^4
SM2ModP_MontAdd(&left, &tmp, &left); // 3xz^4
SM2ModP_Add(&tmp, &tmp, &left); // 2xz^4
SM2ModP_Add(&left, &tmp, &left); // 3xz^4
SM2ModP_MontMul(&X->y, &X->y, &tmp); // y^2
SM2ModP_MontAdd(&tmp, &left, &left); // y^2 + 3xz^4
SM2ModP_Add(&tmp, &left, &left); // y^2 + 3xz^4

// right = x^3 + bz^6
SM2ModP_MontMul(&X->x, &X->x, &tmp); // x^2
SM2ModP_MontMul(&tmp, &X->x, &tmp); // x^3
SM2ModP_MontMul(&right, CONSTS_MODP_MONT_B, &right); // bz^6
SM2ModP_MontAdd(&tmp, &right, &right); // x^3 + bz^6
SM2ModP_Add(&tmp, &right, &right); // x^3 + bz^6
}

return UInt256_Cmp(&left, &right) == 0;
Expand Down Expand Up @@ -450,14 +426,14 @@ int SM2Point_IsOnCurve(const SM2Point* X)

// left = y^2 + 3x
SM2ModP_MontMul(&y, &y, &left);
SM2ModP_MontAdd(&left, &x, &left);
SM2ModP_MontAdd(&left, &x, &left);
SM2ModP_MontAdd(&left, &x, &left);
SM2ModP_Add(&left, &x, &left);
SM2ModP_Add(&left, &x, &left);
SM2ModP_Add(&left, &x, &left);

// right = x^3 + b
SM2ModP_MontMul(&x, &x, &right);
SM2ModP_MontMul(&right, &x, &right);
SM2ModP_MontAdd(&right, CONSTS_MODP_MONT_B, &right);
SM2ModP_Add(&right, CONSTS_MODP_MONT_B, &right);

return UInt256_Cmp(&left, &right) == 0;
}
Expand Down Expand Up @@ -536,9 +512,9 @@ int SM2Point_FromBytes(const uint8_t* bytes, uint64_t bytes_len, SM2Point* X)

// compute y
SM2ModP_MontMul(&X->x, &X->x, &X->y);
SM2ModP_MontAdd(&X->y, CONSTS_MODP_MONT_A, &X->y);
SM2ModP_Add(&X->y, CONSTS_MODP_MONT_A, &X->y);
SM2ModP_MontMul(&X->x, &X->y, &X->y);
SM2ModP_MontAdd(&X->y, CONSTS_MODP_MONT_B, &X->y);
SM2ModP_Add(&X->y, CONSTS_MODP_MONT_B, &X->y);
if (!SM2ModP_MontHasSqrt(&X->y, &X->y))
return SM2CURVE_ERR_NOTONCURVE;

Expand Down Expand Up @@ -588,33 +564,33 @@ void _SM2JacobPointMont_Dbl(const SM2JacobPointMont* X, SM2JacobPointMont* Y)
// a == -3 (mod p)
// t1 = 3(x + z^2)(x - z^2)
SM2ModP_MontMul(&X->z, &X->z, &t1); // z^2
SM2ModP_MontSub(&X->x, &t1, &t2); // x - z^2
SM2ModP_MontAdd(&X->x, &t1, &t1); // x + z^2
SM2ModP_Sub(&X->x, &t1, &t2); // x - z^2
SM2ModP_Add(&X->x, &t1, &t1); // x + z^2
SM2ModP_MontMul(&t1, &t2, &t1); // (x + z^2)(x - z^2)
SM2ModP_MontAdd(&t1, &t1, &t2); // 2(x + z^2)(x - z^2)
SM2ModP_MontAdd(&t1, &t2, &t1); // 3(x + z^2)(x - z^2)
SM2ModP_Add(&t1, &t1, &t2); // 2(x + z^2)(x - z^2)
SM2ModP_Add(&t1, &t2, &t1); // 3(x + z^2)(x - z^2)

// z' = 2yz
SM2ModP_MontAdd(&X->y, &X->y, &t3); // 2y
SM2ModP_Add(&X->y, &X->y, &t3); // 2y
SM2ModP_MontMul(&t3, &X->z, &Y->z);

// t3 = 8y^4
SM2ModP_MontMul(&t3, &t3, &t2); // 4y^2
SM2ModP_MontMul(&t2, &t2, &t3); // 16y^4
SM2ModP_MontDiv2(&t3, &t3); // 8y^4
SM2ModP_Div2(&t3, &t3); // 8y^4

// t2 = 4xy^2
SM2ModP_MontMul(&X->x, &t2, &t2); // 4xy^2

// x' = t1^2 - 2t2
SM2ModP_MontMul(&t1, &t1, &Y->x);
SM2ModP_MontSub(&Y->x, &t2, &Y->x);
SM2ModP_MontSub(&Y->x, &t2, &Y->x);
SM2ModP_Sub(&Y->x, &t2, &Y->x);
SM2ModP_Sub(&Y->x, &t2, &Y->x);

// y' = t1(t2 - x') - t3
SM2ModP_MontSub(&t2, &Y->x, &Y->y);
SM2ModP_Sub(&t2, &Y->x, &Y->y);
SM2ModP_MontMul(&t1, &Y->y, &Y->y);
SM2ModP_MontSub(&Y->y, &t3, &Y->y);
SM2ModP_Sub(&Y->y, &t3, &Y->y);
}

void SM2JacobPointMont_Dbl(const SM2JacobPointMont* X, SM2JacobPointMont* Y)
Expand Down Expand Up @@ -666,27 +642,27 @@ void _SM2JacobPointMont_Add(const SM2JacobPointMont* X, const SM2JacobPointMont*

// t3 = t1 - t2
// t6 = t4 - t5
SM2ModP_MontSub(&t1, &t2, &t3);
SM2ModP_MontSub(&t4, &t5, &t6);
SM2ModP_Sub(&t1, &t2, &t3);
SM2ModP_Sub(&t4, &t5, &t6);

// z3 = z1 * z2 * t3
SM2ModP_MontMul(&X->z, &Y->z, &Z->z); // z1 * z2
SM2ModP_MontMul(&Z->z, &t3, &Z->z); // z1 * z2 * t3

// x3 = t6^2 - (t1 + t2) * t3^2
SM2ModP_MontMul(&t6, &t6, &Z->x); // t6^2
SM2ModP_MontAdd(&t1, &t2, &t2); // t1 + t2
SM2ModP_Add(&t1, &t2, &t2); // t1 + t2
SM2ModP_MontMul(&t3, &t3, &t5); // t3^2
SM2ModP_MontMul(&t2, &t5, &t2); // (t1 + t2) * t3^2
SM2ModP_MontSub(&Z->x, &t2, &Z->x); // t6^2 - (t1 + t2) * t3^2
SM2ModP_Sub(&Z->x, &t2, &Z->x); // t6^2 - (t1 + t2) * t3^2

// y3 = t6 * (t1 * t3^2 - x3) - t4 * t3^3
SM2ModP_MontMul(&t1, &t5, &Z->y); // t1 * t3^2
SM2ModP_MontSub(&Z->y, &Z->x, &Z->y); // t1 * t3^2 - x3
SM2ModP_Sub(&Z->y, &Z->x, &Z->y); // t1 * t3^2 - x3
SM2ModP_MontMul(&t6, &Z->y, &Z->y); // t6 * (t1 * t3^2 - x3)
SM2ModP_MontMul(&t4, &t3, &t3); // t4 * t3
SM2ModP_MontMul(&t3, &t5, &t3); // t4 * t3^3
SM2ModP_MontSub(&Z->y, &t3, &Z->y); // t6 * (t1 * t3^2 - x3) - t4 * t3^3
SM2ModP_Sub(&Z->y, &t3, &Z->y); // t6 * (t1 * t3^2 - x3) - t4 * t3^3
}

void SM2JacobPointMont_Add(const SM2JacobPointMont* X, const SM2JacobPointMont* Y, SM2JacobPointMont* Z)
Expand Down Expand Up @@ -793,7 +769,7 @@ void SM2JacobPointMont_Mul(const UInt256* k, const SM2JacobPointMont* X, SM2Jaco
void SM2JacobPointMont_Neg(const SM2JacobPointMont* X, SM2JacobPointMont* Y)
{
Y->x = X->x;
SM2ModP_MontNeg(&X->y, &Y->y);
SM2ModP_Neg(&X->y, &Y->y);
Y->z = X->z;
}

Expand Down
Loading

0 comments on commit f230991

Please sign in to comment.