728x90
Knowledge Distillation(KD)이란 더 크고 복잡한 모델(LLM)의 정보 혹은 지식을 더 작고 효율적인 모델(sLLM)에 전달(transfer)하는 기술입니다. KD의 목표는 LLM이 학습한 풍부한 표현(rich representation)을 sLLM에 전달해서 자원-효율적이면서 LLM과 비슷한 정확도를 갖게 하는것 입니다.
이를 적용하기 위해서 우선 teacher model(대규모 모델)을 먼저 훈련시킵니다. 이후 student model이 특정 Loss function을 통해 훈련합니다. 이 loss function은 단순히 훈련 데이터의 label 뿐만 아니라 teacher model이 생성한 soft output(probabilities)에도 기반합니다. 이러한 soft output은 hard 레이블보다 더 섬세한 이해를 제공하며, 다양한 class에 걸쳐 teacher model의 confidence를 전달한다.
학습 과정에서 temperature parameter를 활용해서 확률을 soften해서 분포(distribution)을 더 유익하고 student model이 학습하기 쉽게 만들어줍니다.
잘 알려진 teacher-student pair는 NLP 분야에서 BERT-DistillBERT가 있고 image classification 분야에서는 Resnet-50과 Mobile-Net이 있습니다.
실습 코드
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, targets, temperature):
soft_loss = F.kl_div(F.log_softmax(student_logits / temperature, dim=1),
F.softmax(teacher_logits / temperature, dim=1),
reduction='batchmean') * (temperature ** 2)
hard_loss = F.cross_entropy(student_logits, targets)
return soft_loss + hard_loss
# Assuming teacher_model and student_model are defined and loaded
# Define temperature and alpha for balancing the loss components
temperature = 5.0
alpha = 0.5
# Compute teacher and student outputs
teacher_logits = teacher_model(input_data)
student_logits = student_model(input_data)
#
Compute distillation loss
loss = distillation_loss(student_logits, teacher_logits,
targets, temperature)
# Backpropagate and update student model
loss.backward()
optimizer.step()
'인공지능 > LLM' 카테고리의 다른 글
할루시네이션 발생 원인 (0) | 2024.11.02 |
---|---|
LLM의 Risk (0) | 2024.11.02 |
RAG란? (0) | 2024.03.17 |
챗봇 구현 실습 (4) - 챗봇 구현 (0) | 2024.03.17 |
챗봇 구현 실습 (3) - streamlit (0) | 2024.03.17 |