Skip to content

Commit

Permalink
Functor as reducer for TeamThreadRange, ThreadVectorRange and TeamVec…
Browse files Browse the repository at this point in the history
…torRange for SYCL
  • Loading branch information
ldh4 committed Apr 5, 2024
1 parent 1b893c2 commit 56716f9
Showing 1 changed file with 104 additions and 38 deletions.
142 changes: 104 additions & 38 deletions core/src/SYCL/Kokkos_SYCL_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@ class SYCLTeamMember {
team_reduce(reducer, reducer.reference());
}

template <typename ReducerType>
KOKKOS_INLINE_FUNCTION std::enable_if_t<is_reducer<ReducerType>::value>
team_reduce(ReducerType const& reducer,
typename ReducerType::value_type& value) const noexcept {
using value_type = typename ReducerType::value_type;
template <typename ReducerType, typename ValueType>
KOKKOS_INLINE_FUNCTION void team_reduce(ReducerType const& reducer,
ValueType& value) const noexcept {
using value_type = typename ValueType;

auto sg = m_item.get_sub_group();
const auto sub_group_range = sg.get_local_range()[0];
Expand All @@ -142,7 +141,7 @@ class SYCLTeamMember {

const auto n_subgroups = sg.get_group_range()[0];
if (n_subgroups == 1) {
reducer.reference() = value;
if constexpr (is_reducer_v<ReducerType>) reducer.reference() = value;
return;
}

Expand Down Expand Up @@ -196,7 +195,9 @@ class SYCLTeamMember {
}
sycl::group_barrier(m_item.get_group());

reducer.reference() = reduction_array[0];
if constexpr (is_reducer_v<ReducerType>)
reducer.reference() = reduction_array[0];
value = reduction_array[0];
// Make sure that the reduction array hasn't been modified in the meantime.
m_item.barrier(sycl::access::fence_space::local_space);
}
Expand Down Expand Up @@ -304,10 +305,9 @@ class SYCLTeamMember {
vector_reduce(reducer, reducer.reference());
}

template <typename ReducerType>
KOKKOS_INLINE_FUNCTION std::enable_if_t<is_reducer<ReducerType>::value>
vector_reduce(ReducerType const& reducer,
typename ReducerType::value_type& value) const {
template <typename ReducerType, typename ValueType>
KOKKOS_INLINE_FUNCTION void vector_reduce(ReducerType const& reducer,
ValueType& value) const {
const auto tidx1 = m_item.get_local_id(1);
const auto grange1 = m_item.get_local_range(1);

Expand All @@ -333,7 +333,7 @@ class SYCLTeamMember {

tmp2 = sg.shuffle(tmp, (sg.get_local_id() / grange1) * grange1);
value = tmp2;
reducer.reference() = tmp2;
if constexpr (is_reducer_v<ReducerType>) reducer.reference() = tmp2;
}

//----------------------------------------
Expand Down Expand Up @@ -553,20 +553,41 @@ KOKKOS_INLINE_FUNCTION std::enable_if_t<!Kokkos::is_reducer<ValueType>::value>
parallel_reduce(const Impl::TeamThreadRangeBoundariesStruct<
iType, Impl::SYCLTeamMember>& loop_boundaries,
const Closure& closure, ValueType& result) {
ValueType val;
Kokkos::Sum<ValueType> reducer(val);

reducer.init(reducer.reference());
using functor_analysis_type = typename Impl::FunctorAnalysis<
Impl::FunctorPatternInterface::REDUCE,
TeamPolicy<typename Impl::SYCLTeamMember::execution_space>, Closure,
ValueType>;

constexpr bool is_reducer_closure =
functor_analysis_type::has_join_member_function &&
functor_analysis_type::has_init_member_function;

using ReducerSelector =
typename Kokkos::Impl::if_c<is_reducer_closure, Closure,
Sum<ValueType>>::type;

auto run_closure = [&](ValueType& value) {
for (iType i = loop_boundaries.start +
loop_boundaries.member.item().get_local_id(0);
i < loop_boundaries.end;
i += loop_boundaries.member.item().get_local_range(0)) {
closure(i, value);
}
};

for (iType i = loop_boundaries.start +
loop_boundaries.member.item().get_local_id(0);
i < loop_boundaries.end;
i += loop_boundaries.member.item().get_local_range(0)) {
closure(i, val);
ValueType val;
if constexpr (is_reducer_closure) {
closure.init(val);
run_closure(val);
loop_boundaries.member.team_reduce(closure, val);
result = val;
} else {
ReducerSelector reducer(val);
reducer.init(reducer.reference());
run_closure(val);
loop_boundaries.member.team_reduce(reducer, val);
result = reducer.reference();
}

loop_boundaries.member.team_reduce(reducer, val);
result = reducer.reference();
}

/** \brief Inter-thread parallel exclusive prefix sum.
Expand Down Expand Up @@ -675,24 +696,46 @@ KOKKOS_INLINE_FUNCTION std::enable_if_t<!Kokkos::is_reducer<ValueType>::value>
parallel_reduce(const Impl::TeamVectorRangeBoundariesStruct<
iType, Impl::SYCLTeamMember>& loop_boundaries,
const Closure& closure, ValueType& result) {
ValueType val;
Kokkos::Sum<ValueType> reducer(val);
using functor_analysis_type = typename Impl::FunctorAnalysis<
Impl::FunctorPatternInterface::REDUCE,
TeamPolicy<typename Impl::SYCLTeamMember::execution_space>, Closure,
ValueType>;

reducer.init(reducer.reference());
constexpr bool is_reducer_closure =
functor_analysis_type::has_join_member_function &&
functor_analysis_type::has_init_member_function;

using ReducerSelector =
typename Kokkos::Impl::if_c<is_reducer_closure, Closure,
Sum<ValueType>>::type;

const iType tidx0 = loop_boundaries.member.item().get_local_id(0);
const iType tidx1 = loop_boundaries.member.item().get_local_id(1);

const iType grange0 = loop_boundaries.member.item().get_local_range(0);
const iType grange1 = loop_boundaries.member.item().get_local_range(1);

for (iType i = loop_boundaries.start + tidx0 * grange1 + tidx1;
i < loop_boundaries.end; i += grange0 * grange1)
closure(i, val);
auto run_closure = [&](ValueType& value) {
for (iType i = loop_boundaries.start + tidx0 * grange1 + tidx1;
i < loop_boundaries.end; i += grange0 * grange1)
closure(i, value);
};

loop_boundaries.member.vector_reduce(reducer);
loop_boundaries.member.team_reduce(reducer);
result = reducer.reference();
ValueType val;
if constexpr (is_reducer_closure) {
closure.init(val);
run_closure(val);
loop_boundaries.member.vector_reduce(closure, val);
loop_boundaries.member.team_reduce(closure, val);
result = val;
} else {
ReducerSelector reducer(val);
reducer.init(reducer.reference());
run_closure(val);
loop_boundaries.member.vector_reduce(reducer);
loop_boundaries.member.team_reduce(reducer);
result = reducer.reference();
}
}

//----------------------------------------------------------------------------
Expand Down Expand Up @@ -770,16 +813,39 @@ KOKKOS_INLINE_FUNCTION std::enable_if_t<!is_reducer<ValueType>::value>
parallel_reduce(Impl::ThreadVectorRangeBoundariesStruct<
iType, Impl::SYCLTeamMember> const& loop_boundaries,
Closure const& closure, ValueType& result) {
result = ValueType();
using functor_analysis_type = typename Impl::FunctorAnalysis<
Impl::FunctorPatternInterface::REDUCE,
TeamPolicy<typename Impl::SYCLTeamMember::execution_space>, Closure,
ValueType>;

constexpr bool is_reducer_closure =
functor_analysis_type::has_join_member_function &&
functor_analysis_type::has_init_member_function;

using ReducerSelector =
typename Kokkos::Impl::if_c<is_reducer_closure, Closure,
Sum<ValueType>>::type;

const iType tidx1 = loop_boundaries.member.item().get_local_id(1);
const int grange1 = loop_boundaries.member.item().get_local_range(1);

for (iType i = loop_boundaries.start + tidx1; i < loop_boundaries.end;
i += grange1)
closure(i, result);
auto run_closure = [&](ValueType& value) {
for (iType i = loop_boundaries.start + tidx1; i < loop_boundaries.end;
i += grange1)
closure(i, value);
};

loop_boundaries.member.vector_reduce(Kokkos::Sum<ValueType>(result));
ValueType val;
if constexpr (is_reducer_closure) {
closure.init(val);
run_closure(val);
loop_boundaries.member.vector_reduce(closure, val);
result = val;
} else {
result = ValueType();
run_closure(result);
loop_boundaries.member.vector_reduce(ReducerSelector(result));
}
}

//----------------------------------------------------------------------------
Expand Down

0 comments on commit 56716f9

Please sign in to comment.