Gating – Routing tokens to experts

MoE 中比较重要的是将 tokens 路由到不同的 expert。考虑一个 token,是一个 d 维的 vector,通过

代码实现如下,经过 Attention 当前的 tensor shape 为 (B, C, d),其中 router 为 linear 层 (d, n_exp) ,输出为 (B, C, n_exp),表示对于每个 token,路由到不同 expert 的权重,也就是上图的 Router Logits。

接下来经过 Softmax,

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch import nn
from torch.nn import functional as F

class BasicSoftmaxRouter(nn.Module):
    def __init__(
        self,
        d, 
        n_exp = 8,
        top_k = 2,
        use_noisy_top_k = True,
    ):
        """
        Arguments:
        d: size of embedding dimension
        n_exp: the number of experts to create in the expert layer
        top_k: the number of active experts for each token
        use_noisy_top_k: whether to add noise when computing expert output
        """
      
        super().__init__()

        # router settings
        self.top_k = top_k
        assert self.top_k >= 1 and self.top_k <= n_exp
        self.use_noisy_top_k = use_noisy_top_k

        # linear projection for (noisy) softmax routing
        # no bias used, see page 4 eq (4) in https://arxiv.org/abs/1701.06538
        self.w_g = nn.Linear(d, n_exp, bias=False)
        self.w_noise = nn.Linear(d, n_exp, bias=False) if self.use_noisy_top_k else None

    def forward(self, x):
        # eq (4) in https://arxiv.org/abs/1701.06538
        logits = self.w_g(x)  # [B, C, d] -> [B, C, n_exp]
        if self.use_noisy_top_k:
            # (optionally) add noise into the router
            noise = F.softplus(self.w_noise(x))
            noise *= torch.randn_like(noise)
            logits += noise
        top_k_logits, top_k_indices = logits.topk(self.top_k, dim=-1) # [B, C, k]
        return top_k_logits, top_k_indices

Megatron-LM

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class MoELayer(BaseMoELayer):

   def forward(self, hidden_states: torch.Tensor):
        """Forward pass for the MoE layer.

        The forward pass comprises four main steps:
        1. Routing & Preprocessing: Route tokens to the assigned experts and prepare for dispatch.
        2. Dispatch: Tokens are sent to the expert devices using communication collectives.
        3. Expert Computation: Experts process the dispatched tokens.
        4. Combine: The outputs from the experts are combined and returned.

        Args:
            hidden_states (torch.Tensor): The input tensor to the MoE layer.

        Returns:
            A tuple containing the output tensor and the MLP bias, if any.
        """

        hidden_states, probs, residual = self.router_and_preprocess(hidden_states)
        dispatched_input, probs = self.dispatch(hidden_states, probs)
        output, shared_expert_output, mlp_bias = self.experts_compute(
            dispatched_input, probs, residual
        )
        output = self.combine(output, shared_expert_output)

        return output, mlp_bias

Routing

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
    def router_and_preprocess(self, hidden_states: torch.Tensor):
        """Compute and preprocess token routing for dispatch.

        This method uses the router to determine which experts to send each token to,
        producing routing probabilities and a mapping. It then preprocesses the
        hidden states and probabilities for the token dispatcher. The original
        hidden states are returned as a residual connection.
        """
        residual = hidden_states
        probs, routing_map = self.router(hidden_states)
        hidden_states, probs = self.token_dispatcher.dispatch_preprocess(
            hidden_states, routing_map, probs
        )
        return hidden_states, probs, residual

Dispatch

1
2
3
4
5
6
7
    def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
        """Dispatches tokens to assigned expert ranks via communication.
        This method performs the actual communication (e.g., All-to-All) to distribute
        tokens and their associated probabilities to the devices hosting their assigned
        experts.
        """
        return self.token_dispatcher.token_dispatch(hidden_states, probs)

Expert Compute

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
    def experts_compute(
        self, hidden_states: torch.Tensor, probs: torch.Tensor, residual: torch.Tensor
    ):
        """Computes the output of the experts on the dispatched tokens.

        This method first post-processes the dispatched input to get permuted tokens
        for each expert. It then passes the tokens through the local experts.
        If a shared expert is configured and not overlapped with communication,
        it is also applied. The output from the experts is preprocessed for the
        combine step.
        """
        shared_expert_output = None
        if self.use_shared_expert and not self.shared_expert_overlap:
            # Compute the shared expert separately when not overlapped with communication.
            shared_expert_output = self.shared_experts(residual)
        dispatched_input, tokens_per_expert, permuted_probs = (
            self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
        )
        expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert, permuted_probs)
        assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}"
        output = self.token_dispatcher.combine_preprocess(expert_output)

        return output, shared_expert_output, mlp_bias

Combine

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
    def combine(self, output: torch.Tensor, shared_expert_output: Optional[torch.Tensor]):
        """Combines expert outputs via communication and adds shared expert output.

        This method uses the token dispatcher to combine the outputs from different
        experts (e.g., via an All-to-All communication). It then adds the output
        from the shared expert if it exists.
        """
        output = self.token_dispatcher.token_combine(output)
        output = self.token_dispatcher.combine_postprocess(output)
        if shared_expert_output is not None:
            output = output + shared_expert_output
        return output

Expert Parallelism

  1. Gating function: decide target experts for each token
  2. Dispatch Phase
    1. Layout transformation: tokens to the same target experts are grouped in a continuous memory buffer
    2. Alltoall: dispatch tokens to their corresponding experts
  3. Expert Compute: each expert process its tokens
  4. Combine Phase
    1. Combine processed tokens batch to their GPUs
    2. Layout transform: restore tokens to their original positions

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def moe_ep(self, x, topk_ids):
        cnts = topk_ids.new_zeros((topk_ids.shape[0], self.n_routed_experts))
        cnts.scatter_(1, topk_ids, 1)
        tokens_per_expert = cnts.sum(dim=0)
        idxs = topk_ids.view(-1).argsort()
        sorted_tokens = x[idxs // self.num_experts_per_tok]
        if self.ep_size > 1:
            tokens_per_expert_group = torch.empty_like(tokens_per_expert)
            dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)
            output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(dim=1).cpu().tolist()
            input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1).cpu().tolist()
            gathered_tokens = All2All.apply(sorted_tokens, output_splits, input_splits, self.ep_group)
            gatherd_idxs = idxs.new_empty(gathered_tokens.shape[0], device="cpu")
            s = 0
            for i, k in enumerate(tokens_per_expert_group.cpu()):
                gatherd_idxs[s : s + k] = i % self.experts_per_rank
                s += k
            gatherd_idxs = gatherd_idxs.to(idxs.device).argsort()
            sorted_tokens = gathered_tokens[gatherd_idxs]
            tokens_per_expert = tokens_per_expert_group.view(self.ep_size, -1).sum(dim=0)
        tokens_per_expert = tokens_per_expert.cpu().numpy()

        outputs = []
        start_idx = 0
        for i, num_tokens in enumerate(tokens_per_expert):
            if num_tokens == 0:
                continue
            end_idx = start_idx + num_tokens
            expert = self.experts[i + self.ep_rank * self.experts_per_rank]
            outputs.append(expert(sorted_tokens[start_idx:end_idx]))
            start_idx = end_idx

        outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
        if self.ep_size > 1:
            sorted_tokens = torch.empty_like(outs)
            sorted_tokens[gatherd_idxs] = outs
            gathered_tokens = All2All.apply(sorted_tokens, input_splits, output_splits, self.ep_group)
            outs = gathered_tokens

        y = torch.empty_like(outs)
        y[idxs] = outs
        return y

Transformers 代码阅读

参考资料