NTXEntLoss

NTXEntLoss#

class stable_pretraining.losses.NTXEntLoss(temperature: float = 0.5)[source]#

Bases: InfoNCELoss

Normalized temperature-scaled cross entropy loss.

Introduced in the SimCLR paper [Chen et al., 2020]. Also used in MoCo [He et al., 2020].

Parameters:

temperature (float, optional) – The temperature scaling factor. Default is 0.5.

forward(z_i: Tensor, z_j: Tensor) Tensor[source]#

Compute the NT-Xent loss.

Under DDP the negatives span all ranks: the local anchors ([z_i; z_j]) are scored against the candidates gathered from every process, with targets / the self-mask offset by rank so each anchor’s positive (its other view) and self-exclusion point at the right rows of the global candidate set. In a single process this reduces exactly to the classic SimCLR formulation (rank == 0, world_size == 1, gather is a no-op).

Parameters:
  • z_i (torch.Tensor) – Latent representation of the first augmented view of the batch.

  • z_j (torch.Tensor) – Latent representation of the second augmented view of the batch.

Returns:

The computed contrastive loss.

Return type:

float