DINO

Contents

DINO#

class stable_pretraining.methods.DINO(encoder_name: str | Module = 'vit_small_patch16_224', projector_hidden_dim: int = 2048, projector_bottleneck_dim: int = 256, n_prototypes: int = 65536, temperature_student: float = 0.1, temperature_teacher_warmup: float = 0.04, temperature_teacher: float = 0.07, warmup_epochs_temperature_teacher: int = 30, center_momentum: float = 0.9, ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, encoder_kwargs: dict | None = None, pretrained: bool = False)[source]#

Bases: Module

DINO self-distillation with multi-crop and an EMA teacher.

Architecture:
  • Backbone (student) wrapped in TeacherStudentWrapper (teacher is an EMA copy).

  • Projector (student) wrapped in TeacherStudentWrapper: 3-layer MLP -> L2-norm -> linear prototypes (default 65k).

  • Loss: DINOv1Loss with classical centering.

The teacher only sees global crops; the student sees both global and local crops. Loss is the average pairwise cross-entropy between every student view and every teacher view (excluding same-view pairs handled inside DINOv1Loss).

Parameters:
  • encoder_name – timm model name (default "vit_small_patch16_224") or pre-built nn.Module. For multi-crop, the backbone must accept variable input sizes; pass dynamic_img_size=True via encoder_kwargs for timm ViTs.

  • projector_hidden_dim – Hidden dim of the 3-layer MLP (default 2048).

  • projector_bottleneck_dim – Bottleneck dim before prototypes (default 256).

  • n_prototypes – Number of prototypes / output dim (default 65536).

  • temperature_student – Student softmax temperature (default 0.1).

  • temperature_teacher_warmup – Teacher temp at start (default 0.04).

  • temperature_teacher – Teacher temp after warmup (default 0.07).

  • warmup_epochs_temperature_teacher – Linear warmup epochs (default 30).

  • center_momentum – EMA momentum for the teacher centering (default 0.9).

  • ema_decay_start – Initial backbone/projector EMA (default 0.996).

  • ema_decay_end – Final EMA (default 1.0).

  • encoder_kwargs – Extra kwargs forwarded to timm.create_model.

  • pretrained – Load pretrained timm weights for the encoder.

Example:

model = DINO("vit_small_patch16_224", encoder_kwargs={"dynamic_img_size": True})

global_views = [torch.randn(8, 3, 224, 224), torch.randn(8, 3, 224, 224)]
local_views = [torch.randn(8, 3, 96, 96) for _ in range(6)]

out = model(global_views=global_views, local_views=local_views)
out.loss.backward()
forward(global_views: Sequence[Tensor] | None = None, local_views: Sequence[Tensor] | None = None, images: Tensor | None = None) DINOOutput[source]#

Forward pass.

Parameters:
  • global_views – List of n_global tensors [B, C, H, W] (e.g. two 224x224 crops). Required in training mode.

  • local_views – List of n_local tensors [B, C, h, w] (e.g. six 96x96 crops). Optional.

  • images – Single batch of images for evaluation. If supplied, returns the teacher CLS embedding only.

Returns:

DINOOutput.