Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow non-trivially-copyable comparators for SYCL #6939

Open
masterleinad opened this issue Apr 15, 2024 · 2 comments
Open

Allow non-trivially-copyable comparators for SYCL #6939

masterleinad opened this issue Apr 15, 2024 · 2 comments

Comments

@masterleinad
Copy link
Contributor

Related to kokkos/kokkos-core-wiki#504 and #6801 (comment).
Even with changes like

diff --git a/algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp b/algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp
index 47c96a0a0..cd3548d4a 100644
--- a/algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp
+++ b/algorithms/src/sorting/impl/Kokkos_SortByKeyImpl.hpp
@@ -128,9 +128,13 @@ void sort_by_key_onedpl(
   auto policy = oneapi::dpl::execution::make_device_policy(queue);
   const int n = keys.extent(0);
   if constexpr (sizeof...(MaybeComparator)==0)
-  oneapi::dpl::sort_by_key(policy, keys.data(), keys.data() + n, values.data());
+    oneapi::dpl::sort_by_key(policy, keys.data(), keys.data() + n, values.data());
   else {
-  Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem&
+    auto comparator =         std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
+    if constexpr(sycl::is_device_copyable_v<decltype(comparator)>){
+      oneapi::dpl::sort_by_key(policy, keys.data(), keys.data() + n, values.data(), comparator);
+    } else {
+	  Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem&
         indirectKernelMem = exec.impl_internal_space_instance()->get_indirect_kernel_mem();
 
     auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
@@ -142,9 +146,33 @@ void sort_by_key_onedpl(
 			       return comparator(lhs, rhs);
 			     });
     }
+  }
 }
 #endif
 
