diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td index df2954bc27..b371188900 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td @@ -780,6 +780,202 @@ def GesvjOp : EnzymeXLA_Op<"lapack.gesvj", [Pure]> { }]; } +// Special Functions - Bessel Functions + +def BesselJ : EnzymeXLA_Op<"special.besselj", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Bessel function of the first kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselJX : EnzymeXLA_Op<"special.besseljx", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Scaled Bessel function of the first kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def SphericalBesselJ : EnzymeXLA_Op<"special.sphericalbesselj", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Spherical Bessel function of the first kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselY : EnzymeXLA_Op<"special.bessely", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Bessel function of the second kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselYX : EnzymeXLA_Op<"special.besselyx", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Scaled Bessel function of the second kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def SphericalBesselY : EnzymeXLA_Op<"special.sphericalbessely", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Spherical Bessel function of the second kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselH : EnzymeXLA_Op<"special.besselh", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Bessel function of the third kind (Hankel function) of order nu at z"; + + let description = [{ + Computes the Bessel function of the third kind, also known as the Hankel + function. The parameter k must be either 1 or 2, selecting between Hankel + functions of the first kind (H1) and second kind (H2). + }]; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$k, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def HankelH1X : EnzymeXLA_Op<"special.hankelh1x", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Scaled Hankel function of the first kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def HankelH2X : EnzymeXLA_Op<"special.hankelh2x", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Scaled Hankel function of the second kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselI : EnzymeXLA_Op<"special.besseli", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Modified Bessel function of the first kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselIX : EnzymeXLA_Op<"special.besselix", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Scaled modified Bessel function of the first kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselK : EnzymeXLA_Op<"special.besselk", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Modified Bessel function of the second kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def BesselKX : EnzymeXLA_Op<"special.besselkx", [Pure, AllTypesMatch<["z", "res"]>, Elementwise]> { + let summary = "Scaled modified Bessel function of the second kind of order nu at z"; + + let arguments = (ins + HLO_Tensor:$nu, + HLO_Tensor:$z + ); + + let results = (outs + HLO_Tensor:$res + ); +} + +def Jinc : EnzymeXLA_Op<"special.jinc", [Pure, SameOperandsAndResultType, Elementwise]> { + let summary = "Jinc function (sombrero/besinc): scaled Bessel function of the first kind divided by x"; + + let description = [{ + Computes the jinc function, also known as the sombrero or besinc function. + It is defined as J1(pi*x) / (2*x) where J1 is the Bessel function of the + first kind of order 1. At x=0, the function evaluates to pi/4. + }]; + + let arguments = (ins + HLO_Tensor:$x + ); + + let results = (outs + HLO_Tensor:$res + ); +} + // Machine Learning Ops def GeluOp: EnzymeXLA_Op<"ml.gelu", [Pure, SameOperandsAndResultType, Elementwise]> {