Skip to content

losses.py

Data losses definitions.

LossTV

Bases: MSELoss

Source code in src/autoden/losses.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class LossTV(nn.MSELoss):
    def __init__(
        self, lambda_val: float, size_average=None, reduce=None, reduction: str = "mean", isotropic: bool = True
    ) -> None:
        super().__init__(size_average, reduce, reduction)
        self.lambda_val = lambda_val
        self.isotropic = isotropic

    def forward(self, img: pt.Tensor) -> pt.Tensor:
        """Compute total variation statistics on current batch."""
        if img.ndim != 4:
            raise RuntimeError(f"Expected input `img` to be an 3D tensor, but got {img.shape}")
        axes = [-3, -2, -1]

        diff1 = _differentiate(img, dim=-1)
        diff2 = _differentiate(img, dim=-2)
        if self.isotropic:
            tv_val = pt.sqrt(pt.pow(diff1, 2) + pt.pow(diff2, 2))
        else:
            tv_val = diff1.abs() + diff2.abs()

        return self.lambda_val * tv_val.sum(axes).mean()

forward(img)

Compute total variation statistics on current batch.

Source code in src/autoden/losses.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def forward(self, img: pt.Tensor) -> pt.Tensor:
    """Compute total variation statistics on current batch."""
    if img.ndim != 4:
        raise RuntimeError(f"Expected input `img` to be an 3D tensor, but got {img.shape}")
    axes = [-3, -2, -1]

    diff1 = _differentiate(img, dim=-1)
    diff2 = _differentiate(img, dim=-2)
    if self.isotropic:
        tv_val = pt.sqrt(pt.pow(diff1, 2) + pt.pow(diff2, 2))
    else:
        tv_val = diff1.abs() + diff2.abs()

    return self.lambda_val * tv_val.sum(axes).mean()