flash_attn_func

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
                window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Arguments:
    q: (batch_size, seqlen, nheads, headdim)
    k: (batch_size, seqlen, nheads_k, headdim)
    v: (batch_size, seqlen, nheads_k, headdim)
    dropout_p: float. Dropout probability.
    softmax_scale: float. The scaling of QK^T before applying softmax.
        Default to 1 / sqrt(headdim).
    causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
    window_size: (left, right). If not (-1, -1), implements sliding window local attention.
    alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
        (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
        is added to the attention score of query i and key j.
    deterministic: bool. Whether to use the deterministic implementation of the backward pass,
        which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
    out: (batch_size, seqlen, nheads, headdim).
"""

flash_attn_varlen_func

主要记录 flash_attn.flash_attn_varlen_func 这个接口的使用, 精髓在于理解函数签名输入形状: 函数签名需要每个 seq 的 offset, 输入形状需要 (bs, seqlen) 平坦化后的 (total_num, nhead, headdim)

1
from flash_attn import flash_attn_varlen_func

需要注意使用 causal 这个参数才能进入 causal 模式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
def flash_attn_varlen_func(  
    q, # (total_q, nheads, headdim)  
    k, # (total_k, nheads, headdim)  
    v, # (total_v, nheads, headdim)  
    cu_seqlens_q, # (batch_size + 1)  
    cu_seqlens_k, # (batch_size + 1)  
    max_seqlen_q, # 所有序列中最长的q的长度  
    max_seqlen_k, # 所有序列中最长的k的长度  
    dropout_p=0.0,  
    softmax_scale=None,  
    causal=False,  
    window_size=(-1, -1),  # -1 means infinite context window  
    return_attn_probs=False,  
):  
    pass

值得注意的是 qkv 输入形状上需要是 (total_num, nheads, headdim) 而不是 (batch_size, seqlen, nheads, headdim), 和 flash_attn_func 是不同的。这是因为在变长 batching 中把 bs * seqlen 打平展开了,然后再结合 offset 去找到每个 batch 的其实位置做计算。

最重要的参数就是 cu_seqlens_qcu_seqlens_k, 用于记录找到每个 batch 需要的 offset。比如 seq0的 offset=0, seq1的 offset=seq0.len, seq2的 offset=seq0.len+seq1.len, 因此就是一个不包含自身的前缀和, 可以通过 torch.cumsum 减去各自的 seqlen 获得:

1
prefill_start_pos = torch.cumsum(seq_len, dim=0, dtype=torch.int32) - seq_len

demo

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch  
from flash_attn import flash_attn_varlen_func, flash_attn_func  
  
def main():  
    dtype = torch.bfloat16  
    HEAD = 2  
    HEAD_DIM = 2  
    seqlens = [1, 2, 3, 4]  
    query = torch.empty(0, HEAD, HEAD_DIM, dtype=dtype).cuda()  
    key = torch.empty(0, HEAD, HEAD_DIM, dtype=dtype).cuda()  
    value = torch.empty(0, HEAD, HEAD_DIM, dtype=dtype).cuda()  
  
    querys = []  
    keys = []  
    values = []  
    for l in seqlens:  
        q = torch.rand(l, HEAD, HEAD_DIM, dtype=dtype).cuda()  
        k = torch.rand(l, HEAD, HEAD_DIM, dtype=dtype).cuda()  
        v = torch.rand(l, HEAD, HEAD_DIM, dtype=dtype).cuda()  
        querys.append(q)  
        keys.append(k)  
        values.append(v)  
        query = torch.cat([query, q], dim=0)  
        key = torch.cat([key, k], dim=0)  
        value = torch.cat([value, v], dim=0)  
  
    print("===Standard===")  
    for q, k, v in zip(querys, keys, values):  
        q = q.unsqueeze(0)  
        k = k.unsqueeze(0)  
        v = v.unsqueeze(0)  
        out = flash_attn_func(q, k, v)  
        print(out)  
    print("=========\n")  
  
    seq_len = torch.tensor(seqlens, dtype=torch.int32).cuda()  
    # NOTE: flash_attn_varlen_func这个接口需要(bs + 1)长度的cu_seqlens_q和cu_seqlens_k  
    prefill_start_pos = torch.cumsum(seq_len, dim=0, dtype=torch.int32) - seq_len  
    prefill_start_pos = torch.cat([prefill_start_pos, torch.tensor([torch.sum(seq_len)], dtype=torch.int32, device="cuda")], dim=0)  
    print(prefill_start_pos.shape)  
    print(prefill_start_pos)  
  
    print(query.shape, key.shape, value.shape)  
    cu_seqlens_q = prefill_start_pos  
    cu_seqlens_k = prefill_start_pos  
    max_seqlen_q = max(seqlens)  
    max_seqlen_k = max(seqlens)  
  
    out = flash_attn_varlen_func(query, key, value, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)  
    acc = 0  
  
    print("===Varlen===")  
    for l in seqlens:  
        print(out[acc:acc+l])  
        acc += l  
    print("=========\n")  
  
if __name__ == "__main__":  
    main()