Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

backpropagation for sparse semi-structured #126420

Open
bsulyok opened this issue May 16, 2024 · 1 comment
Open

backpropagation for sparse semi-structured #126420

bsulyok opened this issue May 16, 2024 · 1 comment
Assignees
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bsulyok
Copy link

bsulyok commented May 16, 2024

馃殌 The feature, motivation and pitch

I am trying to train ultra-sparse linear layers with as low as 0.1% of nonzero elements. Forward propagation is successful, however propagating the loss backward fails.
I understand that this data structure is a work in progress and also that the main use-case is training dense linear modules but replace them with sparse ones for a skinnier evaluation model.
Nonetheless I see this as a logical next step in research an implementation. The current next best thing is masking dense tensors, which is a significant decrease in training performance and limitation in model size.

Here's a snippet of code for what I am trying to achieve but fail:

from torchvision.transforms import v2 as T
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torch
from torch.sparse import to_sparse_semi_structured
from torch import nn
from pathlib import Path
from torch.optim import SGD

test_dataloader = DataLoader(
    MNIST(
        root=Path('data'),
        train=False,
        download=True,
        transform=T.Compose([
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True),
            T.Normalize(mean=(0.1307,), std=(0.3081,)),
            T.ToDtype(torch.float16)
        ])
    ),
    batch_size=8
)

class SparseLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        density: float = 0.1,
    ):
        super(SparseLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features


        self.weight = nn.Parameter(to_sparse_semi_structured((torch.rand(self.out_features, self.in_features) < density).half().cuda()))
        self.bias = None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=28*28, out_features=512, dtype=torch.float16),
        nn.ReLU(),
        SparseLinear(512, 256, density=0.1, bias=False),
        nn.ReLU(),
        nn.Linear(in_features=256, out_features=10, dtype=torch.float16),
        nn.Softmax(dim=1)
)

device = torch.device(0)
model = model.to(device)


loss_function = nn.CrossEntropyLoss()
optimizer = SGD(params=model.parameters(), lr=1e-4)

x, y_true = next(iter(test_dataloader))

optimizer.zero_grad()
y_pred = model(x.to(device))
loss = loss_function(y_pred, F.one_hot(y_true, num_classes=10).half().to(device))
loss.backward()

This results in the following error:

SparseSemiStructuredTensorCUTLASS matmul: operation is not supported

Alternatives

No response

Additional context

No response

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @jcaip

@mikaylagawarecki mikaylagawarecki added module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 20, 2024
@jcaip jcaip self-assigned this May 20, 2024
@jcaip
Copy link
Contributor

jcaip commented May 21, 2024

Hey @bsulyok you'll be happy to hear that we've added prototype support for semi-structured sparse here :)

This uses a little bit of different user API, we've created a SemiStructuredSparseLinear drop in replacement for nn.Linear, instead of a pure tensor subclass. This is because we sometimes want to do activation sparsity and also to handle autograd support with torch.Function.

There are some meaningful differences between this code and to_sparse_semi_structured. Namely, we apply sparsity to a 4x4 tile so that we can accelerate both the forwards and backwards pass, since we have Wx for the forward pass and W' dL/dx ' for the backwards pass. So we need to be 2:4 sparse in both directions.

Additionally, we've written fast sparsification kernels that do runtime sparsity for training. These kernels to 2:4 pruning + compression very quickly at runtime, this makes distributed support much simpler. Additionally, you'll need cuSPARSELt support to see e2e speedups, CUTLASS is not sufficient.

I am writing a blog post about this that should be publicly available shortly, will share when it's available.
Eventually upstreaming this into pytorch core is something we're thinking about now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants