Skip to content

losses

Data losses definitions.

Classes:

  • LossRegularizer

    Base class for the regularizer losses.

  • LossSWTN

    Multi-level n-dimensional stationary wavelet transform loss function.

  • LossTGV

    Total Generalized Variation loss function.

  • LossTV

    Total Variation loss function.

Functions:

  • get_nd_wl_filters

    Generate all possible N-D separable wavelet filters.

  • swt_nd

    Perform N-dimensional Stationary Wavelet Transform (SWT).

LossRegularizer

Bases: MSELoss

Base class for the regularizer losses.

LossSWTN

LossSWTN(
    wl_dec_lo: Tensor,
    wl_dec_hi: Tensor,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    levels: int = 2,
    n_dims: int = 2,
    min_approx: bool = False,
)

Bases: LossRegularizer

Multi-level n-dimensional stationary wavelet transform loss function.

Methods:

  • forward

    Compute wavelet decomposition on current batch.

Source code in src/autoden/losses.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def __init__(
    self,
    wl_dec_lo: pt.Tensor,
    wl_dec_hi: pt.Tensor,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    levels: int = 2,
    n_dims: int = 2,
    min_approx: bool = False,
) -> None:
    super().__init__(size_average, reduce, reduction)
    self.wl_dec_lo = wl_dec_lo
    self.wl_dec_hi = wl_dec_hi
    self.lambda_val = lambda_val
    self.isotropic = isotropic
    self.levels = levels
    self.n_dims = n_dims
    self.min_approx = min_approx

forward

forward(img: Tensor) -> Tensor

Compute wavelet decomposition on current batch.

Source code in src/autoden/losses.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute wavelet decomposition on current batch."""
    _check_input_tensor(img, self.n_dims)
    axes = list(range(-(self.n_dims + 1), 0))

    coeffs = swt_nd(img, wl_dec_lo=self.wl_dec_lo, wl_dec_hi=self.wl_dec_hi, level=self.levels, normalize="scale")

    wl_val = []
    first_ind = int(not self.min_approx)
    for lvl_c in coeffs[first_ind:]:
        coeff = pt.stack(lvl_c, dim=0)

        if self.isotropic:
            wl_val.append(pt.sqrt(pt.pow(coeff, 2).sum(dim=0)).sum(axes))
        else:
            wl_val.append(coeff.abs().sum(dim=0).sum(axes))

    return self.lambda_val * pt.stack(wl_val, dim=0).sum(dim=0).mean() / ((self.levels + self.min_approx) ** 0.5)

LossTGV

LossTGV(
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    n_dims: int = 2,
)

Bases: LossTV

Total Generalized Variation loss function.

Methods:

  • forward

    Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    n_dims: int = 2,
) -> None:
    super().__init__(size_average, reduce, reduction)
    self.lambda_val = lambda_val
    self.isotropic = isotropic
    self.n_dims = n_dims

forward

forward(img: Tensor) -> Tensor

Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute total variation statistics on current batch."""
    _check_input_tensor(img, self.n_dims)
    axes = list(range(-(self.n_dims + 1), 0))

    diffs = [_differentiate(img, dim=dim, position="post") for dim in range(-self.n_dims, 0)]
    diffdiffs = [_differentiate(d, dim=dim, position="pre") for dim in range(-self.n_dims, 0) for d in diffs]

    if self.isotropic:
        tv_val = pt.sqrt(pt.stack([pt.pow(d, 2) for d in diffs], dim=0).sum(dim=0))
        jac_val = pt.sqrt(pt.stack([pt.pow(d, 2) for d in diffdiffs], dim=0).sum(dim=0))
    else:
        tv_val = pt.stack([d.abs() for d in diffs], dim=0).sum(dim=0)
        jac_val = pt.stack([d.abs() for d in diffdiffs], dim=0).sum(dim=0)

    return self.lambda_val * (tv_val.sum(axes).mean() + jac_val.sum(axes).mean() / 4)

LossTV

LossTV(
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    n_dims: int = 2,
)

Bases: LossRegularizer

Total Variation loss function.

Methods:

  • forward

    Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    lambda_val: float,
    size_average=None,
    reduce=None,
    reduction: str = "mean",
    isotropic: bool = True,
    n_dims: int = 2,
) -> None:
    super().__init__(size_average, reduce, reduction)
    self.lambda_val = lambda_val
    self.isotropic = isotropic
    self.n_dims = n_dims

forward

forward(img: Tensor) -> Tensor

Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute total variation statistics on current batch."""
    _check_input_tensor(img, self.n_dims)
    axes = list(range(-(self.n_dims + 1), 0))

    diffs = [_differentiate(img, dim=dim, position="post") for dim in range(-self.n_dims, 0)]
    diffs = pt.stack(diffs, dim=0)

    if self.isotropic:
        # tv_val = pt.sqrt(pt.stack([pt.pow(d, 2) for d in diffs], dim=0).sum(dim=0))
        tv_val = pt.sqrt(pt.pow(diffs, 2).sum(dim=0))
    else:
        # tv_val = pt.stack([d.abs() for d in diffs], dim=0).sum(dim=0)
        tv_val = diffs.abs().sum(dim=0)

    return self.lambda_val * tv_val.sum(axes).mean()

get_nd_wl_filters

get_nd_wl_filters(
    wl_lo: Tensor, wl_hi: Tensor, ndim: int
) -> list[Tensor]

Generate all possible N-D separable wavelet filters.

Source code in src/autoden/losses.py
108
109
110
111
112
113
114
115
116
117
118
119
def get_nd_wl_filters(wl_lo: pt.Tensor, wl_hi: pt.Tensor, ndim: int) -> list[pt.Tensor]:
    """
    Generate all possible N-D separable wavelet filters.
    """
    filters: list[pt.Tensor] = [wl_lo] + [wl_hi] * ndim
    for _ in range(ndim - 1):
        filters[0] = pt.outer(filters[0], wl_lo)
    for ii in range(ndim):
        new_shape = [1] * ndim
        new_shape[ii] = -1
        filters[ii + 1] = filters[ii + 1].reshape(new_shape)
    return filters

swt_nd

swt_nd(
    x: Tensor,
    wl_dec_lo: Tensor,
    wl_dec_hi: Tensor,
    level: int = 1,
    normalize: str | None = None,
) -> list[list[Tensor]]

Perform N-dimensional Stationary Wavelet Transform (SWT).

Parameters:

  • x (Tensor) –

    Input tensor of shape (B, 1, *dims) where dims can be 1D, 2D, or 3D.

  • wl_dec_lo (Tensor) –

    Low-pass wavelet decomposition filter.

  • wl_dec_hi (Tensor) –

    High-pass wavelet decomposition filter.

  • level (int, default: 1 ) –

    Number of decomposition levels (default is 1).

  • normalize (str or None, default: None ) –

    Normalization method ('none', 'energy', or 'scale'). If None, no normalization is applied (default is None).

Returns:

  • list of list of pt.Tensor

    List like [[approx], [detail_vols], ..., [detail_vols]].

Notes

The function performs the SWT on the input tensor x using the specified wavelet filters and decomposition level. The output is a list of lists, where each inner list contains the decomposition volumes. The first inner list contains the approximation coefficients, and the subsequent inner lists contain the detail coefficients for each level.

Source code in src/autoden/losses.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def swt_nd(
    x: pt.Tensor, wl_dec_lo: pt.Tensor, wl_dec_hi: pt.Tensor, level: int = 1, normalize: str | None = None
) -> list[list[pt.Tensor]]:
    """
    Perform N-dimensional Stationary Wavelet Transform (SWT).

    Parameters
    ----------
    x : pt.Tensor
        Input tensor of shape (B, 1, *dims) where dims can be 1D, 2D, or 3D.
    wl_dec_lo : pt.Tensor
        Low-pass wavelet decomposition filter.
    wl_dec_hi : pt.Tensor
        High-pass wavelet decomposition filter.
    level : int, optional
        Number of decomposition levels (default is 1).
    normalize : str or None, optional
        Normalization method ('none', 'energy', or 'scale'). If None, no normalization is applied (default is None).

    Returns
    -------
    list of list of pt.Tensor
        List like [[approx], [detail_vols], ..., [detail_vols]].

    Notes
    -----
    The function performs the SWT on the input tensor `x` using the specified wavelet filters and decomposition level.
    The output is a list of lists, where each inner list contains the decomposition volumes. The first inner list contains
    the approximation coefficients, and the subsequent inner lists contain the detail coefficients for each level.
    """
    dims = x.shape[2:]
    ndim = len(dims)
    output = []
    current = x

    base_filters = get_nd_wl_filters(
        wl_dec_lo.to(dtype=pt.float32, device=x.device), wl_dec_hi.to(dtype=pt.float32, device=x.device), ndim
    )
    for l in range(1, level + 1):
        dilation = 2 ** (l - 1)

        res_l = []
        for filt in base_filters:
            filt = _normalize_wl_filter(filt, l, normalize)
            filt = filt.unsqueeze(0).unsqueeze(0)  # shape (1, 1, ...)

            # Calculate padding for each dimension
            filt_span_shape = (pt.tensor(filt.shape[2:]).flip(dims=[0]) - 1) * dilation
            pad = [pt.tensor([k // 2, k - k // 2]) for k in filt_span_shape]
            pad = pt.concatenate(pad)
            padded = F.pad(current, pad.tolist(), mode='replicate')

            if ndim == 1:
                out = F.conv1d(padded, filt, dilation=dilation)
            elif ndim == 2:
                out = F.conv2d(padded, filt, dilation=dilation)
            elif ndim == 3:
                out = F.conv3d(padded, filt, dilation=dilation)
            else:
                raise ValueError("Only 1D, 2D, 3D supported")

            res_l.append(out)

        # Split into approximation and details
        current = res_l[0]  # recurse on approximation
        output.append(res_l[1:])

    output.append([current])

    return list(reversed(output))