Skip to content

losses

Data losses definitions.

Classes:

  • LossRegularizer

    Base class for the regularizer losses.

  • LossSWTN

    Multi-level n-dimensional stationary wavelet transform loss function.

  • LossTGV

    Total Generalized Variation loss function.

  • LossTV

    Total Variation loss function.

Functions:

  • get_nd_wl_filters

    Generate all possible N-D separable wavelet filters.

  • swt_nd

    Perform N-dimensional Stationary Wavelet Transform (SWT).

LossRegularizer

Bases: MSELoss

Base class for the regularizer losses.

LossSWTN

LossSWTN(
    wl_dec_lo: Tensor,
    wl_dec_hi: Tensor,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    levels: int = 2,
    ndims: int = 2,
    min_approx: bool = False,
)

Bases: LossRegularizer

Multi-level n-dimensional stationary wavelet transform loss function.

Methods:

  • forward

    Compute wavelet decomposition on current batch.

Source code in src/autoden/losses.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def __init__(
    self,
    wl_dec_lo: pt.Tensor,
    wl_dec_hi: pt.Tensor,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    levels: int = 2,
    ndims: int = 2,
    min_approx: bool = False,
) -> None:
    super().__init__(size_average, reduce, reduction)
    self.wl_dec_lo = wl_dec_lo
    self.wl_dec_hi = wl_dec_hi
    self.lambda_val = lambda_val
    self.isotropic = isotropic
    self.levels = levels
    self.ndims = ndims
    self.min_approx = min_approx

forward

forward(img: Tensor) -> Tensor

Compute wavelet decomposition on current batch.

Source code in src/autoden/losses.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute wavelet decomposition on current batch."""
    _check_input_tensor(img, self.ndims)
    axes = list(range(-(self.ndims + 1), 0))

    coeffs = swt_nd(img, wl_dec_lo=self.wl_dec_lo, wl_dec_hi=self.wl_dec_hi, level=self.levels, normalize="scale")

    wl_val = []
    first_ind = int(not self.min_approx)
    for lvl_c in coeffs[first_ind:]:
        coeff = pt.stack(lvl_c, dim=0)

        if self.isotropic:
            wl_val.append(pt.sqrt(pt.pow(coeff, 2).sum(dim=0)).sum(axes))
        else:
            wl_val.append(coeff.abs().sum(dim=0).sum(axes))

    return self.lambda_val * pt.stack(wl_val, dim=0).sum(dim=0).mean() / ((self.levels + self.min_approx) ** 0.5)

LossTGV

LossTGV(
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    ndims: int = 2,
)

Bases: LossTV

Total Generalized Variation loss function.

Methods:

  • forward

    Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    ndims: int = 2,
) -> None:
    super().__init__(size_average, reduce, reduction)
    self.lambda_val = lambda_val
    self.isotropic = isotropic
    self.ndims = ndims

forward

forward(img: Tensor) -> Tensor

Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute total variation statistics on current batch."""
    _check_input_tensor(img, self.ndims)
    axes = list(range(-(self.ndims + 1), 0))

    diffs = [_differentiate(img, dim=dim, position="post") for dim in range(-self.ndims, 0)]
    diffdiffs = [_differentiate(d, dim=dim, position="pre") for dim in range(-self.ndims, 0) for d in diffs]

    if self.isotropic:
        tv_val = pt.sqrt(pt.stack([pt.pow(d, 2) for d in diffs], dim=0).sum(dim=0))
        jac_val = pt.sqrt(pt.stack([pt.pow(d, 2) for d in diffdiffs], dim=0).sum(dim=0))
    else:
        tv_val = pt.stack([d.abs() for d in diffs], dim=0).sum(dim=0)
        jac_val = pt.stack([d.abs() for d in diffdiffs], dim=0).sum(dim=0)

    return self.lambda_val * (tv_val.sum(axes).mean() + jac_val.sum(axes).mean() / 4)

LossTV

LossTV(
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    ndims: int = 2,
)

Bases: LossRegularizer

Total Variation loss function.

Methods:

  • forward

    Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    ndims: int = 2,
) -> None:
    super().__init__(size_average, reduce, reduction)
    self.lambda_val = lambda_val
    self.isotropic = isotropic
    self.ndims = ndims

forward

forward(img: Tensor) -> Tensor

Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute total variation statistics on current batch."""
    _check_input_tensor(img, self.ndims)
    axes = list(range(-(self.ndims + 1), 0))

    diffs = [_differentiate(img, dim=dim, position="post") for dim in range(-self.ndims, 0)]
    diffs = pt.stack(diffs, dim=0)

    if self.isotropic:
        # tv_val = pt.sqrt(pt.stack([pt.pow(d, 2) for d in diffs], dim=0).sum(dim=0))
        tv_val = pt.sqrt(pt.pow(diffs, 2).sum(dim=0))
    else:
        # tv_val = pt.stack([d.abs() for d in diffs], dim=0).sum(dim=0)
        tv_val = diffs.abs().sum(dim=0)

    return self.lambda_val * tv_val.sum(axes).mean()

get_nd_wl_filters

get_nd_wl_filters(
    wl_lo: Tensor, wl_hi: Tensor, ndim: int
) -> list[Tensor]

Generate all possible N-D separable wavelet filters.

Source code in src/autoden/losses.py
143
144
145
146
147
148
149
150
151
152
153
154
def get_nd_wl_filters(wl_lo: pt.Tensor, wl_hi: pt.Tensor, ndim: int) -> list[pt.Tensor]:
    """
    Generate all possible N-D separable wavelet filters.
    """
    filters: list[pt.Tensor] = [wl_lo] + [wl_hi] * ndim
    for _ in range(ndim - 1):
        filters[0] = pt.outer(filters[0], wl_lo)
    for ii in range(ndim):
        new_shape = [1] * ndim
        new_shape[ii] = -1
        filters[ii + 1] = filters[ii + 1].reshape(new_shape)
    return filters

swt_nd

swt_nd(
    x: Tensor,
    wl_dec_lo: Tensor,
    wl_dec_hi: Tensor,
    level: int = 1,
    normalize: str | None = None,
) -> list[list[Tensor]]

Perform N-dimensional Stationary Wavelet Transform (SWT).

Parameters:

  • x (Tensor) –

    Input tensor of shape (B, 1, *dims) where dims can be 1D, 2D, or 3D.

  • wl_dec_lo (Tensor) –

    Low-pass wavelet decomposition filter.

  • wl_dec_hi (Tensor) –

    High-pass wavelet decomposition filter.

  • level (int, default: 1 ) –

    Number of decomposition levels (default is 1).

  • normalize (str or None, default: None ) –

    Normalization method ('none', 'energy', or 'scale'). If None, no normalization is applied (default is None).

Returns:

  • list of list of pt.Tensor

    List like [[approx], [detail_vols], ..., [detail_vols]].

Notes

The function performs the SWT on the input tensor x using the specified wavelet filters and decomposition level. The output is a list of lists, where each inner list contains the decomposition volumes. The first inner list contains the approximation coefficients, and the subsequent inner lists contain the detail coefficients for each level.

Source code in src/autoden/losses.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def swt_nd(
    x: pt.Tensor, wl_dec_lo: pt.Tensor, wl_dec_hi: pt.Tensor, level: int = 1, normalize: str | None = None
) -> list[list[pt.Tensor]]:
    """
    Perform N-dimensional Stationary Wavelet Transform (SWT).

    Parameters
    ----------
    x : pt.Tensor
        Input tensor of shape (B, 1, *dims) where dims can be 1D, 2D, or 3D.
    wl_dec_lo : pt.Tensor
        Low-pass wavelet decomposition filter.
    wl_dec_hi : pt.Tensor
        High-pass wavelet decomposition filter.
    level : int, optional
        Number of decomposition levels (default is 1).
    normalize : str or None, optional
        Normalization method ('none', 'energy', or 'scale'). If None, no normalization is applied (default is None).

    Returns
    -------
    list of list of pt.Tensor
        List like [[approx], [detail_vols], ..., [detail_vols]].

    Notes
    -----
    The function performs the SWT on the input tensor `x` using the specified wavelet filters and decomposition level.
    The output is a list of lists, where each inner list contains the decomposition volumes. The first inner list contains
    the approximation coefficients, and the subsequent inner lists contain the detail coefficients for each level.
    """
    dims = x.shape[2:]
    ndim = len(dims)
    output = []
    current = x

    base_filters = get_nd_wl_filters(
        wl_dec_lo.to(dtype=pt.float32, device=x.device), wl_dec_hi.to(dtype=pt.float32, device=x.device), ndim
    )
    for l in range(1, level + 1):
        dilation = 2 ** (l - 1)

        res_l = []
        for filt in base_filters:
            filt = _normalize_wl_filter(filt, l, normalize)
            filt = filt.unsqueeze(0).unsqueeze(0)  # shape (1, 1, ...)

            # Calculate padding for each dimension
            filt_span_shape = (pt.tensor(filt.shape[2:]).flip(dims=[0]) - 1) * dilation
            pad = [pt.tensor([k // 2, k - k // 2]) for k in filt_span_shape]
            pad = pt.concatenate(pad)
            padded = F.pad(current, pad.tolist(), mode='replicate')

            if ndim == 1:
                out = F.conv1d(padded, filt, dilation=dilation)
            elif ndim == 2:
                out = F.conv2d(padded, filt, dilation=dilation)
            elif ndim == 3:
                out = F.conv3d(padded, filt, dilation=dilation)
            else:
                raise ValueError("Only 1D, 2D, 3D supported")

            res_l.append(out)

        # Split into approximation and details
        current = res_l[0]  # recurse on approximation
        output.append(res_l[1:])

    output.append([current])

    return list(reversed(output))