Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Execution hang for float16 type with MHA block #1108

Open
Egor-Krivov opened this issue May 13, 2024 · 0 comments
Open

Execution hang for float16 type with MHA block #1108

Egor-Krivov opened this issue May 13, 2024 · 0 comments

Comments

@Egor-Krivov
Copy link

Execution hangs with float16 type on xpu.

It works with cuda or with float32 type.

Code:

import math

import torch
import torch
import intel_extension_for_pytorch as ipex
import triton
import triton.language as tl

class Backend:
    device = 'xpu'

    def sync(self):
        torch.xpu.synchronize()

    def check_device(*args):
        return True

backend = Backend()
'''
Sources -
kernel_fma is based on Triton matmul tutorial in https://github.com/openai/triton/blob/main/python/tutorials/03-matrix-multiplication.py
kenel_colsum is written from scratch
kernel_swiglu_fwd is written from scratch
kernel_swiglu_bwd is written from scratch
'''

# Implements Z = X @ Y + b
@triton.jit
def kernel_fma(
    x_ptrs, y_ptrs, b_ptrs, z_ptrs,
    M, N, K,
    stride_xm, stride_xk,
    stride_yk, stride_yn,
    stride_zm, stride_zn,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_K: tl.constexpr,
    GROUP_SIZE_M: tl.constexpr,
    BIAS_REQD: tl.constexpr,
):
    # pid -> pid_m, pid_n
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (pid % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m
    # block pointers for x and y
    offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_yn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    x_ptrs = x_ptrs + (offs_xm[:, None] * stride_xm + offs_k [None, :] * stride_xk)
    y_ptrs = y_ptrs + (offs_k [:, None] * stride_yk + offs_yn[None, :] * stride_yn)
    # initialize accumulator to zeros in SLM
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # loop over k dim
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        # load input blocks from HBM
        x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        y = tl.load(y_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        # accumulate product of input blocks
        acc += tl.dot(x, y)
        # increment block pointers
        x_ptrs += BLOCK_SIZE_K * stride_xk
        y_ptrs += BLOCK_SIZE_K * stride_yk
    z = acc.to(tl.float16)
    # add bias to accumulator
    if BIAS_REQD:
        bias = tl.load(b_ptrs + offs_yn, mask=offs_yn < N, other=0.0)
        z += bias[None, :]
    # store output block to HBM
    offs_zm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_zn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    z_ptrs = z_ptrs + offs_zm[:, None] * stride_zm + offs_zn[None, :] * stride_zn
    z_mask = (offs_zm[:, None] < M) & (offs_zn[None, :] < N)
    tl.store(z_ptrs, z, mask=z_mask)


@triton.jit
def kernel_colsum(
    x_ptrs, s_ptrs,
    M, N,
    stride_m, stride_n,
    BLOCK_SIZE_M: tl.constexpr,
    BLOCK_SIZE_N: tl.constexpr
):
    pid = tl.program_id(axis=0)
    offs_m = tl.arange(0, BLOCK_SIZE_M)
    offs_n = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    x_ptrs = x_ptrs + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
    # initialize accumulator to zeros in SLM
    acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    for m in range(0, tl.cdiv(M, BLOCK_SIZE_M)):
        # load input block from HBM
        # x = tl.load(x_ptrs, mask=offs_m[:, None] + m * BLOCK_SIZE_M < M & offs_n < N, other=0.0)
        # x = tl.load(x_ptrs, mask=offs_m[:, None] < M - m * BLOCK_SIZE_M & offs_n < N, other=0.0)
        x = tl.load(x_ptrs, mask=offs_m[:, None] < M - m * BLOCK_SIZE_M, other=0.0)
        acc += x
        # increment block pointers
        x_ptrs += BLOCK_SIZE_M * stride_m
    r = tl.sum(acc, axis=0)
    s_ptrs += offs_n
    tl.store(s_ptrs, r)


@triton.jit
def kernel_swiglu_fwd(
    x_ptrs, y_ptrs, z_ptrs,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptrs + offsets, mask).to(tl.float32)
    y = tl.load(y_ptrs + offsets, mask).to(tl.float32)

    u = tl.sigmoid(x)
    v = x * u           # silu
    z = v * y           # swiglu

    z = z.to(tl.float16)
    tl.store(z_ptrs + offsets, z, mask)


@triton.jit
def kernel_swiglu_bwd(
    x_ptrs, y_ptrs, dz_ptrs, dx_ptrs, dy_ptrs,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)

    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptrs + offsets, mask).to(tl.float32)
    y = tl.load(y_ptrs + offsets, mask).to(tl.float32)
    dz = tl.load(dz_ptrs + offsets, mask).to(tl.float32)

    u = tl.sigmoid(x)
    v = x * u           # silu
    dy = dz * v
    dt = dz * y         # temp
    dx = dt * u * (1.0 + x * (1.0 - u))

    dx = dx.to(tl.float16)
    dy = dy.to(tl.float16)
    tl.store(dx_ptrs + offsets, dx, mask)
    tl.store(dy_ptrs + offsets, dy, mask)



def fused_mul_add(X, Y, b, transpose_x, transpose_y):
    if transpose_x:
        K, M = X.shape
        Xstride0, Xstride1 = X.stride(1), X.stride(0)
    else:
        M, K = X.shape
        Xstride0, Xstride1 = X.stride(0), X.stride(1)
    if transpose_y:
        N, _ = Y.shape
        Wstride0, Wstride1 = Y.stride(1), Y.stride(0)
    else:
        _, N = Y.shape
        Wstride0, Wstride1 = Y.stride(0), Y.stride(1)

    # Allocates output.
    Z = torch.empty((M, N), device=X.device, dtype=X.dtype)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    kernel_fma[grid](
        X, Y, b, Z,
        M, N, K,
        Xstride0, Xstride1,
        Wstride0, Wstride1,
        Z.stride(0), Z.stride(1),
        BLOCK_SIZE_M=128,
        BLOCK_SIZE_N=128,
        BLOCK_SIZE_K=32,
        GROUP_SIZE_M=8,
        BIAS_REQD=b is not None,
    )

    return Z


# Implements Z = XY
class Matmul(torch.autograd.Function):

    @staticmethod
    def forward(X, Y):
        # Check constraints.
        assert X.shape[1] == Y.shape[0], "Incompatible dimensions for X and Y"
        assert X.is_contiguous(), "Matrix X must be contiguous"
        assert Y.is_contiguous(), "Matrix Y must be contiguous"

        return fused_mul_add(X, Y, None, transpose_x=False, transpose_y=False)

    @staticmethod
    def setup_context(ctx, inputs, output):
        ctx.save_for_backward(*inputs)

    @staticmethod
    def backward(ctx, dZ):
        X, Y = ctx.saved_tensors

        # dX = dZ @ Y.T # (M x N) x (N x K)
        dX = fused_mul_add(dZ, Y, None, transpose_x=False, transpose_y=True)

        # dY = X.T @ dZ # (K x M) x (M x N)
        dY = fused_mul_add(X, dZ, None, transpose_x=True, transpose_y=False)

        return dX, dY



@triton.jit
def _softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):
    # The rows of the softmax are independent, so we parallelize across those
    row_idx = tl.program_id(0)
    # The stride represents how much we need to increase the pointer to advance 1 row
    row_start_ptr = input_ptr + row_idx * input_row_stride
    # The block size is the next power of two greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets
    # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))
    # Subtract maximum for numerical stability
    row_minus_max = row - tl.max(row, axis=0)
    # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
    numerator = tl.exp(row_minus_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)


