Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HIP: Use builtin atomic for compare_exchange #7000

Merged
merged 3 commits into from
May 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
145 changes: 28 additions & 117 deletions tpls/desul/include/desul/atomics/Compare_Exchange_HIP.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ SPDX-License-Identifier: (BSD-3-Clause)
#ifndef DESUL_ATOMICS_COMPARE_EXCHANGE_HIP_HPP_
#define DESUL_ATOMICS_COMPARE_EXCHANGE_HIP_HPP_

#include <desul/atomics/Adapt_HIP.hpp>
#include <desul/atomics/Common.hpp>
#include <desul/atomics/Lock_Array_HIP.hpp>
#include <desul/atomics/Thread_Fence_HIP.hpp>
Expand All @@ -17,130 +18,40 @@ SPDX-License-Identifier: (BSD-3-Clause)
namespace desul {
namespace Impl {

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4, T> device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrderRelaxed, MemoryScope) {
static_assert(sizeof(unsigned int) == 4,
"this function assumes an unsigned int is 32-bit");
unsigned int return_val = atomicCAS(reinterpret_cast<unsigned int*>(dest),
reinterpret_cast<unsigned int&>(compare),
reinterpret_cast<unsigned int&>(value));
return reinterpret_cast<T&>(return_val);
}
template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 8, T> device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrderRelaxed, MemoryScope) {
static_assert(sizeof(unsigned long long int) == 8,
"this function assumes an unsigned long long is 64-bit");
unsigned long long int return_val =
atomicCAS(reinterpret_cast<unsigned long long int*>(dest),
reinterpret_cast<unsigned long long int&>(compare),
reinterpret_cast<unsigned long long int&>(value));
return reinterpret_cast<T&>(return_val);
}
template <class T>
struct atomic_exchange_available_hip {
constexpr static bool value =
((sizeof(T) == 1 && alignof(T) == 1) || (sizeof(T) == 4 && alignof(T) == 4) ||
(sizeof(T) == 8 && alignof(T) == 8)) &&
std::is_trivially_copyable<T>::value;
};

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T>
device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrderRelease, MemoryScope) {
T return_val = atomic_compare_exchange(
dest, compare, value, MemoryOrderRelaxed(), MemoryScope());
atomic_thread_fence(MemoryOrderRelease(), MemoryScope());
return return_val;
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T>
device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrderAcquire, MemoryScope) {
atomic_thread_fence(MemoryOrderAcquire(), MemoryScope());
T return_val = atomic_compare_exchange(
dest, compare, value, MemoryOrderRelaxed(), MemoryScope());
return return_val;
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T>
template <class T, class MemoryOrder, class MemoryScope>
__device__ std::enable_if_t<atomic_exchange_available_hip<T>::value, T>
device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrderAcqRel, MemoryScope) {
atomic_thread_fence(MemoryOrderAcquire(), MemoryScope());
T return_val = atomic_compare_exchange(
dest, compare, value, MemoryOrderRelaxed(), MemoryScope());
atomic_thread_fence(MemoryOrderRelease(), MemoryScope());
return return_val;
T* const dest, T compare, T value, MemoryOrder, MemoryScope) {
(void)__hip_atomic_compare_exchange_strong(
dest,
&compare,
value,
HIPMemoryOrder<MemoryOrder>::value,
HIPMemoryOrder<cmpexch_failure_memory_order<MemoryOrder>>::value,
HIPMemoryScope<MemoryScope>::value);
return compare;
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4, T> device_atomic_exchange(
T* const dest, T value, MemoryOrderRelaxed, MemoryScope) {
static_assert(sizeof(unsigned int) == 4,
"this function assumes an unsigned int is 32-bit");
unsigned int return_val = atomicExch(reinterpret_cast<unsigned int*>(dest),
reinterpret_cast<unsigned int&>(value));
return reinterpret_cast<T&>(return_val);
}
template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 8, T> device_atomic_exchange(
T* const dest, T value, MemoryOrderRelaxed, MemoryScope) {
static_assert(sizeof(unsigned long long int) == 8,
"this function assumes an unsigned long long is 64-bit");
unsigned long long int return_val =
atomicExch(reinterpret_cast<unsigned long long int*>(dest),
reinterpret_cast<unsigned long long int&>(value));
return reinterpret_cast<T&>(return_val);
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T> device_atomic_exchange(
T* const dest, T compare, T value, MemoryOrderRelease, MemoryScope) {
T return_val = device_atomic_compare_exchange(
dest, compare, value, MemoryOrderRelaxed(), MemoryScope());
device_atomic_thread_fence(MemoryOrderRelease(), MemoryScope());
return reinterpret_cast<T&>(return_val);
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T> device_atomic_exchange(
T* const dest, T /*compare*/, T value, MemoryOrderAcquire, MemoryScope) {
device_atomic_thread_fence(MemoryOrderAcquire(), MemoryScope());
T return_val =
device_atomic_exchange(dest, value, MemoryOrderRelaxed(), MemoryScope());
return reinterpret_cast<T&>(return_val);
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T> device_atomic_exchange(
T* const dest, T value, MemoryOrderAcqRel, MemoryScope) {
device_atomic_thread_fence(MemoryOrderAcquire(), MemoryScope());
T return_val =
device_atomic_exchange(dest, value, MemoryOrderRelaxed(), MemoryScope());
device_atomic_thread_fence(MemoryOrderRelease(), MemoryScope());
return reinterpret_cast<T&>(return_val);
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T> device_atomic_exchange(
T* const dest, T value, MemoryOrderSeqCst, MemoryScope) {
device_atomic_thread_fence(MemoryOrderAcquire(), MemoryScope());
T return_val =
device_atomic_exchange(dest, value, MemoryOrderRelaxed(), MemoryScope());
device_atomic_thread_fence(MemoryOrderRelease(), MemoryScope());
return reinterpret_cast<T&>(return_val);
}

template <class T, class MemoryScope>
__device__ std::enable_if_t<sizeof(T) == 4 || sizeof(T) == 8, T>
device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrderSeqCst, MemoryScope) {
device_atomic_thread_fence(MemoryOrderAcquire(), MemoryScope());
T return_val = device_atomic_compare_exchange(
dest, compare, value, MemoryOrderRelaxed(), MemoryScope());
device_atomic_thread_fence(MemoryOrderRelease(), MemoryScope());
template <class T, class MemoryOrder, class MemoryScope>
__device__ std::enable_if_t<atomic_exchange_available_hip<T>::value, T>
device_atomic_exchange(T* const dest, T value, MemoryOrder, MemoryScope) {
T return_val = __hip_atomic_exchange(dest,
value,
HIPMemoryOrder<MemoryOrder>::value,
HIPMemoryScope<MemoryScope>::value);
return return_val;
}

template <class T, class MemoryOrder, class MemoryScope>
__device__ std::enable_if_t<(sizeof(T) != 8) && (sizeof(T) != 4), T>
__device__ std::enable_if_t<!atomic_exchange_available_hip<T>::value, T>
device_atomic_compare_exchange(
T* const dest, T compare, T value, MemoryOrder, MemoryScope scope) {
// This is a way to avoid deadlock in a warp or wave front
Expand Down Expand Up @@ -169,7 +80,7 @@ device_atomic_compare_exchange(
}

template <class T, class MemoryOrder, class MemoryScope>
__device__ std::enable_if_t<(sizeof(T) != 8) && (sizeof(T) != 4), T>
__device__ std::enable_if_t<!atomic_exchange_available_hip<T>::value, T>
device_atomic_exchange(T* const dest, T value, MemoryOrder, MemoryScope scope) {
// This is a way to avoid deadlock in a warp or wave front
T return_val;
Expand Down