Search

xformers로 메모리 절약하기

History

2023.11.13.
초안 작성

들어가기전에

근래에는 자연어 처리, 컴퓨터 비전을 가리지 않고 attention 혹은 transformer가 잘 활용되고 있다. 이로 인해 Flash Attention, xformers와 같은 efficient attention operation에 관련된 연구들이 활발히 진행되었으며 오늘 이 글에서는 xformers를 활용하는 방법에 대해서 소개하고자 한다.
Flash Attention은 후속 연구도 있고 Pytorch2.0에 통합도 되었지만 제한적인 GPU에서만 현재 지원되고 있다. 반면 xformers는 pytorch 위에서 구현된 library로 풀고자하는 문제나 모델의 디자인에 따라 성능이 떨어질 위험이 있으나 pytorch가 구동되는 곳이라면 어디서든 사용가능하다는 장점이 있다.

Example 돌려보기

""" Forked from https://github.com/facebookresearch/xformers/blob/main/HOWTO.md#blocksparseattention """ import torch from xformers.components import MultiHeadDispatch from xformers.components.attention import BlockSparseAttention BATCH = 4 HEADS = 4 SEQ = 4096 EMB = 512 BLOCK_SIZE = 32 DROPOUT = 0.0 dtype = torch.float16 # Let's try out a causal mask, but really it could be anything "block sparse enough" causal_mask = torch.tril(torch.ones((SEQ, SEQ), device=torch.device("cuda"), dtype=dtype)) blocks = SEQ // BLOCK_SIZE causal_layout = torch.tril(torch.ones([HEADS, blocks, blocks], dtype=torch.bool)) # Let's build our blocksparse attention. Please note that the layout can be # [SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] or [HEADS, SEQ//BLOCK_SIZE, SEQ//BLOCK_SIZE] # so that _you can pass a different layout per head_ attention = BlockSparseAttention(layout=causal_layout, block_size=BLOCK_SIZE, dropout=DROPOUT, num_heads=HEADS) # Out of commodity, let's build our multihead attention now # "multi_head" will be responsible for the forward multi_head = ( MultiHeadDispatch( dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attention, ) .cuda() .half() ) # Now FW some random data # Note that passing a per-coefficient mask makes it possible to remove extra coefficients, # which where required by the blockification query = torch.randn((BATCH, SEQ, EMB), requires_grad=True, device=torch.device("cuda"), dtype=dtype) # Self attention in this particular example, no limitations really att_val = multi_head(query=query, key=query, value=query)#, att_mask=causal_mask) ######################################### # Bonus: compare the memory use vs dense: def mem_use(fn, kwargs, title): # bookeeping import time start = time.time() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # actually run the function fn(**kwargs) torch.cuda.synchronize() stop = time.time() # now report max_memory = torch.cuda.max_memory_allocated() // 2 ** 20 print(f"{title} - Peak memory use: {max_memory}MB - {round((stop-start)*1e6)/1e3}ms") pytorch_multihead = torch.nn.MultiheadAttention( EMB, HEADS, batch_first=True, device=torch.device("cuda"), dtype=torch.float16 ) mem_use(multi_head, {"query": query, "key": query, "value": query}, "Blocksparse") mem_use(pytorch_multihead, {"query": query, "key": query, "value": query}, "PyTorch")
Python
복사
결과는 아래와 같다
Blocksparse - Peak memory use: 980MB - 9.827ms PyTorch - Peak memory use: 1504MB - 11.18ms
Plain Text
복사

Class로 사용

from torch import nn from xformers.components import MultiHeadDispatch from xformers.components.attention import BlockSparseAttention class EfficientMHA(nn.Module): def __init__(self, embed_dim, num_heads, seq_len, block_size): super().__init__() self.attention_layer = MultiHeadDispatch( dim_model=embed_dim, num_heads=num_heads, attention=BlockSparseAttention( layout=torch.tril(torch.ones([num_heads, seq_len // block_size, seq_len // block_size], dtype=torch.bool)), block_size=block_size, num_heads=num_heads, ) ) self.norm = nn.LayerNorm(embed_dim) def forward(self, key, query, value): output = self.attention_layer( key=key, query=query, value=value, ) output = output + query output = self.norm(output) return output
Plain Text
복사