Source code for torchph.nn.slayer

r"""

Implementation of **differentiable vectorization layers** for persistent
homology barcodes.

For a basic tutorial click `here <tutorials/SLayer.html>`_.
"""
import torch
import numpy as np
from torch.tensor import Tensor
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

from typing import List, Tuple

import warnings


# region helper functions


[docs]def prepare_batch( batch: List[Tensor], point_dim: int=None )->Tuple[Tensor, Tensor, int, int]: """ This method 'vectorizes' the multiset in order to take advances of GPU processing. The policy is to embed all multisets in batch to the highest dimensionality occurring in batch, i.e., ``max(t.size()[0]`` for ``t`` in batch). Args: batch: The input batch to process as a list of tensors. point_dim: The dimension of the points the inputs consist of. Returns: A four-tuple consisting of (1) the constructed ``batch``, i.e., a tensor with size ``batch_size`` x ``n_max_points`` x ``point_dim``; (2) a tensor ``not_dummy`` of size ``batch_size`` x ``n_max_points``, where ``1`` at position (i,j) indicates if the point is a dummy point, whereas ``0`` indicates a dummy point used for padding; (3) the max. number of points and (4) the batch size. Example:: >>> from torchph.nn.slayer import prepare_batch >>> import torch >>> x = [torch.rand(10,2), torch.rand(20,2)] >>> batch, not_dummy, max_pts, batch_size = prepare_batch(x) """ if point_dim is None: point_dim = batch[0].size(1) assert (all(x.size(1) == point_dim for x in batch if len(x) != 0)) batch_size = len(batch) batch_max_points = max([t.size(0) for t in batch]) input_device = batch[0].device if batch_max_points == 0: # if we are here, batch consists only of empty diagrams. batch_max_points = 1 # This will later be used to set the dummy points to zero in the output. not_dummy_points = torch.zeros( batch_size, batch_max_points, device=input_device) prepared_batch = [] for i, multi_set in enumerate(batch): n_points = multi_set.size(0) prepared_dgm = torch.zeros( batch_max_points, point_dim, device=input_device) if n_points > 0: index_selection = torch.tensor(range(n_points), device=input_device) prepared_dgm.index_add_(0, index_selection, multi_set) not_dummy_points[i, :n_points] = 1 prepared_batch.append(prepared_dgm) prepared_batch = torch.stack(prepared_batch) return prepared_batch, not_dummy_points, batch_max_points, batch_size
def is_prepared_batch(input): if not (isinstance(input, tuple) and len(input) == 4): return False else: batch, not_dummy_points, max_points, batch_size = input return isinstance(batch, Tensor) and isinstance(not_dummy_points, Tensor) and max_points > 0 and batch_size > 0 def is_list_of_tensors(input): try: return all([isinstance(x, Tensor) for x in input]) except TypeError: return False def prepare_batch_if_necessary(input, point_dimension=None): batch, not_dummy_points, max_points, batch_size = None, None, None, None if is_prepared_batch(input): batch, not_dummy_points, max_points, batch_size = input elif is_list_of_tensors(input): if point_dimension is None: point_dimension = input[0].size(1) batch, not_dummy_points, max_points, batch_size = prepare_batch( input, point_dimension) else: raise ValueError( 'SLayer does not recognize input format! Expecting [Tensor] or \ prepared batch. Not {}'.format(input)) return batch, not_dummy_points, max_points, batch_size def parameter_init_from_arg(arg, size, default, scalar_is_valid=False): if isinstance(arg, (int, float)): if not scalar_is_valid: raise ValueError('Scalar initialization values are not valid. \ Got {} expected Tensor of size {}.' .format(arg, size)) return torch.Tensor(*size).fill_(arg) elif isinstance(arg, torch.Tensor): assert(arg.size() == size) return arg elif arg is None: if default in [torch.rand, torch.randn, torch.ones, torch.ones_like]: return default(*size) else: return default(size) else: raise ValueError('Cannot handle parameter initialization. \ Got "{}" '.format(arg)) # endregion
[docs]class SLayerExponential(Module): """ Proposed input layer for multisets [1]. """
[docs] def __init__(self, n_elements: int, point_dimension: int=2, centers_init: Tensor=None, sharpness_init: Tensor=None): """ Args: n_elements: Number of structure elements used. point_dimension: D Dimensionality of the points of which the input multi set consists of. centers_init: The initialization for the centers of the structure elements. sharpness_init: Initialization for the sharpness of the structure elements. """ super().__init__() self.n_elements = n_elements self.point_dimension = point_dimension expected_init_size = (self.n_elements, self.point_dimension) centers_init = parameter_init_from_arg( centers_init, expected_init_size, torch.rand, scalar_is_valid=False) sharpness_init = parameter_init_from_arg( sharpness_init, expected_init_size, lambda size: torch.ones(*size)*3) self.centers = Parameter(centers_init) self.sharpness = Parameter(sharpness_init)
[docs] def forward(self, input)->Tensor: batch, not_dummy_points, max_points, batch_size = prepare_batch_if_necessary( input, point_dimension=self.point_dimension) batch = torch.cat([batch] * self.n_elements, 1) not_dummy_points = torch.cat([not_dummy_points] * self.n_elements, 1) centers = torch.cat([self.centers] * max_points, 1) centers = centers.view(-1, self.point_dimension) centers = torch.stack([centers] * batch_size, 0) sharpness = torch.pow(self.sharpness, 2) sharpness = torch.cat([sharpness] * max_points, 1) sharpness = sharpness.view(-1, self.point_dimension) sharpness = torch.stack([sharpness] * batch_size, 0) x = centers - batch x = x.pow(2) x = torch.mul(x, sharpness) x = torch.sum(x, 2) x = torch.exp(-x) x = torch.mul(x, not_dummy_points) x = x.view(batch_size, self.n_elements, -1) x = torch.sum(x, 2) x = x.squeeze() return x
def __repr__(self): return 'SLayerExponential (... -> {} )'.format(self.n_elements)
[docs]class SLayerRational(Module): """ """
[docs] def __init__(self, n_elements: int, point_dimension: int=2, centers_init: Tensor=None, sharpness_init: Tensor=None, exponent_init: Tensor=None, pointwise_activation_threshold=None, share_sharpness=False, share_exponent=False, freeze_exponent=True): """ Args: n_elements: Number of structure elements used. point_dimension: Dimensionality of the points of which the input multi set consists of. centers_init: The initialization for the centers of the structure elements. sharpness_init: Initialization for the sharpness of the structure elements. """ super().__init__() self.n_elements = int(n_elements) self.point_dimension = int(point_dimension) self.pointwise_activation_threshold = float(pointwise_activation_threshold) \ if pointwise_activation_threshold is not None else None self.share_sharpness = bool(share_sharpness) self.share_exponent = bool(share_exponent) self.freeze_exponent = freeze_exponent if self.pointwise_activation_threshold is not None: self.pointwise_activation_threshold = float(self.pointwise_activation_threshold) centers_init = parameter_init_from_arg( arg=centers_init, size=(self.n_elements, self.point_dimension), default=torch.rand) sharpness_init = parameter_init_from_arg( arg=sharpness_init, size=(1,) if self.share_sharpness else (self.n_elements, self.point_dimension), default=torch.ones, scalar_is_valid=True) exponent_init = parameter_init_from_arg( arg=exponent_init, size=(1,) if self.share_exponent else (self.n_elements,), default=torch.ones, scalar_is_valid=True) self.centers = Parameter(centers_init) self.sharpness = Parameter(sharpness_init) if self.freeze_exponent: self.register_buffer('exponent', exponent_init) else: self.exponent = Parameter(exponent_init)
[docs] def forward(self, input)->Tensor: batch, not_dummy_points, max_points, batch_size = prepare_batch_if_necessary( input, point_dimension=self.point_dimension) batch = batch.unsqueeze(1).expand( batch_size, self.n_elements, max_points, self.point_dimension) not_dummy_points = not_dummy_points.unsqueeze(1).expand(-1, self.n_elements, -1) centers = self.centers.unsqueeze(1).expand( self.n_elements, max_points, self.point_dimension) centers = centers.unsqueeze(0).expand(batch_size, *centers.size()) sharpness = self.sharpness if self.share_sharpness: sharpness = sharpness.expand(self.n_elements, self.point_dimension) sharpness = sharpness.unsqueeze(1).expand(-1, max_points, -1) sharpness = sharpness.unsqueeze(0).expand(batch_size, *sharpness.size()) exponent = self.exponent if self.share_exponent: exponent = exponent.expand(self.n_elements) exponent = exponent.unsqueeze(1).expand(-1, max_points) exponent = exponent.unsqueeze(0).expand(batch_size, *exponent.size()) x = centers - batch x = x.abs() x = torch.mul(x, sharpness.abs()) x = torch.sum(x, 3) x = 1/(1+x).pow(exponent.abs()) if self.pointwise_activation_threshold is not None: x[(x < self.pointwise_activation_threshold).data] = 0 x = torch.mul(x, not_dummy_points) x = torch.sum(x, 2) return x
def __repr__(self): return 'SLayerRational (... -> {} )'.format(self.n_elements)
[docs]class SLayerRationalHat(Module): """ """
[docs] def __init__(self, n_elements: int, point_dimension: int=2, centers_init: Tensor=None, radius_init: float=1, exponent: int=1 ): """ Args: n_elements: Number of structure elements used. point_dimension: Dimensionality of the points of which the input multi set consists of. centers_init: The initialization for the centers of the structure elements. radius_init: Initialization for radius of zero level-set of the hat. exponent: Exponent of the rationals forming the hat. """ super().__init__() self.n_elements = int(n_elements) self.point_dimension = int(point_dimension) self.exponent = int(exponent) centers_init = parameter_init_from_arg(arg=centers_init, size=(self.n_elements, self.point_dimension), default=torch.rand) radius_init = parameter_init_from_arg(arg=radius_init, size=(self.n_elements,), default=torch.ones, scalar_is_valid=True) self.centers = Parameter(centers_init) self.radius = Parameter(radius_init) self.norm_p = 1
[docs] def forward(self, input)->Tensor: batch, not_dummy_points, max_points, batch_size = prepare_batch_if_necessary( input, point_dimension=self.point_dimension) batch = batch.unsqueeze(1).expand( batch_size, self.n_elements, max_points, self.point_dimension) not_dummy_points = not_dummy_points.unsqueeze(1).expand(-1, self.n_elements, -1) centers = self.centers.unsqueeze(1).expand( self.n_elements, max_points, self.point_dimension) centers = centers.unsqueeze(0).expand(batch_size, *centers.size()) radius = self.radius radius = radius.unsqueeze(1).expand(-1, max_points) radius = radius.unsqueeze(0).expand(batch_size, *radius.size()) radius = radius.abs() norm_to_center = centers - batch norm_to_center = torch.norm(norm_to_center, p=self.norm_p, dim=3) positive_part = 1/(1+norm_to_center).pow_(self.exponent) negative_part = 1/(1 + (radius - norm_to_center).abs_()).pow_(self.exponent) x = positive_part - negative_part x = torch.mul(x, not_dummy_points) x = torch.sum(x, 2) return x
def __repr__(self): return 'SLayerRationalHat (... -> {} )'.format(self.n_elements)
class LinearRationalStretchedBirthLifeTimeCoordinateTransform: def __init__(self, nu): self._nu = nu self._nu_squared = nu**2 self._2_nu = 2*nu def __call__(self, dgm): if len(dgm) == 0: return dgm x, y = dgm[:, 0], dgm[:, 1] y = y - x i = (y <= self._nu) y[i] = - self._nu_squared/y[i] + self._2_nu return torch.stack([x, y], dim=1) class LogStretchedBirthLifeTimeCoordinateTransform: def __init__(self, nu): self.nu = nu def __call__(self, dgm): if len(dgm) == 0: return dgm x, y = dgm[:, 0], dgm[:, 1] y = y - x i = (y <= self.nu) y[i] = torch.log(y[i] / self.nu)*self.nu + self.nu return torch.stack([x, y], dim=1) class UpperDiagonalThresholdedLogTransform: def __init__(self, nu): self.b_1 = (torch.Tensor([1, 1]) / np.sqrt(2)) self.b_2 = (torch.Tensor([-1, 1]) / np.sqrt(2)) self.nu = nu def __call__(self, dgm): if len(dgm) == 0: return dgm self.b_1 = self.b_1.to(dgm.device) self.b_2 = self.b_2.to(dgm.device) x = torch.mul(dgm, self.b_1.repeat(dgm.size(0), 1)) x = torch.sum(x, 1).squeeze() y = torch.mul(dgm, self.b_2.repeat(dgm.size(0), 1)) y = torch.sum(y, 1).squeeze() i = (y <= self.nu) y[i] = torch.log(y[i] / self.nu)*self.nu + self.nu ret = torch.stack([x, y], 1) return ret