본문 바로가기

학부/논문 리뷰

Learning to Prompt for Vision-Language Models (CoOp 논문 리뷰)

728x90
반응형

[논문 링크]

https://arxiv.org/pdf/2109.01134

 

이번 논문은 Prompt learning과 관련된 논문을 리뷰하고자 한다.

"CoOp"라고 많이들 알고 있을텐데, prompt를 learnable paramter로 설정하여 class에 맞는 알맞은 prompt를 자동으로 학습할 수 있도록 만드는 것을 의미한다. 크게 어려운 내용은 없었던 것 같다.

키워드는 읽는 데에 도움이 되는 정도로만 가볍게 참고하자.

 

Keyword: CoOp pipeline, learnable prompt vectors

 

 


Abstract


  • CLIP과 같은 large pre-trained vision-language model은 다양한 downstream task에서 transferable한 representation을 학습
  • traditional representation learning에서는 주로 discretized label들을 바탕으로 학습을 진행했지만, VLM은 text와 image를 common feature space에서 align
    • 전통적인 방법은 ‘image’ - ‘label’ 과 같은 방식으로 제시하지만, VLM에서는 이런 구분 없이 text와 image를 같은 space에서 정의
  • 이와 같은 방법은 prompting를 통해 downstream task를 위한 zero-shot transfer를 가능하게 만든다.
    • 관심 있는 class에 대해서 자연어로 묘사하는 방식
  • prompt engineering은 이러한 모델을 실생활에서 효율적으로 사용하는 데에 있어서 주된 challenge가 된다.
    • domain에 맞는 expertise도 필요하고 무엇보다 time-consuming이 크다.
    • 한 사람이 tuning하는 데에만 엄청난 시간이 걸리는 데, 이는 wording에서 조금이라도 변화가 생기면 결과에 큰 영향을 주기 때문이다.
  • 이러한 NLP 분야에서의 prompt learning search에서 본 논문은 **‘Context Optimization (CoOp)’**를 제안한다.
    • downstream image recognition을 위해 CLIP과 같은 VLM 모델에 적용할 수 있는 간단한 approach
    • 조금 더 구체적으로 말하자면 CoOp는 prompt의 context words를 일종의 learnable한 vector로 모델링 하는 것을 의미한다. 이 때, pre-trained parameter들은 모두 보존한다.

2가지의 CoOp implementations을 제공한다.

  • Unified context
  • class-specific context

learning based approach임에도 불구하고 hand-crafted prompt를 사용하는 zero-shot model과 비교했을 때, 최고의 성능을 자랑한다.

 

 


1. Introduction


  • 기존 vision model을 훈련할 때는 ‘discrete label’을 통해 category가 고정된 채로 학습을 진행
    • image feature와 fixed set of weight를 matching
    • 대부분 label은 text 형태로 되어 있지만, 실제로는 cross-entropy를 줄여주는 어떤 특정한 discrete label로 변환해서 사용
      • 따라서 새로운 데이터나 closed-set visual concept에 대해서 한정된 recognition system을 갖는다.
  • CLIP과 ALIGN 등의 vision-language model이 visual representation을 학습하는 방식으로 모델링
    • raw text와 image를 2개의 encoder를 사용해서 같은 feature space로 align
    • pair라면 가까이 (pull) , pair가 아니라면 멀리 (push) 하는 방식
  • large-scale로 사전 학습을 하면, 모델은 다양한 visual concept을 학습하고 어떠한 downstream task에 대해서도 prompting을 통해 쉽게 transfer할 수 있다.
    • 예를 들어, task-relevant category들을 묘사할 수 있는 문장을 text encoder에 제공하고 이를 바탕으로 image encoder에서 나온 image feature와 비교하며 classification을 수행할 수 있다.

2,3번째의 prompt 결과 참고

  • 사전 학습된 VLM은 text input인 prompt가 핵심적인 역할
    • 하지만 right prompt를 찾는 것은 쉽지 않은 일
    • (a)에서 2,3번째의 prompt 결과를 참고해보자. 조금만 바뀌어도 결과가 많이 달라지며 시간이 많이 소요되는 작업이다.
    • 게다가 prompt engineering은 task에 대한 사전 지식이 필요하며 언어 모델의 mechanism 또한 이해하고 있어야 한다.

b~d까지의 사진을 통한 예시

  • 결국 어떤 prompt가 정말 optimal한 지에 대해서 guarantee할 수 없음을 의미.
  • 따라서 자동으로 prompt engineering을 해줄 수 있는 Context Optimization (CoOp) 제시
    • 특히, 사전 학습된 vision-language model에 대해.

