본문 바로가기

카테고리 없음

Scaling Instruction-Finetuned Language Models (FLAN 논문 리뷰)

728x90
반응형

[논문 링크]

https://arxiv.org/pdf/2210.11416

 

이번엔 instruction fine-tuning을 사용하여 FLAN이라는 모델을 만든 논문을 리뷰해보고자 한다. 개인적으로 이 논문을 읽었을 때는, GPT-2 논문 이후에 제시된 GPT-3 논문과 비슷한 느낌을 받았다.

이전에 제시된 Flan에서 조금 더 scaling된 Flan을 제시한 것에서 그렇게 느꼈고, 그래서 또 한번 데이터의 중요성을 깨닫게 되었다.

 

키워드는 다음과 같으며 가볍게 참고하자.

Keyword : Instruction Finetuning, Instruction Finetuning data

 


Pre-study


  • Instruction learning이란?
    • 해결하고자 하는 task를 natural language로 만들어서 이에 대한 것으로 fine-tuning
    • 번역하는 task에 대해서 번역하는 task에 대한 설명을 함께 제시하는 것
    • 그럼, 단순하게 template에 어떤 task를 수행해야 하는 지를 추가하여 fine-tuning을 진행한 것인가?
    • 왜 성능이 향상될까?를 생각해보자.


Abstract


 

  • instruction으로 구성된 dataset을 활용하여 Language model을 fine-tuning했을 때, model performance가 크게 증가하고 unseen task에 대한 generalization도 향상된 모습을 확인할 수 있었다.
  • 해당 논문에서는 3가지의 focus로 instruction fine-tuning을 탐구
    • scaling the number of tasks
    • scaling the model size
    • fine-tuning on chain-of-thought (CoT) data
  • 위의 3가지 측면에 대해 instruction fine-tuning을 진행했을 때, 많은 model class, prompt setup, evaluation benchmark에 대해서 performance가 dramatically하게 향상한 모습을 확인할 수 있었다.
    • Flan-PaLM 540B >>>> PaLM 540B
      • Flan-PaLM 540B = instruction fine-tuned model

결론적으로, instruction fine-tuning은 모델의 performance를 향상 시키고 pre-trained language model의 usability를 향상 시킬 수 있는 general method로 제안된다.


1. Introduction


 

unseen task에 대해서 generalization하는 많은 방법들이 고안되었고 그에 따른 progress도 많았다. 특히, instruction으로 phrase된 collection of task를 가지고 모델을 fine-tuning시켰을 때, 모델은 few-shot exemplars를 덜 필요로 하면서도 instruction에 대해 더 잘 반응하는 모습을 확인.

이 논문에서는 여러 방법으로 instruction fine-tuning을 발전

  • instruction fie-tuning에 대해 scaling에 대한 impact 확인
    • number of task에 대한 scaling
    • model size에 대한 scaling
    • 이 실험에 대한 결과들은 더 scaling을 해야 한다는 future work를 제시
  • reasoning task를 수행하는 model ability에 대한 fine-tuning effect
    • 이전 fine-tuning dataset에는 CoT 포함 X → CoT evaluation에서 degrade performance
    • 단순하게 9개의 CoT dataset을 포함 → 모든 evaluation에서 성능 향상

이러한 결과들을 바탕으로 540B-parameter model을 사용하여 Flan-PaLM을 training

  • finetuning task를 1.8K로 증가
  • CoT data 포함

이로 인해 다양한 SOTA를 갱신하고 performance가 향상

  • reasoning ability
    • arithmetic reasoning
  • multi-lingual ability
  • several responsible AI evaluation benchmarks

추가로 Flan-T5 model을 instruction fine-tuning했을 때, zero-shot과 few-shot, CoT ability가 모두 증가하는 모습을 보여주었으며 T5와 같은 public checkpoint에 대한 성능을 모두 앞서 가는 드으이 결과를 얻을 수 있었다.

 


2. Flan Finetuning


 

해당 논문에서 사용한 fine-tuning data에 대한 figure

 

총 473개의 dataset과 146개의 task category, 1836개의 total task로 이루어져 있다.

  • Fine-tuning task와 Held-out task로 구분되어 있음.

