DINO#
- class stable_pretraining.methods.DINO(encoder_name: str | Module = 'vit_small_patch16_224', projector_hidden_dim: int = 2048, projector_bottleneck_dim: int = 256, n_prototypes: int = 65536, temperature_student: float = 0.1, temperature_teacher_warmup: float = 0.04, temperature_teacher: float = 0.07, warmup_epochs_temperature_teacher: int = 30, center_momentum: float = 0.9, ema_decay_start: float = 0.996, ema_decay_end: float = 1.0, encoder_kwargs: dict | None = None, pretrained: bool = False)[source]#
Bases:
ModuleDINO self-distillation with multi-crop and an EMA teacher.
- Architecture:
Backbone (student) wrapped in
TeacherStudentWrapper(teacher is an EMA copy).Projector (student) wrapped in
TeacherStudentWrapper: 3-layer MLP -> L2-norm -> linear prototypes (default 65k).Loss:
DINOv1Losswith classical centering.
The teacher only sees global crops; the student sees both global and local crops. Loss is the average pairwise cross-entropy between every student view and every teacher view (excluding same-view pairs handled inside
DINOv1Loss).- Parameters:
encoder_name – timm model name (default
"vit_small_patch16_224") or pre-builtnn.Module. For multi-crop, the backbone must accept variable input sizes; passdynamic_img_size=Trueviaencoder_kwargsfor timm ViTs.projector_hidden_dim – Hidden dim of the 3-layer MLP (default 2048).
projector_bottleneck_dim – Bottleneck dim before prototypes (default 256).
n_prototypes – Number of prototypes / output dim (default 65536).
temperature_student – Student softmax temperature (default 0.1).
temperature_teacher_warmup – Teacher temp at start (default 0.04).
temperature_teacher – Teacher temp after warmup (default 0.07).
warmup_epochs_temperature_teacher – Linear warmup epochs (default 30).
center_momentum – EMA momentum for the teacher centering (default 0.9).
ema_decay_start – Initial backbone/projector EMA (default 0.996).
ema_decay_end – Final EMA (default 1.0).
encoder_kwargs – Extra kwargs forwarded to
timm.create_model.pretrained – Load pretrained timm weights for the encoder.
Example:
model = DINO("vit_small_patch16_224", encoder_kwargs={"dynamic_img_size": True}) global_views = [torch.randn(8, 3, 224, 224), torch.randn(8, 3, 224, 224)] local_views = [torch.randn(8, 3, 96, 96) for _ in range(6)] out = model(global_views=global_views, local_views=local_views) out.loss.backward()
- forward(global_views: Sequence[Tensor] | None = None, local_views: Sequence[Tensor] | None = None, images: Tensor | None = None) DINOOutput[source]#
Forward pass.
- Parameters:
global_views – List of
n_globaltensors[B, C, H, W](e.g. two 224x224 crops). Required in training mode.local_views – List of
n_localtensors[B, C, h, w](e.g. six 96x96 crops). Optional.images – Single batch of images for evaluation. If supplied, returns the teacher CLS embedding only.
- Returns:
DINOOutput.