NEPA#
- class stable_pretraining.methods.NEPA(img_size: int = 224, patch_size: int = 14, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, use_rope: bool = True, use_qk_norm: bool = True, use_swiglu: bool = True, layer_scale_init: float = 1e-05)[source]#
Bases:
ModuleNEPA: Next-Embedding Predictive Autoregression.
- Uses standard TransformerBlock with modern options enabled:
use_rope=True: 2D Rotary Position Embeddinguse_qk_norm=True: Query-Key normalizationmlp_type='swiglu': Gated MLP activationuse_layer_scale=True: Residual scaling
Causal masking is applied via
attn_maskduring training.- forward(images: Tensor) NEPAOutput[source]#
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output