+/*
+
+  // Can't use Experimental::begin/end here since the oneDPL then assumes that
+  // the data is on the host.
+  auto queue  = space.sycl_queue();
+  auto policy = oneapi::dpl::execution::make_device_policy(queue);
+  const int n = view.extent(0);
+   if constexpr (sizeof...(MaybeComparator)==0)
+    oneapi::dpl::sort(policy, view.data(), view.data() + n);
+  else {
+    auto comparator =         std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
+    if constexpr(sycl::is_device_copyable_v<decltype(comparator)>){
+      oneapi::dpl::sort(policy, view.data(), view.data() + n, comparator);
+    } else {
+      SYCLSortWrapper<decltype(comparator), view_value_type> functor_wrapper(comparator);
+      static_assert(sycl::is_device_copyable_v<decltype(functor_wrapper)>);
+      ::sycl::detail::CheckDeviceCopyable<int>();
+      //sycl::detail::CheckDeviceCopyable<decltype(functor_wrapper)>();
+      oneapi::dpl::sort(policy, view.data(), view.data() + n, functor_wrapper);
+    }
+  }
+*/
+
 template <typename ExecutionSpace, typename PermutationView, typename ViewType>
 void applyPermutation(const ExecutionSpace& space,
                       const PermutationView& permutation,
diff --git a/algorithms/src/sorting/impl/Kokkos_SortImpl.hpp b/algorithms/src/sorting/impl/Kokkos_SortImpl.hpp
index c7bafc121..2237b8485 100644
--- a/algorithms/src/sorting/impl/Kokkos_SortImpl.hpp
+++ b/algorithms/src/sorting/impl/Kokkos_SortImpl.hpp
@@ -184,6 +184,60 @@ void sort_cudathrust(const Cuda& space,
 }
 #endif
 
+
+template <typename Functor, typename ValueType>
+class SYCLSortWrapper {
+  Functor m_functor;
+
+ public:
+  bool operator()(const ValueType& lhs, const ValueType& rhs) const {
+      return m_functor(lhs, rhs);
+  }
+
+  
+    SYCLSortWrapper(const Functor& f) : m_functor(f) {}
+   // { std::memcpy(&m_functor, &f, sizeof(m_functor)); }
+/*
+    SYCLSortWrapper(const SYCLSortWrapper& other) {
+      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
+    }
+    SYCLSortWrapper(SYCLSortWrapper&& other) {
+      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
+    }
+    SYCLSortWrapper& operator=(const SYCLSortWrapper& other) {
+      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
+      return *this;
+    }
+    SYCLSortWrapper& operator=(SYCLSortWrapper&& other) {
+      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
+      return *this;
+    }
+    ~SYCLSortWrapper(){};*/
+
+  const Functor& get_functor() const { return m_functor; }
+};
+
+
+}
+}
+
+template <typename ValueType, typename FunctorWrapper>
+struct MyComparator{
+  bool operator()(const ValueType& lhs, const ValueType& rhs) const {
+      return m_functor_wrapper.get_functor()(lhs, rhs);
+  }
+  static_assert(sycl::is_device_copyable_v<FunctorWrapper>);
+
+  FunctorWrapper m_functor_wrapper;
+};
+
+template <typename Functor, typename ValueType>
+struct
+sycl::is_device_copyable<Kokkos::Impl::SYCLSortWrapper<Functor, ValueType>>: std::true_type {};
+
+namespace Kokkos {
+namespace Impl {
+
 #if defined(KOKKOS_ENABLE_ONEDPL)
 template <class DataType, class... Properties, class... MaybeComparator>
 void sort_onedpl(const Kokkos::Experimental::SYCL& space,
@@ -222,23 +276,20 @@ void sort_onedpl(const Kokkos::Experimental::SYCL& space,
   auto policy = oneapi::dpl::execution::make_device_policy(queue);
   const int n = view.extent(0);
    if constexpr (sizeof...(MaybeComparator)==0)
-  oneapi::dpl::sort(policy, view.data(), view.data() + n);
+    oneapi::dpl::sort(policy, view.data(), view.data() + n);
   else {
-  Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem&
-        indirectKernelMem = space.impl_internal_space_instance()->get_indirect_kernel_mem();
-
-    auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
-        std::forward<MaybeComparator>(maybeComparator)..., indirectKernelMem);
-    oneapi::dpl::sort(policy, view.data(), view.data() + n,
-                             [functor_wrapper](const view_value_type& lhs, const view_value_type& rhs)
-                             {
-			      const auto& comparator = functor_wrapper.get_functor();
-                               return comparator(lhs, rhs);
-                             });
+    auto comparator =         std::get<0>(std::tuple<MaybeComparator...>(maybeComparator...));
+    if constexpr(sycl::is_device_copyable_v<decltype(comparator)>){
+      oneapi::dpl::sort(policy, view.data(), view.data() + n, comparator);
+    } else {
+      SYCLSortWrapper<decltype(comparator), view_value_type> functor_wrapper(comparator);
+      SYCLSortWrapper<decltype(functor_wrapper), view_value_type> double_wrapper(functor_wrapper);
+      static_assert(sycl::is_device_copyable_v<decltype(functor_wrapper)>);
+      //::sycl::detail::CheckDeviceCopyable<int>();
+      //::sycl::detail::CheckDeviceCopyable<decltype(double_wrapper)>();
+      oneapi::dpl::sort(policy, view.data(), view.data() + n, double_wrapper);
     }
-
-  oneapi::dpl::sort(policy, view.data(), view.data() + n,
-                    std::forward<MaybeComparator>(maybeComparator)...);
+  }
 }
 #endif

I struggle to convince the SYCL compiler to accept custom (non-trivially-copyable) comparators

/usr/bin/compiler/../../include/sycl/types.hpp:2572:17: error: static assertion failed due to requirement 'is_device_copyable_v<(lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1816:20)> || detail::IsDeprecatedDeviceCopyable<(lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1816:20), void>::value': The specified type is not device copyable
 2572 |   static_assert(is_device_copyable_v<FieldT> ||
      |                 ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 2573 |                     detail::IsDeprecatedDeviceCopyable<FieldT>::value,
      |                     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/usr/bin/compiler/../../include/sycl/types.hpp:2605:7: note: in instantiation of template class 'sycl::detail::CheckFieldsAreDeviceCopyable<(lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83), 4>' requested here
 2605 |     : CheckFieldsAreDeviceCopyable<FuncT, __builtin_num_fields(FuncT)>,
      |       ^ 
/usr/bin/compiler/../../include/sycl/types.hpp:2613:7: note: in instantiation of template class 'sycl::detail::CheckDeviceCopyable<(lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83)>' requested here
 2613 |     : CheckDeviceCopyable<KernelType> {};
      |       ^
/usr/bin/compiler/../../include/sycl/handler.hpp:1652:5: note: in instantiation of template class 'sycl::detail::CheckDeviceCopyable<sycl::detail::RoundedRangeKernel<sycl::item<1, true>, 1, (lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83)>>' requested here
 1652 |     detail::CheckDeviceCopyable<KernelType>();
      |     ^
/usr/bin/compiler/../../include/sycl/handler.hpp:1694:5: note: in instantiation of function template specialization 'sycl::handler::unpack<sycl::detail::RoundedRangeKernel<sycl::item<1, true>, 1, (lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83)>, sycl::detail::RoundedRangeKernel<sycl::item<1, true>, 1, (lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83)>, sycl::ext::oneapi::experimental::properties<std::tuple<>>, false, (lambda at /usr/bin/compiler/../../include/sycl/handler.hpp:1697:21)>' requested here
 1694 |     unpack<KernelName, KernelType, PropertiesT,
      |     ^
/usr/bin/compiler/../../include/sycl/handler.hpp:1293:7: note: in instantiation of function template specialization 'sycl::handler::kernel_parallel_for_wrapper<sycl::detail::RoundedRangeKernel<sycl::item<1, true>, 1, (lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83)>, sycl::item<1, true>, sycl::detail::RoundedRangeKernel<sycl::item<1, true>, 1, (lambda at /usr/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h:1578:83)>, sycl::ext::oneapi::experimental::properties<std::tuple<>>>' requested here
 1293 |       kernel_parallel_for_wrapper<KName, TransformedArgType, decltype(Wrapper),
      |       ^
/usr/bin/compiler/../../include/sycl/handler.hpp:2332:5: note: (skipping 7 contexts in backtrace; use -ftemplate-backtrace-limit=0 to see all)
 2332 |     parallel_for_lambda_impl<KernelName, KernelType, 1, PropertiesT>(
      |     ^
/home/darndt/kokkos/algorithms/unit_tests/../src/sorting/Kokkos_SortPublicAPI.hpp:114:11: note: in instantiation of function template specialization 'Kokkos::Impl::sort_device_view_with_comparator<(lambda at /home/darndt/kokkos/algorithms/unit_tests/../src/sorting/./impl/Kokkos_SortByKeyImpl.hpp:227:24), unsigned int *, Kokkos::Experimental::SYCL>' requested here
  114 |     Impl::sort_device_view_with_comparator(exec, view, comparator);
      |           ^
/home/darndt/kokkos/algorithms/unit_tests/../src/sorting/./impl/Kokkos_SortByKeyImpl.hpp:226:13: note: in instantiation of function template specialization 'Kokkos::sort<Kokkos::Experimental::SYCL, (lambda at /home/darndt/kokkos/algorithms/unit_tests/../src/sorting/./impl/Kokkos_SortByKeyImpl.hpp:227:24), unsigned int *, Kokkos::Experimental::SYCL>' requested here
  226 |     Kokkos::sort(
      |             ^
/home/darndt/kokkos/algorithms/unit_tests/../src/sorting/./impl/Kokkos_SortByKeyImpl.hpp:272:5: note: in instantiation of function template specialization 'Kokkos::Impl::sort_by_key_via_sort<Kokkos::Experimental::SYCL, int *, Kokkos::Experimental::SYCL, float *, Kokkos::Experimental::SYCL>' requested here
  272 |     sort_by_key_via_sort(exec, keys, values);
      |     ^
/home/darndt/kokkos/algorithms/unit_tests/../src/sorting/Kokkos_SortByKeyPublicAPI.hpp:66:19: note: in instantiation of function template specialization 'Kokkos::Impl::sort_by_key_device_view_without_comparator<int *, Kokkos::Experimental::SYCL, float *, Kokkos::Experimental::SYCL>' requested here
   66 |   ::Kokkos::Impl::sort_by_key_device_view_without_comparator(exec, keys,
      |                   ^
/home/darndt/kokkos/algorithms/unit_tests/TestSortByKey.hpp:87:29: note: in instantiation of function template specialization 'Kokkos::Experimental::sort_by_key<Kokkos::Experimental::SYCL, int *, Kokkos::Experimental::SYCL, float *, Kokkos::Experimental::SYCL>' requested here
   87 |       Kokkos::Experimental::sort_by_key(ExecutionSpace(), keys, values));
      |                             ^
@dalg24
Copy link
Member

dalg24 commented Apr 15, 2024

Are you able to reproduce with a small program that does not use Kokkos and call oneapi::dpl::sort_by_key directly?

@masterleinad
Copy link
Contributor Author

masterleinad commented Apr 16, 2024

#include <sycl/sycl.hpp>

#include <oneapi/dpl/algorithm>
#include <oneapi/dpl/execution>
#include <cstdio>
#include <iostream>

// When the special members are defined, we get the following error
//#define HIDE_SPECIAL_MEMBERS

// User-defined comparator we don't have control over, that does not specialize sycl::is_device_copyable
// Obviously, this functor is in fact device-copyable but that is not the case in general.
struct UserComparator{

  bool operator()(int i, int j)const { return keys[i] < keys[j];}

  int* keys;
};

// Our attempt to guarantee that functor argument passed to oneDPL sort is device-copyable
template <typename Functor, typename ValueType>
class CompareWrapper {
  Functor m_functor;

 public:
  bool operator()(const ValueType& lhs, const ValueType& rhs) const {
      return m_functor(lhs, rhs);
  }

    CompareWrapper(const Functor& f) : m_functor(f) {}

#ifndef HIDE_SPECIAL_MEMBERS
    CompareWrapper(const CompareWrapper& other) {
      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
    }
    CompareWrapper(CompareWrapper&& other) {
      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
    }
    CompareWrapper& operator=(const CompareWrapper& other) {
      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
      return *this;
    }
    CompareWrapper& operator=(CompareWrapper&& other) {
      std::memcpy(&m_functor, &other.m_functor, sizeof(m_functor));
      return *this;
    }
    ~CompareWrapper(){};
#endif
};

#ifndef HIDE_SPECIAL_MEMBERS
template<typename Functor, typename ValueType>
struct
sycl::is_device_copyable<CompareWrapper<Functor, ValueType>> : std::true_type {};
#endif

int main(int argc, char* argv[]) {
  sycl::queue queue;
  
  const int n = 10;
  
  int* values_ptr = sycl::malloc_device<int>(n, queue);
  int* keys_ptr = sycl::malloc_device<int>(n, queue);

  UserComparator user_comparator{keys_ptr};
  CompareWrapper<UserComparator, int> comparator{user_comparator};
  auto policy = oneapi::dpl::execution::make_device_policy(queue);
  static_assert(sycl::is_device_copyable_v<decltype(comparator)>);
  oneapi::dpl::sort(policy, values_ptr, values_ptr + n, comparator);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants