NEPA

Contents

NEPA#

class stable_pretraining.methods.NEPA(img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_rope: bool = True, use_qk_norm: bool = True, use_swiglu: bool = True, layer_scale_init: float = 1e-05)[source]#

Bases: Module

NEPA: Next-Embedding Predictive Autoregression.

Uses standard TransformerBlock with modern options enabled:
  • use_rope=True: 2D Rotary Position Embedding

  • use_qk_norm=True: Query-Key normalization

  • mlp_type='swiglu': Gated MLP activation

  • use_layer_scale=True: Residual scaling

Causal masking is applied via attn_mask during training.

forward(images: Tensor) NEPAOutput[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output