Learnable vector의 집합을 사용하여 prompt’s context를 modeling

핵심적인 내용은, learnable vector의 집합을 사용하여 prompt’s context를 modeling한다는 점이다.

  • 전체적인 pipeline
    1. random value / pre-trained word embedding으로 시작
    2. 2가지의 implementations을 따로 제공
      1. Unified context : 대부분의 category에 잘 작동하며, 모든 class에 동일한 prompt를 사용
      2. Class-specific context : 각 class에 대해서 특정한 context token 집합을 사용하여 더 fine-grained category에 적합하다.
    3. 고정된 entire pre-trained parameters를 유지하면서 learnable context vectors와의 cross-entropy를 최소화 하는 방식으로 훈련.
      1. gradient는 text encoder를 통해 모든 방향으로 back-propagated
      2. task-relevant context를 학습하기 위해 parameter 내로 encoding된 rich knowledge를 distillation

요약하자면, 다음과 같은 contribution이 존재한다.

  • downstream application과 관련되어 최근 제시된 VLM에 대한 timely study + efficiency와 관련된 문제점 지적 (prompt engineering)
  • prompt를 자동화 하기 위해 간단한 접근 방식을 제시 → 2가지 implementation
  • 제시된 prompt learning-based approach가 성능이 좋다는 것을 제시


3. Methodology


 

3.1 Vision-Language Pre-training

CLIP에 초점을 맞춰 VLM의 pre-training 방법을 간단하게 소개하며, 이는 CLIP과 같은 VLM에서 넓게 사용할 수 있다.

[Models]

  • CLIP은 2개의 encoder로 이루어져 있다.
    • image encoder : high-dimensional images → low-dimensional embedding space
      • ResNet-50과 같은 CNN 형태의 image encoder
    • text encoder : Transformer를 기반으로 built, 자연어로부터 text representation을 생성하는 것에 초점
  • sequence of words가 주어지면 CLIP은 각 토큰을 lower-cased BPE representation으로 변환한다.
    • unique numeric ID
  • 이 후, mini-batch processing을 용이하게 하기 위해서 각 text sequence를 [SOS]와 [EOS] token으로 감싼다.
  • IDs → 512-D word embedding vectors → transformers
  • 결론적으로 [EOS] token position에 있는 features들이 normalized, linear projection layer를 거쳐 나온다.

[Training]

CLIP은 text와 image에 대한 각각의 embedding space를 align하는 방법을 학습

contrastive loss로 구성된 objective function을 학습

image-text pair가 주어지면, 일치하는 pair에는 cosine similarity를 최대화 하고 일치하지 않는 pair는 최소화 하는 방식으로 학습

[Zero-Shot Inference]

  • CLIP 자체가 image가 텍스트로 된 description과 일치하는 지 안 하는 지에 대해 예측하는 것이므로, zero-shot recognition과 naturally fit함.
    • 특정 class에 대한 textual description을 처리하기 위한 text encoder에 의해 합성된 classification weight를 image feature와 비교함으로써 수행된다.

  • prediction probability는 (1)과 같이 계산된다.
    • $\bold{f}$ : image encoder로부터 추출된 image features
    • $x$ : image
    • $w_i$ : text encoder로부터 생성된 weight vector
      • ‘A photo of a [CLASS]’ 형태의 prompt로부터 derive
    • K : class의 개수

결국, 기존 classifier 학습 방법과 비교했을 때, vision-language pre-training은 high-capacity를 가지는 text encoder를 통해 open-set visual concept을 가능하게 했으며, 이는 결과적으로 downstream task에 대해 더 잘 transferable 하게 됨.

 

여기서, 기존 방법들은 random vector로부터 학습된다고 말하고 있다. 내 생각으로는 기존 방법들은 random하게 initialize되지만, CLIP은 어떤 ‘a photo of []’와 같은 형태로 들어간 text embedding을 사용하면 고정된 weight 값으로 넣을 수 있으므로, 그 차이점을 언급하고자 한 의도가 아닐까, 생각해본다.

3.2 Context Optimization

  • context words를 continuous vectors로 modeling하여 manual prompt tuning을 방지한다.
    • 방대한 pre-trained parameter들을 frozen한 상태로 data로부터 end-to-end 방식으로 continuous vector를 학습 시킨다.
    • Figure 2를 다시 한번 참고.

 

[Unified Context]

모든 class에 대해서 같은 context를 공유하는 unified-context version을 먼저 소개

g(-)를 text encoder로 가정.

$$ t = [V]_1,[V]_2 . . . [V]_M[CLASS], \tag{2} $$

  • $[V]_m$은 word embedding으로, same dimension에 있는 vector를 의미한다.
  • M은 context token의 개수로 hyperparameter에 해당한다.

prompt, $\bold{t}$가 text encoder $g(-)$를 통과하면, 우리는 visual concept을 나타내는 classification weight vector를 얻을 수 있다. (여전히 [EOS] token position) → $g(t)$가 w가 되는 것.

 

따라서 prediction probability는 다음과 같이 계산된다.

각 prompt, t_i안에 있는 class token ([CLASS])은 i번째 class 이름의 word embedding vector로 대체된다. Eq.(2)에서 sequence 끝에 [CLASS] token을 배치하는 것 외에도,

$$ t = [V]_1...[V]_{M/2}[CLASS] [V]_{M/2+1}...[V]_M, \tag{4} $$

 

이런 방식으로도 [CLASS] token을 중간에 집어넣을 수 있다.

  • 이는 학습에 있어서 flexibility를 늘린다.
    • prompt가 추가적인 description과 함께 마지막 cell을 채워 넣거나 아니면 termination signal을 사용하여 sentence를 일찍 끊어서 full stop처럼 사용하는 것
    • 이 문장의 의미는 Class token이 어디에 들어가느냐에 따라 성능이 바뀔 수도 있다는 여지를 제시하는 것으로 이해했다.

[Class-Specific Context]

이번엔 Class-specific context (CSC)에 대해 알아보자.

CSC는 각 class가 독립적인 context vector를 의미한다.

$$ [V]^i_1[V]^i_2. . . [V]^i_M \neq [V]^j_1[V]^j_2. . . [V]^j_ M \space\space\space i, j ∈ {1, . . . , K}. $$

unified context의 대체제로써, CSC는 특히 fine-grained classification에 유용하다.

[Training]

cross-entropy를 기반으로 하는 classification loss를 최소화하는 방식으로 훈련된다.

gradient는 text encoder를 통해 모든 방향으로 back-propagation이 가능하다.

continuous representations의 design은 word embedding space에서 full exploration을 가능하게 하며 이는 task와 관련된 context를 학습하는 데 용이하게 만든다.

3.3 Discussion

CoOp는 CLIP과 같은 최근에 많이 제시되는 vision-language model에서 나타나는 문제점들을 다루기 때문에, 기존 NLP에서 사용하던 prompt learning method와는 조금 다른 점들이 존재한다.

  1. backbone architecture가 언어 모델과 CLIP은 명백히 다르기 때문.
  2. pre-training objective가 다르다.
    1. contrastive learning vs autoregressive learning


5. Conclusion, Limitations and Future Work


  • Large pre-trained model은 다양한 downstream task에서 뛰어난 성능을 보였으나 “critically central yet incomplete”로 인해 downstream task를 위한 automated technique이 필요하다.
  • 논문에서 제시한 접근 방법은 CLIP을 어떻게 data-efficient learner가 될 것인가에 대한 insight를 제공한다.
    • CoOp는 직접 작성한 prompt보다 훨씬 더 성능이 좋은 결과를 도출해 냈으며, 성공적인 prompt learning을 수행한다.
  • 성능이 좋긴 하지만, 여러 한계가 존재한다.
    • difficult to interpret
    • sensitive to noisy labels
  • 그럼에도 불구하고 CoOp는 future work로 확장하는 것이 쉽고 여전히 해결해야 할 문제들이 많다.
    • cross-dataset transfer
    • test-time adaptation
    • 따라서 해당 논문에서 제시한 접근 방법이 efficient adaptation method에 대한 future work를 제시할 수 있기를 바란다.


Discussion


 

여기서 사용된 방법들의 메인 아이디어는 "Learnable context"를 활용한 게 아닐까? 라는 생각이 든다. 다만, 이 방법들은 다른 모델이나 AI 연구 분야에서 "Learnable query", "Learnable token" 등으로 굉장히 많이 활용되고 있는 방법들과 비슷한 종류라고 생각하기 때문에 엄청 새롭다, 신선하다 ! 라는 느낌은 잘 못 받았던 것 같다. 그럼에도 정말 술술 읽히는 논문이었던 것 같고, 시간이 없어서 코드를 살펴보지 않았지만 결국 이런 연구들의 핵심은 어떻게 구현하는가가 정말 중요하다고 생각하기 때문에 추후엔 시간을 좀 더 들여서 구현 단계를 살펴보고 싶다.

728x90
반응형