Skip to content

Solution of multiple losses on mse_loss, bce_loss, KLD, and loss_contra. #421

@linjing-lab

Description

@linjing-lab

Here is a loss function composed of multiple loss combinations from one spatial transcriptomics task of spatial domain identification. Can losses calculated by different functions be used for multi-task learning?

import torch
import torch.nn.functional as F
from torch.optim import Adam

+ from torchjd import mtl_backward
+ from torchjd.aggregation import UPGrad

  optimizer = Adam(params, lr=0.1)
+ aggregator = UPGrad()

for ...
      mse_loss = F.mse_loss(decoded, x)
      bce_loss = F.binary_cross_entropy_with_logits(preds, labels)
      KLD = -0.5 / n_nodes * torch.mean(torch.sum(1 + 2*logvar -mu.pow(2)-logvar.exp().pow(2), 1))
      loss_contra = contrastive_loss(...)
      optimizer.zero_grad()
-     loss = mse_loss + bce_loss + KLD + loss_contra
-     loss.backward()
+     mtl_backward(losses=[mse_loss, bce_loss, KLD, loss_contra], features=features, aggregator=aggregator)
      optimizer.step()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions