Contribution 세미나/Knowledge distillation

[논문세미나] Knowledge Distillation via the Target-aware Transformer

PaperGPT 2024. 1. 21. 17:36

CVPR 2022, 65회 인용

Introduction

 

기존의 feature기반의 knowledge distillation의 경우 one-to-one spatial matching을 한다.

(feature map에서 같은 위치는 같은 값을 가지도록)

다만 teacher와 student 구조의 경우 같은 feature resolution을 가지더라도 receptive field가 다르다.

따라서 서로 보는 영역이 다르기 때문에 semantic mismatch를 일으키고 이는 sub-optimal한 결과를 가져온다.

 

본 논문은 이를 해결하기 위해 one-to-all spatial matching을 제안한다.

이를 target-aware transformer라고 부른다.

 

Contribution

1. Target-aware transformer 제안, teacher 모델 각각의 spatial components가 student feature map 전체에 영향을 미친다.

2. Hierarchical distillation을 사용하여 연산량을 줄인다.

 

Proposed method

기존 distillation loss (logit)

 

3d feature map -> 2d feature map

feature matching

 

 

Target-aware Transformer (TaT) 

 

teacher와 student feature의 유사도로 weight를 구함

해당 weight로 새로운 student feature map 업데이트

해당 feature map과 teacher feature map 사이의 loss 계산

 

여기서 더 나아가 각각의 feature에 linear function을 적용하여 성능을 향상 시키려는 시도

 

이를 parametric TaT로 명칭.

 

전체 loss

 

그런데 feature map의 크기가 너무 크면 연산량이 너무 커진다.
(하나의 spatial feature를 업데이트 하기 위해 전체 feature와의 유사도를 구해야 함)

 

이러한 문제를 완화시키기 위해 patch-group distillation 사용

기본적인 컨셉은 유사도를 구할때 전체 feature map을 보는게 아닌 patch 단위로 계산한다.

다만 이렇게 되면 patch 사이의 유사도는 고려가 되지 않기 때문에 여러 patch를 group으로 지정하여 group 단위로 유사도를 계산 어느정도는 patch 사이의 유사도가 고려된다. 

 

하지만 그렇게 해도 global 정보가 고려가 덜 된다.

이를 보정하기 위해 anchor-point distillation을 사용한다.

이는 average pooling을 사용하여 patch를 구성하고, dimension을 줄이면서 전체 feature를 고려하는 효과를 가져온다.

 

 

Experiment result

 

 

기존 KD 방식들에 비해 좋은 성능을 보여줌

 

 

 

Linear function을 적용하는 방식에 따른 성능 비교, teacher feature에는 linear function을 적용하지 않는게 더 성능이 좋다.

 

 

TaT에 대한 loss weight에 따른 성능 비교

 

 

TaT 방식에 따른 성능 비교

 

 

Group 개수에 따른 성능 비교