FlashAttention
flash_attn_func
|
|
flash_attn_varlen_func
主要记录 flash_attn.flash_attn_varlen_func 这个接口的使用, 精髓在于理解函数签名和输入形状: 函数签名需要每个 seq 的 offset, 输入形状需要 (bs, seqlen) 平坦化后的 (total_num, nhead, headdim)
|
|
需要注意使用 causal 这个参数才能进入 causal 模式
|
|
值得注意的是 qkv 输入形状上需要是 (total_num, nheads, headdim) 而不是 (batch_size, seqlen, nheads, headdim), 和 flash_attn_func 是不同的。这是因为在变长 batching 中把 bs * seqlen 打平展开了,然后再结合 offset 去找到每个 batch 的其实位置做计算。
最重要的参数就是 cu_seqlens_q 和 cu_seqlens_k, 用于记录找到每个 batch 需要的 offset。比如 seq0的 offset=0, seq1的 offset=seq0.len, seq2的 offset=seq0.len+seq1.len, 因此就是一个不包含自身的前缀和, 可以通过 torch.cumsum 减去各自的 seqlen 获得:
|
|
demo
|
|
Linked Mentions
-
No backlinks found.