Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dilated layer takes more than k neightbours #96

Open
zademn opened this issue Jun 5, 2022 · 5 comments
Open

Dilated layer takes more than k neightbours #96

zademn opened this issue Jun 5, 2022 · 5 comments

Comments

@zademn
Copy link

zademn commented Jun 5, 2022

The Dilated layer doesn't take into account k. This can lead to taking more neighbours than intended.

t = torch.tensor([
    [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
    [0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5]
])

res = Dilated(k=2, dilation=2)(t)
print(res) # here 3 neighbours are taken even though the constructor specified 2.
# tensor([[0, 0, 0, 1, 1, 1],
#         [0, 2, 4, 0, 2, 4]])
@lightaime
Copy link
Owner

The Dilated class will only dilate the provided knn edge_index. You would first need to find the knn graph and then dilate it like DilatedKnnGraph. The Dilated class itself wouldn’t build the knn graph. Sorry for the confusion.

class DilatedKnnGraph(nn.Module):
    """
    Find the neighbors' indices based on dilated knn
    """
    def __init__(self, k=9, dilation=1, stochastic=False, epsilon=0.0, knn='matrix'):
        super(DilatedKnnGraph, self).__init__()
        self.dilation = dilation
        self.stochastic = stochastic
        self.epsilon = epsilon
        self.k = k
        self._dilated = Dilated(k, dilation, stochastic, epsilon)
        if knn == 'matrix':
            self.knn = knn_graph_matrix
        else:
            self.knn = knn_graph

    def forward(self, x, batch):
        edge_index = self.knn(x, self.k * self.dilation, batch)
        return self._dilated(edge_index, batch)

@zademn
Copy link
Author

zademn commented Jun 5, 2022

I'm not even sure taking [::d] is the right way to go. The following example

t = torch.tensor(
    [
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
    ]
)
res = Dilated(k=2, dilation=2)(t)
print(res)
# tensor([[0, 0, 0, 1, 1, 2, 2, 2],
#         [0, 2, 4, 1, 3, 0, 2, 4]])

For node 1 [1, 0] and [1, 1] are the expected edges but we get [1, 1] and [1 3]

@lightaime
Copy link
Owner

lightaime commented Jun 5, 2022

We should first build a knn graph that has k*d neighbors for each nodes then use [::d] to get the dilated graphs.
edge_index = self.knn(x, self.k * self.dilation, batch)
So this case won’t happen. But you are right. It is not ideal if the provided graph doesn’t have k*d neighbors for each nodes.

@zademn
Copy link
Author

zademn commented Jun 5, 2022

A possible solution would be (using einops):

from einops import rearrange

t = torch.tensor(
    [
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2],
        [0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4],
    ]
)
k = 2
d = 2

u, counts = torch.unique(t[0], return_counts=True)
k_constructed = counts[0]  # assume we always find k neighbours. We can give this as a parameter too
res1 = rearrange(t, "e (n2 k_constructed) -> e n2 k_constructed", k_constructed=k_constructed)

# tensor([[[0, 0, 0, 0, 0],
#          [1, 1, 1, 1, 1],
#          [2, 2, 2, 2, 2]],

#         [[0, 1, 2, 3, 4],
#          [0, 1, 2, 3, 4],
#          [0, 1, 2, 3, 4]]])
res2 = res1[:, :, ::d]  # Res dilated
print(res2)
# tensor([[[0, 0, 0],
#          [1, 1, 1],
#          [2, 2, 2]],

#         [[0, 2, 4],
#          [0, 2, 4],
#          [0, 2, 4]]])
res3 = res2[:, :, :k] # Take first k neighbours
print(res3)
# tensor([[[0, 0],
#          [1, 1],
#          [2, 2]],

#         [[0, 2],
#          [0, 2],
#          [0, 2]]])
res4 = rearrange(res3, "e d1 d2 -> e (d1 d2)")
print(res4)
# tensor([[0, 0, 1, 1, 2, 2],
#         [0, 2, 0, 2, 0, 2]])

@lightaime
Copy link
Owner

Thanks for the suggestion @zademn. That is definitely a good idea if we are dealing with a more complex case. But in our example, we always build knn graphs with k*d neighbors. To keep it simple, we prefer to leave it as it is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants