본문 바로가기

딥러닝/트랜스포머

[pytorch] forward를 수정하고 싶다면 - forward hook 예제 코드

pytorch 관련 모듈

  • (설명 안 함, 사용 안 해 봄) register_module_forward_pre_hook
  • register_module_forward_hook

Background

  • Llama 모델의 forward(feed forward layer) 과정에서 activation value를 확인하고 싶은데 기존에 공개된 논문의 코드에서는 vllm+MethodType을 이용하여 forward에 접근 중
  • vllm에서 llama3를 사용하기에는 아직 부족한 점이 있는 듯(내가 부족하다는 뜻)하여 이를 transformers 라이브러리로 변경
  • transformers 라이브러리만 사용하는 코드로 변경하였을 때, torch hook을 사용해서 forward에 접근
  • 처음으로 torch hook을 사용하면서 간단히 사용법을 정리하는 목적으로 작성

pytorch 모듈 설명: register_module_forward_hook

1. 실행 타이밍, 목적

  • 모든 모듈에 걸쳐 전역 상태 변경을 도입하므로 주로 디버깅이나 프로파일링 목적으로 사용
  • 이 후크는 모든 모듈의 forward() 메서드가 호출된 후에 실행되어,
    사용자가 모듈의 forward() 메서드의 출력을 검사하거나 수정할 수 있도록 함

 

2. 예제 코드

def hook_fn(idx):
    def llama_forward(self, inputs, outputs): 
        activation = self.act_fn(self.gate_proj(inputs[0])) 
        activation = activation.float()
        over_zero[idx, :] += (activation > 0).sum(dim=(0,1))
        return outputs
    return llama_forward

# register hooks for each layer
hooks = []
for idx, layer in enumerate(model.model.layers):
    if hasattr(layer.mlp, 'gate_proj'): # for llama model
        hook = layer.mlp.register_forward_hook(hook_fn(idx))
    else:
        hook =None #XXX: for other model
    hooks.append(hook)
  • forward의 결과를 받고 처리할 함수 hook_fn 정의
  • hook을 등록할 hooks 리스트 정의
  • model.model.layers 모델의 레이어를 순회(for)하면서 layer.mlp(LlamaMLP)에 등록(layer.mlp.register_forward_hook(hook_fn(idx)))
    • 여기서 idx를 입력으로 넣는 건 hook_fn에서 따로 내가 사용하기 위함
  • hooks.append(hook) hook 추가
  • self로 LlamaMLP class 객체에 바로 접근 가능
    • activation = self.act_fn(self.gate_proj(inputs[0])) activation function으로 value를 계산
    • 해당 부분은 아래의 실제 트랜스포머 라이브러리에서 확인 가능
class LlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) ## activation value
        return down_proj
  • forward 함수를 보면 down_proj 전, 활성화 값을 계산
  • 이는 중간 레이어를 업데이트 하기 전 차원을 intermediate_size 만큼 up하고, 다시 hidden_size로 down하는 과정
  • 차원을 down 하기 전, 활성화 값을 계산하는 부분을 hook_fn에서 이용 가능

 

 

 


references

- https://pytorch.org/docs/stable/generated/torch.nn.modules.module.register_module_forward_hook.html