@triton.jit
def _softmax_backward_kernel(grad_input, grad_output, output, grad_input_stride, grad_out_stride, output_row_stride,
                             n_cols, BLOCK_SIZE: tl.constexpr):
    # Parallelization across rows
    row_idx = tl.program_id(0)

    # Memory pointer calculations
    row_start_ptr = grad_input + row_idx * grad_input_stride
    grad_output_row_start_ptr = grad_output + row_idx * grad_out_stride
    output_row_start_ptr = output + row_idx * output_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)

    # Memmory addresses of all the elements we want to load
    grad_output_ptrs = grad_output_row_start_ptr + col_offsets
    output_ptrs = output_row_start_ptr + col_offsets

    # Load relevant data
    o = tl.load(output_ptrs, mask=col_offsets < n_cols)
    g = tl.load(grad_output_ptrs, mask=col_offsets < n_cols)

    # Using cross-entropy loss
    # Step1: Compute intermediate sum used for gradient
    s = tl.sum(g * o, 0)

    # Step1: Compute the gradients
    grad_input = o * (g - s)

    grad_input_ptrs = row_start_ptr + col_offsets
    tl.store(grad_input_ptrs, grad_input, mask=col_offsets < n_cols)
# %%
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.


class Softmax(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x):
        n_rows, n_cols = x.shape
        # The block size is the smallest power of two greater than the number of columns in `x`
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        # Another trick we can use is to ask the compiler to use more threads per row by
        # increasing the number of warps (`num_warps`) over which each row is distributed.
        # You will see in the next tutorial how to auto-tune this value in a more natural
        # way so you don't have to come up with manual heuristics yourself.
        num_warps = 4
        if BLOCK_SIZE >= 2048:
            num_warps = 8
        if BLOCK_SIZE >= 4096:
            num_warps = 16
        # Allocate output
        y = torch.empty_like(x)
        # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o
        # f the input matrix
        _softmax_kernel[(n_rows, )](
            y,
            x,
            x.stride(0),
            y.stride(0),
            n_cols,
            num_warps=num_warps,
            BLOCK_SIZE=BLOCK_SIZE,
        )
        ctx.save_for_backward(y)
        return y

    @staticmethod
    def backward(ctx, grad_out):
        (out,) = ctx.saved_tensors
        n_rows, n_cols = out.shape

        # The block size is the smallest power of two greater than the number of columns in `x`
        BLOCK_SIZE = triton.next_power_of_2(n_cols)

        # torch.zeros is measurably slower, we'll zero out in the kernel
        grad_in = torch.empty_like(out)

        # Make sure that the tensor are contiguous
        grad_in, grad_out, out = map(lambda x: x.contiguous(), [
                                     grad_in, grad_out, out])
        _softmax_backward_kernel[(n_rows, )](
            grad_in, grad_out, out,
            grad_in.stride(0),
            grad_out.stride(0),
            out.stride(0),
            n_cols,
            BLOCK_SIZE,
        )
        return grad_in.reshape_as(grad_out)



class TritonAttentionHead(torch.nn.Module):
    def __init__(self, config, name=None) -> None:
        super().__init__()

        self.config   = config
        self.name = name

        self.dK = int(self.config.emb_size/self.config.no_of_heads)
        self.sqrt_dk = math.sqrt(self.dK)

        self.wQ = torch.nn.Parameter(torch.rand([self.config.emb_size, self.dK],
                    device=backend.device, dtype=self.config.dtype, requires_grad=True))
        self.wK = torch.nn.Parameter(torch.rand_like(self.wQ))
        self.wV = torch.nn.Parameter(torch.rand_like(self.wQ))

        self.matmul_triton = Matmul.apply
        self.softmax_triton = Softmax.apply

        if self.config.mask:
            self.mask = torch.triu(
                torch.ones(self.config.no_of_embs, self.config.no_of_embs, device=backend.device,
                           dtype=self.config.dtype, requires_grad=True),
                diagonal=1)
            self.mask[self.mask.bool()] = -float('inf')
            #print('Mask: ', self.mask)


    def forward(self, embs_for_q, embs_for_k, embs_for_v) -> torch.tensor:
        assert backend.check_device(embs_for_q)

        q = self.matmul_triton(embs_for_q, self.wQ)
        k = self.matmul_triton(embs_for_k, self.wK)
        v = self.matmul_triton(embs_for_v, self.wV)

        # # TODO: multiply by 1/sqrt(dK) and use Triton for qkT computation
        qkT = torch.mm(q, torch.transpose(k, 0, 1))
        if self.config.mask:
            qkT = qkT + self.mask
        softmaxed = self.softmax_triton(qkT)
        output = self.matmul_triton(softmaxed, v)

        return output



class TritonMultiHeadAttention(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.attention_heads: TritonAttentionHead = []
        for x in range(config.no_of_heads):
            self.attention_heads.append(TritonAttentionHead(config, "attention_head_{x}"))

        self.dK = int(self.config.emb_size/self.config.no_of_heads)
        self.wO = torch.nn.Parameter(torch.rand([self.config.emb_size, self.config.emb_size],
                    device=backend.device, dtype=self.config.dtype, requires_grad=True))

        self.matmul_triton = Matmul.apply

    def forward(self, embs_for_q, embs_for_k, embs_for_v):
        list_of_z = []

        for x in range(self.config.no_of_heads):
            list_of_z.append(self.attention_heads[x](embs_for_q, embs_for_k, embs_for_v))
            print(x)

        backend.sync()
        print("sync finished")

        concatZ = torch.cat(list_of_z, dim=1)
        output = self.matmul_triton(concatZ, self.wO)

        return output


def fwd_bwd_triton_full_attention(
    embs: torch.tensor, triton_multi_head_attention: TritonMultiHeadAttention
):
    pred_y = triton_multi_head_attention(embs, embs, embs)
    grad_output = torch.rand_like(pred_y)
    pred_y.backward(grad_output, retain_graph=True)
    return pred_y, grad_output


class Config:
    no_of_embs = 1024
    emb_size = 4096
    no_of_heads = 32
    mask = True
    separate_kernels = True
    dtype = torch.float16

def main():
    config = Config()

    embds = torch.rand(
        [config.no_of_embs, config.emb_size],
        device=backend.device,
        dtype=config.dtype,
        requires_grad=True,
    )
    head = TritonMultiHeadAttention(config)

    res = fwd_bwd_triton_full_attention(
        embds, head
    )
    print(res)

if __name__ == '__main__':
    main()
@vlad-penkin vlad-penkin added this to the 4.0 [Performance] Core milestone May 13, 2024
@vlad-penkin vlad-penkin added bug Something isn't working tests: e2e labels May 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants