Quick Start

This guide walks through the key functionality of Slapo. We will use the BERT model in HuggingFace Hub as an example and leverage Slapo to optimize its performance.

Optimize PyTorch model with Slapo

We first import the Slapo package. Make sure you have already installed the PyTorch package.

import slapo
import torch

We load a BERT model implemented in PyTorch from HuggingFace Hub.

from transformers import BertLMHeadModel, AutoConfig

config = AutoConfig.from_pretrained("bert-large-uncased")
model = BertLMHeadModel(config)
Downloading (…)lve/main/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]
Downloading (…)lve/main/config.json: 100%|##########| 571/571 [00:00<00:00, 146kB/s]
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): GELUActivation()
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  (cls): BertOnlyMLMHead(
    (predictions): BertLMPredictionHead(
      (transform): BertPredictionHeadTransform(
        (dense): Linear(in_features=1024, out_features=1024, bias=True)
        (transform_act_fn): GELUActivation()
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (decoder): Linear(in_features=1024, out_features=30522, bias=True)

After we have the model defintion, we can create a schedule and optimize it. Slapo provides an apply_schedule API for users to directly apply a predefined schedule to the model. By default, the schedule will inject the Flash Attention kernel, conduct tensor parallelism, and fuse the operators. Users can also customize the schedule by passing in the schedule configurations like data type (fp16/bf16) or checkpoint ratio. Detailed schedule configurations can be found in slapo.model_schedule.

After applying the schedule, we can build the optimized model by calling slapo.build. Here we explicitly pass in the _init_weights function of HuggingFace models to initialize the parameters of the optimized model.

def apply_and_build_schedule(model, config):
    from slapo.model_schedule import apply_schedule

    sch = apply_schedule(
        model, "bert", model_config=config, prefix="bert", fp16=True, ckpt_ratio=0
    opt_model, _ = slapo.build(sch, init_weights=model._init_weights)
    return opt_model

The optimized model is still a PyTorch nn.Module, so we can pass it to the PyTorch training loop as usual.

def train(model, device="cuda", bs=8, seq_length=512):
    input_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    attention_mask = torch.ones(bs, seq_length, dtype=torch.float16, device=device)
    token_type_ids = torch.ones(bs, seq_length, dtype=torch.long, device=device)
    labels = input_ids.clone()
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    for step in range(100):
        inputs = (input_ids, attention_mask, token_type_ids)
        loss = model(*inputs, labels=labels).loss

        if step % 10 == 0:
            print(f"step {step} loss: {loss.item()}")

Total running time of the script: ( 0 minutes 4.989 seconds)

Gallery generated by Sphinx-Gallery