여기서 Appendix F를 잠시 참고해보자.

public하게 사용할 수 있는 dataset들을 사용했으며, 모든 MMLU task를 삭제해서 evaluation을 위한 57개의 held-out task로 활용했다. 따라서, 전체적으로 1836개의 task를 다루게 된다.

 

이러한 데이터 셋들을 사용할 때, data distribution에 대해 unwanted bias가 존재할 수 있으므로 downstream user들은 이를 반드시 고려해서 사용해야 한다.

또한, sensitive human attributes를 포함할 수 있다.

Data distributions과 관련하여 top category가 mixture마다 존재하기 때문에 이에 대한 점을 반드시 고려해야 한다.


Figure 2에서 소개된 collection of data source로 Figure 3과 같이 template을 만들어서 instruction fine-tuning을 진행했다. 이러한 fine-tuning procedure을 Flan (Fine-tuning language models)라고 하며 결과적으로 fine-tuned된 모델 앞에 ‘Flan’ 이라는 이름을 붙여서 사용한다.

Flan 작업이 여러 모델의 사이즈와 architecture에서 사용될 수 있다는 것을 Table 2에서 확인할 수 있다.

  • Table 2에서 fine-tuning하는 데에 cost가 적게 사용된다.

※ 추가로 Flan은 fine-tuning procedure, FLAN은 모델을 의미한다.

2.1 Fine-tuning Data (Keyword로 등장)

Task Mixtures

기존 연구들은 fine-tuning에서 task의 숫자를 늘리면 unseen task에 대한 generalization이 증가한다는 것을 발견했다. 따라서 1836개의 fine-tuning task로 늘려서 사용한다.

이전 연구로부터 4개의 mixture를 결합하여 1836개의 fine-tuning task를 제작한다.

  • Muffin (80 tasks) + 26 new tasks (our addition)
    • Muffin = Multi-task fine-tuning with instructions
  • T0-SF (193 tasks)
    • Muffin과 겹치지 않게.
  • NIV2 (1554 tasks)

Chain-of-thought fine-tuning mixture

마지막 4번째 fine-tuning data mixture은 CoT annotation을 포함하고 있는데, 이는 CoT annotation에서 fine-tuning을 했을 때, unseen reasoning task에서 성능이 향상하는 지를 확인하기 위해 추가했다.

human raters가 직접 제작한 9개의 CoT annotation을 합쳐서 만들었으며, arithmetic reasoning, multi-hop reasoning, NLP inference 등을 추가한다. 그리고 각 task에 대해 10개의 instruction templates을 구성한다.

Templates and formatting

사용된 4개의 mixture data에서, 각각의 mixture data의 creator가 제시한 templates을 사용했으며 CoT같은 경우 10개 정도의 templates을 사용했다.

few-shot templates을 제작하기 위해, 다양한 exemplar delimiters를 썼으며 random하게 example level에서 적용했다.

2.2 Fine-tuning procedure

T5, PaLM, U-PaLM을 포함하는 넓은 범위의 model family에 대해 instruction fine-tuning을 진행했다. 이러한 model family는 size의 크기가 80M parameters (Flan-T5-small)부터 540B parameters (U-PaLM)까지 확장된다.

  • 모두 동일한 training procedure
    • learning rate, batch size와 같은 hyper parameter는 일부 변동
  • constant learning rate scheduler
    • Adafactor optimizer
  • Packing 사용 → 여러 training example을 하나의 sequence로 통합하기 위해서.
    • end-of-sequence token을 사용해서 target으로부터 input을 분리
  • Masking 사용
    • packed example boundary에서 token들이 다른 token에 attending하지 않도록.
  • 모든 evaluation에는 single checkpoint 사용
  • optimal step은 periodic evaluations에 기반하여 선택된다.
  • fine-tuning에 필요한 computation은 pre-training에 비하면 small fraction.
  • JAX-based T5X framework를 사용

여기서 Appendix E를 참고해보자. (Fine-tuning details)

Table 23은 각 mixture에 대한 proportion을 나눠서 사용했다.


2.3 Evaluation protocol

Evaluation benchmarks

fine-tuning data에 포함되어 있지 않은 held-out task로 performance를 수행하는 것에 초점을 맞췄다. 일반적인 world knowledge와 reasoning task에 대한 Flan-PaLM의 전반적인 capabilities를 측정하기 위해 다양한 범위의 benchmarks를 사용했다.

여기서, GPT3 논문에서 사용된 evaluation set을 사용하지 않고 여전히 expert human rater보다 성능이 낮게 평가되는 challenging benchmark에 대해서 평가했다.

  • MMLU
    • mathematics, history, law, medicine, etc.
  • BBH
    • BIG-Bench (PaLM은 human rater보다 성능이 낮음)
  • TyDiQA
    • question-answering benchmark
      • diverse language
    • MGSM
      • multi-lingual benchmark of math word problems (10 languages)

이러한 dataset들은 PaLM 논문에서 한번 언급이 되었고, 이전 연구에서 data contamination에 대한 의미 있는 결과들이 없었다.

Appendix C에서 data contamination, toxicity와 관련된 내용들을 더 찾아볼 수 있다.

Evaluation methods and metrics

  • MMLU & BBH
    • direct prompting을 통해 direct하게 답변을 prediction하는 능력을 평가
    • CoT와 같은 prompting에서도 마찬가지이며, 이 때 답변도 CoT를 제시하고 마지막에 답변을 제시해야 함.
    • five-shot / three-shot
  • TyDiQA
    • direct prompting exact-match score를 측정
    • sophisticated answer를 필요로 하지 않기 때문에 가능.
    • one-shot
  • MGSM
    • CoT prompting accuracy로 평가
    • direct prompting은 매우 낮은 performance를 보이고 있기 때문에.
    • 8-shot

또한, BIG-Bench에서 제시한 “normalized preferred metric”에 이어 “normalized average” metric을 사용한다.

  • macro-average over six normalized scores
    • MMLU-Direct
    • MMLU-CoT
    • BBH-Direct
    • BBH-CoT
    • TyDiQA-Direct
    • MGSM-CoTFlan 작업이 여러 모델의 사이즈와 architecture에서 사용될 수 있다는 것을 Table 2에서 확인할 수 있다.


3. Scaling to 540B parameters and 1.8K tasks


Model scaling : 8B → 62B → 540B

task scaling : CoT → Muffin → T0-SF → NIV2 순서대로 mixture을 추가

Individual Benchmark에 대한 결과

모델 크기와 task를 함께 scaling 했을 때, joint effect에 대한 figure.

  • Model parameter의 입장에서, 9.4% performance gain을 얻을 정도로 성능 향상
  • number of task를 증가했을 때도, improvement 확인
    • 282 tasks에서 더 증가 시켰을 때는, 크게 유의미한 차이를 보이진 않았음. 이에 대한 2가지 potential explanation이 존재
      • additional task가 특정하게 다양하지 않기 때문에, 그리고 model에 새로운 지식을 제공하지 않기 때문에.
      • multi-task instruction fine-tuning의 대부분의 gain이 모델이 이미 사전 훈련 과정에서 습득한 지식을 더 잘 표현하는 것을 학습하는 것에서 오는건데, 282 tasks 이상 사용했을 때는 이것에 큰 도움이 되지 않기 때문에.
    • fine-tuning은 1.4B tokens을 사용하지만 pre-training은 780B tokens을 사용한다는 점에서 2번째의 이유가 더 일리 있음.
  • 마지막으로 fine-tuning을 하던 안하던, 모델 규모가 커짐에 따라 performance가 향상한다.
  • 따라서, 모델 규모에 따른 성능 향상에 대한 효과가 어떻게 다른 지를 말하는 것은 어렵다.
    • 모두 fine-tuning이 성능을 향상 시키는 것은 동일하지만 작은 모델이 absolute performance에 대해서는 큰 모델보다 더 향상 폭이 크지만, 큰 모델이 relative error rate는 더 많이 감소된 것을 확인했다.

이러한 scaling curve에 대한 plot은 모델 크기와 task 개수를 어떻게 scaling해야 더 좋은 performance를 얻을 수 있을 지에 대한 insight를 제공한다.

Another order of magnitude로 모델 사이즈를 scaling하는 것이 상당한 performance gain을 유도할 수 있고, fine-tuning task 개수를 늘리는 것 또한 동일한 방식으로 performance gain을 얻을 수 있다.

전반적으로, scaling curves는 instruction fine-tuning에 대한 지속적인 scaling을 future work로 제시한다.

 


4. Fine-tuning with chain-of-thought annotations


Traditional NLP task뿐만 아니라 multi-step reasoning ability를 포함하는 다양한 evaluation에 대해 improved checkpoint를 만들기 위해 Flan fine-tuning을 한다.

여기서 CoT data를 instruction fine-tuning dataset에 포함되었을 때의 effect를 살펴본다.

  • CoT dataset을 포함했을 때, Flan-PaLM의 reasoning abilities가 증가하는 모습을 확인
  • CoT fine-tuning data에 대한 ablation study를 통해, 모든 evaluation에서 CoT dataset을 단순히 포함하는 것만으로도 performance가 증가했으며, CoT dataset이 없었을 때는 reasoning ability가 감소
  • CoT fine-tuning dataset을 포함했을 때, challenging BIG-Bench task에 대해서도 “let’s think step-by-step”을 통해 zero-shot reasoning을 향상 시켰음을 보여준다.

4.1 Fine-tuning on chain-of-thought improves reasoning on held-out tasks

  • CoT annotation이 포함된 9개의 dataset을 추가하는 것만으로도 reasoning ability가 향상된 것을 확인할 수 있다.
    • Flan-PaLM이 PaLM을 모두 outperform
  • 또한, self-consistency와 CoT prompting을 결합했을 때에 대한 결과도 보여주고 있다.
    • MMLU와 MGSM에서 SOTA를 달성
    • 하지만 BBH-alg에서는 좋은 모습을 보여주지 못했는데, 이는 symbolic manipulation이 필요한 task이기 때문이라고 생각.

4.2 Some chain-of-thought data is needed to maintain reasoning ability

Instruction fine-tuning에서 CoT dataset을 ablate하여 effect를 비교.

  • Held-out CoT benchmark인 MMLU, BBH, MGSM에 대해 evaluation을 진행.
  • Held-out non-CoT benchmarks인 MMLU, BBH, TyDiQA에 대해 evaluation 진행

  • Left figure
    • Held-out CoT benchmarks에서는 단순히 CoT만 사용하는 것보다 non-CoT와 함께 결합해서 fine-tuning하는 것이 가장 성능이 좋았다.
    • 이 그림을 통해 일부 CoT example을 활용하는 것이 reasoning ability를 유지하는 데에 매우 중요하다는 것을 의미하고 있다.
      • green line과 비교해보면 더 와닿음.
  • Right Figure
    • Held-out non-CoT benchmarks에서는 CoT를 함께 사용했을 때, 성능을 저하 시키지 않는 다는 것을 확인.

이전 연구에서 instruction fine-tuning이 unseen task에 대한 improvement를 얻을 수 있다고 했기 때문에, 왼쪽 그림은 조금 의아할 수도 있음.

  • 하지만 이전 연구들은 held-out NP tasks로만 평가
  • 또한, CoT reasoning을 평가하기에 너무 작은 모델을 사용

결론적으로, CoT / non-CoT를 함께 사용해야 한다.

4.3 Unlocking zero-shot reasoning

Figure 7을 통해 PaLM과 Flan-PaLM에서 사용된 3개의 zero-shot CoT example을 확인할 수 있다.

“let’s think step-by-step”을 통해 reasoning ability가 향상된 것을 확인할 수 있음. 반면에 fine-tuning하지 않은 PaLM은 굉장히 낮은 accuracy를 보임.


7. Discussion


Instruction fine-tuning을 3가지로 확장

  • scaling the number of fine-tuning task
  • scaling the size of model
  • fine-tuning on CoT data

이 논문으로부터 takeaway (핵심) 를 다음과 같이 요약할 수 있다.

Scaling curvers for instruction fine-tuning

  • model size를 키우는 것과 fine-tuning task를 증가시키는 것은 performance를 향상 시켰다.
  • 이전 연구들은 templates, task, model size를 개별적으로 증가 시켰지만 section 3을 통해 함께 증가 시켰을 때 joint effect를 확인할 수 있었다.

CoT finetuning is critical for reasoning abilities

  • 이전 instruction fine-tuning이 unseen non-CoT tasks에서 performance가 향상 하긴 했지만, CoT tasks에서 degraded performance를 보임
  • CoT가 포함된 dataset으로 fine-tuning을 진행하면 모델의 reasoning ability를 보존할 수 있음.

Instruction fine-tuning generalizes across models

  • Section 5에서 Decoder only, encoder-decoder와 같은 다양한 architecture를 가진 모델과 다양한 모델의 size, pre-training objectives에도 적용할 수 있다는 것을 확인했다.
  • instruction fine-tuning에 대한 effectiveness가 consistent하다는 것을 입증
  • 추가로, U2LR과 같은 model adaptation technique을 함께 사용했을 때 더 좋은 결과를 얻을 수 있었다. (Flan-U-PaLM)

Instruction fine-tuning improves usability and mitigates some potential harms

전문가가 아닌 사람들이 pre-trained checkpoint를 직접 사용하는 것이 어렵다. 왜냐하면 다음 token이 무엇인지 예측을 하도록 훈련된 모델 혼자서는 언제 generating을 멈춰야 할지도 모르고, user의 input에 responding하는 것이 아니라 continuing하는 실수를 할 수도 있기 때문이다.

Flan-PaLM이 PaLM을 open-ended evaluation에서 outperform하는 것을 확인했고, 특히 CoT task에서 더 좋은 performance를 보였다. Flan-PaLM은 responsible AI benchmarks에서 PaLM보다 더 유능하며 특히 toxic language에 대한 benchmarks에서 더 좋은 모습을 보인다.

모델의 Zero-shot usability는 prompt engineering, 그리고 few-shot exemplars가 필요하지 않은 넓은 범위의 LLM에서 굉장히 중요하다.

Instruction fine-tuning is relatively compute-efficient

비록 모델 size를 scaling하는 것이 유의미한 결과를 보이더라도, 상당한 계산이 필요하다. 따라서 compute-efficient에 대한 기술이 발전되는 것이 중요하다.

  • existing checkpoint를 이용하는 기술 등

Instruction fine-tuning은 상대적으로 작은 계산 양으로 모델의 performance를 향상 시킨다. 게다가 instrution fine-tuning된 더 작은 모델은 fine-tuning되지 않은 더 큰 모델을 outperform하는 경우도 있다.

결론적으로, instruction fine-tuning이 few-shot, zero-shot, CoT, open-ended generation evaluation에 어떻게 기여하는 지에 대해 논문에서 서술했다. Instruction fine-tuning은 다양한 모델에 generalization할 수 있으며, 다른 기술들과도 잘 결합하여 사용될 수 있기 때문에 거의 모든 pre-trained LLM에 instruction fine-tuning을 추천한다.

 


My discussion


Scaling함으로써 더 좋은 성능의 모델을 얻었다는 점, 그리고 실제로 FLAN-T5 모델 등을 손쉽게 사용할 수 있다는 점에서 또 다른 LLM의 축을 제공한 논문이 아닐까 싶다. 

또한, 여러 데이터에 대한 결과들과 모델 사이즈에 대한 결과 등을 통해서 scaling law는 여전히 참인 것과 동시에 데이터의 중요성도 알게 되었다.

새로운 방법론이라기 보다, fine-tuning할 때 우리가 어떤 목적으로 이 데이터를 학습 시킬 지에 대한 instruction을 가볍게 추가하는 것만으로도 성능이 향상 됐다는 점을 어필하는 것처럼 보였다.

이 분야를 오래 공부하지 않았던 학생의 입장으로, prompt engineering라는 분야에 대해서 의아함을 많이 가지고 있었는데 이런걸 보다 보면 기업이나 여러 연구 분야에서 왜 prompt engineering을 놓지 못하고 있는 지에 대해서도 어느정도 납득이 갔다. 

아무튼,, 아직 갈길이 많이 멀다.. 라는 걸 느꼈다. 

728x90
반응형