17 #ifndef KOKKOS_MATHEMATICAL_FUNCTIONS_HPP 
   18 #define KOKKOS_MATHEMATICAL_FUNCTIONS_HPP 
   19 #ifndef KOKKOS_IMPL_PUBLIC_INCLUDE 
   20 #define KOKKOS_IMPL_PUBLIC_INCLUDE 
   21 #define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS 
   24 #include <Kokkos_Macros.hpp> 
   27 #include <type_traits> 
   29 #ifdef KOKKOS_ENABLE_SYCL 
   31 #if __has_include(<sycl/sycl.hpp>) 
   32 #include <sycl/sycl.hpp> 
   34 #include <CL/sycl.hpp> 
   41 template <
class T, 
bool = std::is_
integral_v<T>>
 
   46 struct promote<T, false> {};
 
   48 struct promote<long double> {
 
   49   using type = 
long double;
 
   52 struct promote<double> {
 
   56 struct promote<float> {
 
   60 using promote_t = 
typename promote<T>::type;
 
   61 template <
class T, 
class U,
 
   62           bool = std::is_arithmetic_v<T>&& std::is_arithmetic_v<U>>
 
   64   using type = decltype(promote_t<T>() + promote_t<U>());
 
   66 template <
class T, 
class U>
 
   67 struct promote_2<T, U, false> {};
 
   68 template <
class T, 
class U>
 
   69 using promote_2_t = 
typename promote_2<T, U>::type;
 
   70 template <
class T, 
class U, 
class V,
 
   71           bool = std::is_arithmetic_v<T>&& std::is_arithmetic_v<U>&&
 
   72               std::is_arithmetic_v<V>>
 
   74   using type = decltype(promote_t<T>() + promote_t<U>() + promote_t<V>());
 
   76 template <
class T, 
class U, 
class V>
 
   77 struct promote_3<T, U, V, false> {};
 
   78 template <
class T, 
class U, 
class V>
 
   79 using promote_3_t = 
typename promote_3<T, U, V>::type;
 
   84 #if defined(KOKKOS_ENABLE_SYCL) 
   85 #define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE sycl 
   87 #if (defined(KOKKOS_COMPILER_NVCC) || defined(KOKKOS_COMPILER_NVHPC)) && \ 
   88     defined(__GNUC__) && (__GNUC__ < 6) && !defined(__clang__) 
   89 #define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE 
   91 #define KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE std 
   95 #define KOKKOS_IMPL_MATH_UNARY_FUNCTION(FUNC)                                  \ 
   96   KOKKOS_INLINE_FUNCTION float FUNC(float x) {                                 \ 
   97     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  100   KOKKOS_INLINE_FUNCTION double FUNC(double x) {                               \ 
  101     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  104   inline long double FUNC(long double x) {                                     \ 
  108   KOKKOS_INLINE_FUNCTION float FUNC##f(float x) {                              \ 
  109     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  112   inline long double FUNC##l(long double x) {                                  \ 
  117   KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, double> FUNC( \ 
  119     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  120     return FUNC(static_cast<double>(x));                                       \ 
  126 #if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA) 
  127 #define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC)                               \ 
  128   KOKKOS_INLINE_FUNCTION bool FUNC(float x) { return ::FUNC(x); }            \ 
  129   KOKKOS_INLINE_FUNCTION bool FUNC(double x) { return ::FUNC(x); }           \ 
  130   inline bool FUNC(long double x) {                                          \ 
  135   KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, bool> FUNC( \ 
  137     return ::FUNC(static_cast<double>(x));                                   \ 
  140 #define KOKKOS_IMPL_MATH_UNARY_PREDICATE(FUNC)                               \ 
  141   KOKKOS_INLINE_FUNCTION bool FUNC(float x) {                                \ 
  142     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \ 
  145   KOKKOS_INLINE_FUNCTION bool FUNC(double x) {                               \ 
  146     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \ 
  149   inline bool FUNC(long double x) {                                          \ 
  154   KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, bool> FUNC( \ 
  156     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                        \ 
  157     return FUNC(static_cast<double>(x));                                     \ 
  161 #define KOKKOS_IMPL_MATH_BINARY_FUNCTION(FUNC)                                 \ 
  162   KOKKOS_INLINE_FUNCTION float FUNC(float x, float y) {                        \ 
  163     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  166   KOKKOS_INLINE_FUNCTION double FUNC(double x, double y) {                     \ 
  167     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  170   inline long double FUNC(long double x, long double y) {                      \ 
  174   KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y) {                     \ 
  175     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  178   inline long double FUNC##l(long double x, long double y) {                   \ 
  182   template <class T1, class T2>                                                \ 
  183   KOKKOS_INLINE_FUNCTION                                                       \ 
  184       std::enable_if_t<std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> && \ 
  185                            !std::is_same_v<T1, long double> &&                 \ 
  186                            !std::is_same_v<T2, long double>,                   \ 
  187                        Kokkos::Impl::promote_2_t<T1, T2>>                      \ 
  189     using Promoted = Kokkos::Impl::promote_2_t<T1, T2>;                        \ 
  190     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                          \ 
  191     return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y));           \ 
  193   template <class T1, class T2>                                                \ 
  194   inline std::enable_if_t<std::is_arithmetic_v<T1> &&                          \ 
  195                               std::is_arithmetic_v<T2> &&                      \ 
  196                               (std::is_same_v<T1, long double> ||              \ 
  197                                std::is_same_v<T2, long double>),               \ 
  200     using Promoted = Kokkos::Impl::promote_2_t<T1, T2>;                        \ 
  201     static_assert(std::is_same_v<Promoted, long double>);                      \ 
  203     return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y));           \ 
  206 #define KOKKOS_IMPL_MATH_TERNARY_FUNCTION(FUNC)                             \ 
  207   KOKKOS_INLINE_FUNCTION float FUNC(float x, float y, float z) {            \ 
  208     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \ 
  209     return FUNC(x, y, z);                                                   \ 
  211   KOKKOS_INLINE_FUNCTION double FUNC(double x, double y, double z) {        \ 
  212     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \ 
  213     return FUNC(x, y, z);                                                   \ 
  215   inline long double FUNC(long double x, long double y, long double z) {    \ 
  217     return FUNC(x, y, z);                                                   \ 
  219   KOKKOS_INLINE_FUNCTION float FUNC##f(float x, float y, float z) {         \ 
  220     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \ 
  221     return FUNC(x, y, z);                                                   \ 
  223   inline long double FUNC##l(long double x, long double y, long double z) { \ 
  225     return FUNC(x, y, z);                                                   \ 
  227   template <class T1, class T2, class T3>                                   \ 
  228   KOKKOS_INLINE_FUNCTION std::enable_if_t<                                  \ 
  229       std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&               \ 
  230           std::is_arithmetic_v<T3> && !std::is_same_v<T1, long double> &&   \ 
  231           !std::is_same_v<T2, long double> &&                               \ 
  232           !std::is_same_v<T3, long double>,                                 \ 
  233       Kokkos::Impl::promote_3_t<T1, T2, T3>>                                \ 
  234   FUNC(T1 x, T2 y, T3 z) {                                                  \ 
  235     using Promoted = Kokkos::Impl::promote_3_t<T1, T2, T3>;                 \ 
  236     using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::FUNC;                       \ 
  237     return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y),         \ 
  238                 static_cast<Promoted>(z));                                  \ 
  240   template <class T1, class T2, class T3>                                   \ 
  241   inline std::enable_if_t<std::is_arithmetic_v<T1> &&                       \ 
  242                               std::is_arithmetic_v<T2> &&                   \ 
  243                               std::is_arithmetic_v<T3> &&                   \ 
  244                               (std::is_same_v<T1, long double> ||           \ 
  245                                std::is_same_v<T2, long double> ||           \ 
  246                                std::is_same_v<T3, long double>),            \ 
  248   FUNC(T1 x, T2 y, T3 z) {                                                  \ 
  249     using Promoted = Kokkos::Impl::promote_3_t<T1, T2, T3>;                 \ 
  250     static_assert(std::is_same_v<Promoted, long double>);                   \ 
  252     return FUNC(static_cast<Promoted>(x), static_cast<Promoted>(y),         \ 
  253                 static_cast<Promoted>(z));                                  \ 
  257 KOKKOS_INLINE_FUNCTION 
int abs(
int n) {
 
  258   using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
 
  261 KOKKOS_INLINE_FUNCTION 
long abs(
long n) {
 
  263 #if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700 
  264   return n > 0 ? n : -n;
 
  266   using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
 
  270 KOKKOS_INLINE_FUNCTION 
long long abs(
long long n) {
 
  272 #if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700 
  273   return n > 0 ? n : -n;
 
  275   using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
 
  279 KOKKOS_INLINE_FUNCTION 
float abs(
float x) {
 
  280 #ifdef KOKKOS_ENABLE_SYCL 
  281   return sycl::fabs(x);  
 
  283   using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
 
  287 KOKKOS_INLINE_FUNCTION 
double abs(
double x) {
 
  288 #ifdef KOKKOS_ENABLE_SYCL 
  289   return sycl::fabs(x);  
 
  291   using KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE::abs;
 
  295 inline long double abs(
long double x) {
 
  299 KOKKOS_IMPL_MATH_UNARY_FUNCTION(fabs)
 
  300 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmod)
 
  301 KOKKOS_IMPL_MATH_BINARY_FUNCTION(remainder)
 
  303 KOKKOS_IMPL_MATH_TERNARY_FUNCTION(fma)
 
  304 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmax)
 
  305 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fmin)
 
  306 KOKKOS_IMPL_MATH_BINARY_FUNCTION(fdim)
 
  307 #ifndef KOKKOS_ENABLE_SYCL 
  308 KOKKOS_INLINE_FUNCTION 
float nanf(
char const* arg) { return ::nanf(arg); }
 
  309 KOKKOS_INLINE_FUNCTION 
double nan(
char const* arg) { return ::nan(arg); }
 
  315 KOKKOS_INLINE_FUNCTION 
float nanf(
char const*) { 
return sycl::nan(0u); }
 
  316 KOKKOS_INLINE_FUNCTION 
double nan(
char const*) { 
return sycl::nan(0ul); }
 
  318 inline long double nanl(
char const* arg) { return ::nanl(arg); }
 
  320 KOKKOS_IMPL_MATH_UNARY_FUNCTION(exp)
 
  322 #if defined(KOKKOS_COMPILER_NVHPC) && KOKKOS_COMPILER_NVHPC < 230700 
  323 KOKKOS_INLINE_FUNCTION 
float exp2(
float val) {
 
  324   constexpr 
float ln2 = 0.693147180559945309417232121458176568L;
 
  325   return exp(ln2 * val);
 
  327 KOKKOS_INLINE_FUNCTION 
double exp2(
double val) {
 
  328   constexpr 
double ln2 = 0.693147180559945309417232121458176568L;
 
  329   return exp(ln2 * val);
 
  331 inline long double exp2(
long double val) {
 
  332   constexpr 
long double ln2 = 0.693147180559945309417232121458176568L;
 
  333   return exp(ln2 * val);
 
  336 KOKKOS_INLINE_FUNCTION 
double exp2(T val) {
 
  337   constexpr 
double ln2 = 0.693147180559945309417232121458176568L;
 
  338   return exp(ln2 * static_cast<double>(val));
 
  341 KOKKOS_IMPL_MATH_UNARY_FUNCTION(exp2)
 
  343 KOKKOS_IMPL_MATH_UNARY_FUNCTION(expm1)
 
  344 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log)
 
  345 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log10)
 
  346 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log2)
 
  347 KOKKOS_IMPL_MATH_UNARY_FUNCTION(log1p)
 
  349 KOKKOS_IMPL_MATH_BINARY_FUNCTION(pow)
 
  350 KOKKOS_IMPL_MATH_UNARY_FUNCTION(sqrt)
 
  351 KOKKOS_IMPL_MATH_UNARY_FUNCTION(cbrt)
 
  352 KOKKOS_IMPL_MATH_BINARY_FUNCTION(hypot)
 
  353 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || \ 
  354     defined(KOKKOS_ENABLE_SYCL) 
  355 KOKKOS_INLINE_FUNCTION 
float hypot(
float x, 
float y, 
float z) {
 
  356   return sqrt(x * x + y * y + z * z);
 
  358 KOKKOS_INLINE_FUNCTION 
double hypot(
double x, 
double y, 
double z) {
 
  359   return sqrt(x * x + y * y + z * z);
 
  361 inline long double hypot(
long double x, 
long double y, 
long double z) {
 
  362   return sqrt(x * x + y * y + z * z);
 
  364 KOKKOS_INLINE_FUNCTION 
float hypotf(
float x, 
float y, 
float z) {
 
  365   return sqrt(x * x + y * y + z * z);
 
  367 inline long double hypotl(
long double x, 
long double y, 
long double z) {
 
  368   return sqrt(x * x + y * y + z * z);
 
  371     class T1, 
class T2, 
class T3,
 
  372     class Promoted = std::enable_if_t<
 
  373         std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
 
  374             std::is_arithmetic_v<T3> && !std::is_same_v<T1, long double> &&
 
  375             !std::is_same_v<T2, long double> &&
 
  376             !std::is_same_v<T3, long double>,
 
  377         Impl::promote_3_t<T1, T2, T3>>>
 
  378 KOKKOS_INLINE_FUNCTION Promoted hypot(T1 x, T2 y, T3 z) {
 
  379   return hypot(static_cast<Promoted>(x), static_cast<Promoted>(y),
 
  380                static_cast<Promoted>(z));
 
  383     class T1, 
class T2, 
class T3,
 
  384     class = std::enable_if_t<
 
  385         std::is_arithmetic_v<T1> && std::is_arithmetic_v<T2> &&
 
  386         std::is_arithmetic_v<T3> &&
 
  387         (std::is_same_v<T1, long double> || std::is_same_v<T2, long double> ||
 
  388          std::is_same_v<T3, long double>)>>
 
  389 inline long double hypot(T1 x, T2 y, T3 z) {
 
  390   return hypot(static_cast<long double>(x), static_cast<long double>(y),
 
  391                static_cast<long double>(z));
 
  394 KOKKOS_IMPL_MATH_TERNARY_FUNCTION(hypot)
 
  397 KOKKOS_IMPL_MATH_UNARY_FUNCTION(sin)
 
  398 KOKKOS_IMPL_MATH_UNARY_FUNCTION(cos)
 
  399 KOKKOS_IMPL_MATH_UNARY_FUNCTION(tan)
 
  400 KOKKOS_IMPL_MATH_UNARY_FUNCTION(asin)
 
  401 KOKKOS_IMPL_MATH_UNARY_FUNCTION(acos)
 
  402 KOKKOS_IMPL_MATH_UNARY_FUNCTION(atan)
 
  403 KOKKOS_IMPL_MATH_BINARY_FUNCTION(atan2)
 
  405 KOKKOS_IMPL_MATH_UNARY_FUNCTION(sinh)
 
  406 KOKKOS_IMPL_MATH_UNARY_FUNCTION(cosh)
 
  407 KOKKOS_IMPL_MATH_UNARY_FUNCTION(tanh)
 
  408 KOKKOS_IMPL_MATH_UNARY_FUNCTION(asinh)
 
  409 KOKKOS_IMPL_MATH_UNARY_FUNCTION(acosh)
 
  410 KOKKOS_IMPL_MATH_UNARY_FUNCTION(atanh)
 
  412 KOKKOS_IMPL_MATH_UNARY_FUNCTION(erf)
 
  413 KOKKOS_IMPL_MATH_UNARY_FUNCTION(erfc)
 
  414 KOKKOS_IMPL_MATH_UNARY_FUNCTION(tgamma)
 
  415 KOKKOS_IMPL_MATH_UNARY_FUNCTION(lgamma)
 
  417 KOKKOS_IMPL_MATH_UNARY_FUNCTION(ceil)
 
  418 KOKKOS_IMPL_MATH_UNARY_FUNCTION(floor)
 
  419 KOKKOS_IMPL_MATH_UNARY_FUNCTION(trunc)
 
  420 KOKKOS_IMPL_MATH_UNARY_FUNCTION(round)
 
  424 #ifndef KOKKOS_ENABLE_SYCL  // FIXME_SYCL 
  425 KOKKOS_IMPL_MATH_UNARY_FUNCTION(nearbyint)
 
  437 KOKKOS_IMPL_MATH_UNARY_FUNCTION(logb)
 
  438 KOKKOS_IMPL_MATH_BINARY_FUNCTION(nextafter)
 
  440 KOKKOS_IMPL_MATH_BINARY_FUNCTION(copysign)
 
  443 KOKKOS_IMPL_MATH_UNARY_PREDICATE(isfinite)
 
  444 KOKKOS_IMPL_MATH_UNARY_PREDICATE(isinf)
 
  445 KOKKOS_IMPL_MATH_UNARY_PREDICATE(isnan)
 
  447 KOKKOS_IMPL_MATH_UNARY_PREDICATE(signbit)
 
  455 #undef KOKKOS_IMPL_MATH_FUNCTIONS_NAMESPACE 
  456 #undef KOKKOS_IMPL_MATH_UNARY_FUNCTION 
  457 #undef KOKKOS_IMPL_MATH_UNARY_PREDICATE 
  458 #undef KOKKOS_IMPL_MATH_BINARY_FUNCTION 
  459 #undef KOKKOS_IMPL_MATH_TERNARY_FUNCTION 
  462 KOKKOS_INLINE_FUNCTION 
float rsqrt(
float val) {
 
  463 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) 
  464   KOKKOS_IF_ON_DEVICE(return ::rsqrtf(val);)
 
  465   KOKKOS_IF_ON_HOST(
return 1.0f / Kokkos::sqrt(val);)
 
  466 #elif defined(KOKKOS_ENABLE_SYCL)
 
  467   KOKKOS_IF_ON_DEVICE(return sycl::rsqrt(val);)
 
  468   KOKKOS_IF_ON_HOST(return 1.0f / Kokkos::sqrt(val);)
 
  470   return 1.0f / Kokkos::sqrt(val);
 
  473 KOKKOS_INLINE_FUNCTION 
double rsqrt(
double val) {
 
  474 #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) 
  475   KOKKOS_IF_ON_DEVICE(return ::rsqrt(val);)
 
  476   KOKKOS_IF_ON_HOST(
return 1.0 / Kokkos::sqrt(val);)
 
  477 #elif defined(KOKKOS_ENABLE_SYCL)
 
  478   KOKKOS_IF_ON_DEVICE(return sycl::rsqrt(val);)
 
  479   KOKKOS_IF_ON_HOST(return 1.0 / Kokkos::sqrt(val);)
 
  481   return 1.0 / Kokkos::sqrt(val);
 
  484 inline long double rsqrt(
long double val) { 
return 1.0l / Kokkos::sqrt(val); }
 
  485 KOKKOS_INLINE_FUNCTION 
float rsqrtf(
float x) { 
return Kokkos::rsqrt(x); }
 
  486 inline long double rsqrtl(
long double x) { 
return Kokkos::rsqrt(x); }
 
  488 KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_integral_v<T>, 
double> rsqrt(
 
  490   return Kokkos::rsqrt(static_cast<double>(x));
 
  495 #ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS 
  496 #undef KOKKOS_IMPL_PUBLIC_INCLUDE 
  497 #undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_MATHFUNCTIONS