Skip to content

Commit

Permalink
just for you to see
Browse files Browse the repository at this point in the history
  • Loading branch information
romintomasetti committed Apr 12, 2024
1 parent 4911491 commit 6b04674
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 24 deletions.
17 changes: 7 additions & 10 deletions core/src/Cuda/Kokkos_Cuda_Graph_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,13 @@ struct GraphImpl<Kokkos::Cuda> {

using node_details_t = GraphNodeBackendSpecificDetails<Kokkos::Cuda>;

void _instantiate_graph() {
constexpr size_t error_log_size = 256;
cudaGraphNode_t error_node = nullptr;
char error_log[error_log_size];
public:
template <typename... Args>
void instantiate_graph(Args&&... args) {
KOKKOS_IMPL_CUDA_SAFE_CALL(
(m_execution_space.impl_internal_space_instance()
->cuda_graph_instantiate_wrapper(&m_graph_exec, m_graph,
&error_node, error_log,
error_log_size)));
// TODO @graphs print out errors
std::forward<Args>(args)...)));
}

public:
Expand Down Expand Up @@ -163,7 +160,7 @@ struct GraphImpl<Kokkos::Cuda> {

void submit() {
if (!bool(m_graph_exec)) {
_instantiate_graph();
instantiate_graph();
}
KOKKOS_IMPL_CUDA_SAFE_CALL(
(m_execution_space.impl_internal_space_instance()
Expand Down Expand Up @@ -201,8 +198,8 @@ struct GraphImpl<Kokkos::Cuda> {
aggregate_kernel_impl_t{});
}

cudaGraph_t& get_cuda_graph() { return m_graph; }
cudaGraphExec_t& get_cuda_graph_exec() { return m_graph_exec; }
cudaGraph_t get_cuda_graph() const { return m_graph; }
cudaGraphExec_t get_cuda_graph_exec() const { return m_graph_exec; }
};

} // end namespace Impl
Expand Down
35 changes: 28 additions & 7 deletions core/src/Cuda/Kokkos_Cuda_Instance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,15 +453,36 @@ class CudaInternal {
return cudaFuncSetAttribute(entry, attr, value);
}

template <bool setCudaDevice = true>
/// The @c CUDA graph API changed in @c CUDA 12.
/// All signatures now take 3 arguments:
/// - cudaGraphExec_t*
/// - cudaGraph_t
/// - either some flag as an unsigned long long or a cudaGraphInstantiateParams*
///
/// Note that my implementation is just to illustrate the idea developed in
// https://github.com/kokkos/kokkos/pull/6904#discussion_r1562293918.
template <bool setCudaDevice = true, typename... Args>
cudaError_t cuda_graph_instantiate_wrapper(cudaGraphExec_t* pGraphExec,
cudaGraph_t graph,
cudaGraphNode_t* pErrorNode,
char* pLogBuffer,
size_t bufferSize) const {
if constexpr (setCudaDevice) set_cuda_device();
return cudaGraphInstantiate(pGraphExec, graph, pErrorNode, pLogBuffer,
bufferSize);
Args&&... args) const {
static_assert(sizeof...(Args) == 0 || sizeof...(Args) == 1, "Only one optional parameter supported.");

if constexpr (setCudaDevice) set_cuda_device();

if constexpr (sizeof...(Args) == 0) {
printf("> Calling cudaGraphInstantiate.\n");
return cudaGraphInstantiate(pGraphExec, graph);
} else if (sizeof...(Args) == 1) {
if constexpr (std::is_pointer_v<Args...>) {
printf("> Calling cudaGraphInstantiateWithParams.\n");
return cudaGraphInstantiateWithParams(pGraphExec, graph, std::forward<Args>(args)...);
} else {
printf("> Calling cudaGraphInstantiateWithFlags %d.\n", std::forward<Args>(args)...);
return cudaGraphInstantiateWithFlags(pGraphExec, graph, std::forward<Args>(args)...);
}
} else {
Kokkos::abort("oups not 0 or 1 arg");
}
}

// Resizing of reduction related scratch spaces
Expand Down
8 changes: 1 addition & 7 deletions core/unit_test/cuda/TestCuda_Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@ TEST(TEST_CATEGORY, cuda_graph_instantiate_and_debug_dot_print) {
ASSERT_EQ(graph_ptr_impl->get_cuda_graph_exec(), nullptr);

//! Instantiate the graph manually.
constexpr size_t error_log_size = 256;
cudaGraphNode_t error_node = nullptr;
char error_log[error_log_size];
cudaGraphInstantiate(&graph_ptr_impl->get_cuda_graph_exec(),
graph_ptr_impl->get_cuda_graph(), &error_node, error_log,
error_log_size);
ASSERT_EQ(error_node, nullptr) << error_log;
graph_ptr_impl->instantiate_graph(cudaGraphInstantiateFlagAutoFreeOnLaunch);

/// At this stage, the @c Cuda "executable" graph should not be null,
/// because it has been instantiated.
Expand Down

0 comments on commit 6b04674

Please sign in to comment.