Differentiable barcode vectorization¶
This tutorial gives you a brief insight in the functionalities offered by the torchph.nn.SLayerExponential
module. It assumes familarity with standard PyTorch
functionality.
Also, torchph.nn.SLayerExponential
is just one structure element and others are available as well (see documentation).
[1]:
from shared_code import check_torchph_availability
check_torchph_availability()
[3]:
from torchph.nn import SLayerExponential
# create an instance with 3 structure elements over \R^2
sl = SLayerExponential(3, 2)
nn.SLayerExponential
is a torch.nn.Module
…
[4]:
import torch
isinstance(sl, torch.nn.Module)
[4]:
True
… now we can do all the beautiful stuff which is inherited from torch.nn.Module
, e.g.,
[5]:
for p in sl.parameters():
print(p)
Parameter containing:
tensor([[0.6355, 0.3604],
[0.3162, 0.9167],
[0.4922, 0.9822]], requires_grad=True)
Parameter containing:
tensor([[3., 3.],
[3., 3.],
[3., 3.]], requires_grad=True)
The module has two parameters: 1. centers
: controls the centers of the structure elements. 2. sharpness
: controls how tight the used Gaussians are. The higher the value, the tighter.
Both can be initialized using the centers_init
and sharpness_init
keyword arguments, respectively.
[6]:
# here is an initialization example
centers_init = torch.Tensor(
[
[0,0],
[0.5, 0.5],
[1,1]
]
)
sharpness_init = torch.Tensor(
[
[1,1],
[2,2],
[3,3]
]
)
sl = SLayerExponential(3, 2,
centers_init=centers_init,
sharpness_init=sharpness_init)
print(sl.centers)
print(sl.sharpness)
Parameter containing:
tensor([[0.0000, 0.0000],
[0.5000, 0.5000],
[1.0000, 1.0000]], requires_grad=True)
Parameter containing:
tensor([[1., 1.],
[2., 2.],
[3., 3.]], requires_grad=True)
The simplest input form for nn.SLayerExponential
is a list
of torch.Tensor
objects which are treated as a batch.
[7]:
# As an example, we create a batch of multisets
mset_1 = [[0, 0]]
mset_2 = [[0, 0], [0, 0]]
mset_3 = [[1, 1], [0, 0]]
mset_4 = [[0, 0], [1, 1]]
batch = [mset_1, mset_2, mset_3, mset_4]
batch = [torch.Tensor(x) for x in batch]
output = sl(batch)
print(output.size())
torch.Size([4, 3])
As we see the output dimensionality is (4, 3)
since we have a batch of size 4
and 3
structure elements.
In other words, output[i, j] =
“evaluation of structure element j on mset_i”
Lets take a look …
[8]:
print(output)
tensor([[1.0000e+00, 1.3534e-01, 1.5230e-08],
[2.0000e+00, 2.7067e-01, 3.0460e-08],
[1.1353e+00, 2.7067e-01, 1.0000e+00],
[1.1353e+00, 2.7067e-01, 1.0000e+00]], grad_fn=<SqueezeBackward0>)
We observe the following:
The j-th stucture element approximates the multiplicity function of the given input at point
sl.centers[j]
. E.g., the output of mset_1,output[0, :]
, is approx.(1, 0, 0)
.sl.sharpness[j]
controls the amount of contribution of points not exactly onsl.centers[j]
with respect to their distance tosl.centers[j]
.The input is interpreted as set, i.e., it is permutation invariant, as mset_3 and mset_4 do not defer as multiset and also
output[2,:] == output[3, :]
.
Maybe this becomes more clear if we increase the sharpness of our structure elements a “little” …
[9]:
sl = SLayerExponential(3, 2,
centers_init=centers_init,
sharpness_init=10*sharpness_init)
print(sl(batch))
tensor([[1., 0., 0.],
[2., 0., 0.],
[1., 0., 1.],
[1., 0., 1.]], grad_fn=<SqueezeBackward0>)
Below is a small toy model to illustrate the applicatation of SLayerExponential
:
[10]:
class ToyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slayer = SLayerExponential(50, 2)
self.linear = torch.nn.Linear(50, 10)
def forward(self, inp):
x = self.slayer(inp)
x = self.linear(x)
return x
[11]:
model = ToyModel()
inp = [torch.rand(10,2), torch.rand(20,2), torch.rand(30,2)]
out = model(inp)
print(out.size())
torch.Size([3, 10])
More information about alternative structure elements, i.e.,
torchph.nn.SLayerRational
torchph.nn.SLayerRationalHat
see documentation.