BEiT#
- class stable_pretraining.methods.BEiT(encoder_name: str | Module = 'vit_small_patch16_224', tokenizer: Callable[[Tensor], Tensor] | None = None, vocab_size: int = 8192, patch_size: int = 16, mask_ratio: float = 0.4, image_size: int = 224, pretrained: bool = False)[source]#
Bases:
ModuleBEiT masked image modeling with a discrete visual tokenizer.
- Parameters:
encoder_name – timm ViT name (default
"vit_small_patch16_224").tokenizer – Callable
images -> [B, N] int64returning visual token IDs. IfNone, defaults topatch_kmeans_tokenizer()(placeholder; not SOTA).vocab_size – Number of visual tokens (default 8192, matches DALL-E).
patch_size – Patch size of the encoder (default 16).
mask_ratio – Fraction of patches masked (default 0.4, BEiT used 0.4).
image_size – Input size (default 224).
pretrained – Load pretrained timm weights.
- forward(images: Tensor) BEiTOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output