본 연구는 수 많은 딥러닝 모델들이 결과로 주는 confidence score를 조정하는 연구이다. Confidence score는 모델의 최종 layer에서의 sigmoid 또는 softmax 값을 의미하며, $[0,1]$ 값을 가진다. 논문의 주요 contribution은 다음과 같다.
- 딥러닝 모델이 커지면서 over-confident하게 되는 현상을 발견
- perfect calibration을 정의하고, 이를 근사하기 위한 metric 정의
- calibration 방법 비교 실험
논문의 주요 내용은 다음과 같다.
1. Introduction
모델의 예측값에 대한 confidence를 같이 제공하는 것은 모델의 성능 뿐만 아니라 신뢰도에 많은 영향을 준다. 현대의 neural network는 과거에 비해 기하 급수적으로 커졌는데, 동시에 관찰된 현상은 더 이상 well-calibrated되지 않는다는 것이다.
위 그림에서 확인해보자. LeNet은 5-layer 모델이고 ResNet은 110-layer 모델이다. 그림의 confidence histogram을 보면 모델의 크기가 커지면서 accuracy도 높아졌지만 confidence score 또한 1에 가까운 값으로 치우치게 출력되는 것을 확인할 수 있다.
이를 좀더 분석하기 위해서는 calibration에 대한 정의와 metric이 필요한다. 본 연구는 supervised learning 상황을 가정으로 진행되었다.
2. Definitions
먼저 용어 정의가 필요하다.
- $\hat{Y}$: class prediction | $\hat{P}$: associated confidence score
- $P$: real distribution of a class
예를 들어 100개의 데이터가 있고, $class 0$이 80개가 있다고 하자. 이때, perfect calibration은 confidence score가 true probability와 같은 상황으로 정의할 수 있다.
$$\mathbb{P}(\hat{Y}=Y|\hat{P}=P) = p, \ \ \forall x \in [0,1]$$
(조금 의문이 들긴 한다. $class 0,1$이 6:4로 섞여있어도, 각 data point에 대한 confidence score는 높은게 좋지 않나?)
이를 측정하기 위해 적용한 방법들 중 일부는 다음과 같다.
1) Reliability Diagrams
Reliability Diagrams는 confidence score에 대해 일종의 histogram을 그리는 방식이다. 아래 그림을 먼저 보자.
각 confidence score의 구간 별로 accuracy를 구할 수 있는데, perfect calibration 상황이라면 각 구간별로 accuracy가 confidence score와 비슷할 것이다(왼쪽). 반면 over-confident하다면 confidence score 값들이 오른쪽으로 치우치고 구간별 accuracy가 일정하지 않을 것이다(오른쪽).
Reliability Diagram은 다음과 같은 방법으로 그린다. 먼저 $M$개의 interval을 나누고 각 구간에 대한 accuracy를 구한다. $B_{m}$을 해당 interval $I_{m} = (\frac{m-1}{M}, \frac{m}{M}]$에 속하는 sample의 개수라고 하고, $\hat{y}_{i}, y_{i}$를 각각 predicted label, true label이라고 하자.
$B_{m}$의 accuracy는 다음과 같다.
$$acc(B_{m}) = \frac{1}{|B_{m}|} \sum_{i \in B_{m}} 1(\hat{y}_{i} = y_{i})$$
$B_{m}$의 average confidence는 다음과 같다.
$$conf(B_{m}) = \frac{1}{|B_{m}|} \sum_{i \in B_{m}} \hat{p}_{i}$$
위 정의를 기반으로 하면 prefectly calibrated model은 모든 구간 $m \in {1, ..., M}$에 대해 $acc(B_{m}) = conf(B_{m})$인 모델로 정의할 수 있다.
2) Expected Calibration Error(ECE)
miscalibration의 개념은 구체적으로는 accuracy와 confidence 분포의 차이로 생각해볼 수 있다.
$$\mathbb{E}_{\hat{P}}[|\mathbb{P}(\hat{Y}=Y|\hat{P}=p) - p |]$$
ECE는 위 수식에서 구간을 $M$개의 bin으로 나눠 근사사키는 방법으로 계산한다.
$$ECE = \sum_{m=1}^{M} \frac{|B_{m}|}{n} |acc(B_{m}) - conf(B_{m})|$$
위 개념들을 문서분류 모델에 대해 PoC를 진행해보았다. 적용 모델은 MPAD(Message Passing Attention networks for Document Understanding)이다.
적용 결과 calibaration 효과는 없는 것으로 판단된다. 모델이 매우 잘 학습된 경우(ex. F1-score 0.99) 당연히 큰 confidence score가 나올텐데, 실제 데이터 분포에 맞추도록 calibrate한다는 것이 제대로 된 정의는 안된 것 같다.
위 그림에 대한 코드는 아래와 같다.
def calibration_graph(path):
'''
input:
- preds, scores, trues: 1-d dimension list
'''
preds, scores, true = pickle.load(open(path, 'rb'))
preds= np.array(preds)
scores = np.array(scores)
true = np.array(true)
num_data = len(scores)
interval = 0.1
num_intervals = int(1/interval)
IntervalRange = collections.namedtuple('IntervalRange', 'start end')
intervals=[]
for i in range(num_intervals):
_intev = IntervalRange(start=i*interval, end=(i+1)*interval)
intervals.append(_intev)
percents = []
accs = []
eces = []
for itv in intervals:
interval_scores = scores[np.where((scores >= itv.start)&(scores < itv.end), True, False)]
interval_preds = preds[np.where((scores >= itv.start)&(scores < itv.end), True, False)]
interval_trues = true[np.where((scores >= itv.start)&(scores < itv.end), True, False)]
percent = len(interval_scores) / num_data
if list(interval_scores):
acc = np.sum(np.equal(interval_trues, interval_preds)) / len(interval_trues)
conf = np.average(scores)
percents.append(percent)
accs.append(acc)
eces.append((len(interval_scores)/num_data)*(np.abs(acc-conf)))
else:
percents.append(0.0)
accs.append('nan')
avg_acc = np.sum(np.equal(preds, trues))/len(trues)
plt.figure(figsize=(12,6))
plt.xlabel('confidence')
plt.subplot(1,2,1)
plt.bar([intv.start for intv in intervals], percents,width=0.1, color='lightskyblue', edgecolor='silver', align='edge', alpha=0.6)
plt.grid(True,alpha=0.7, linestyle='--')
plt.ylim(0,1)
plt.axvline(x=avg_acc, linestyle='--', color='b')
plt.text(avg_acc-0.08,0.5,'accuracy',rotation=90, color='b')
plt.axvline(x=np.average(scores), linestyle='--', color='g')
plt.text(np.average(scores)+0.03,0.4,'avg. confidence',rotation=90, color='g')
plt.ylabel('% of samples', fontsize='x-large')
plt.xlabel('confidence', fontsize='x-large')
plt.title('confidence histogram', fontsize='xx-large')
plt.subplot(1,2,2)
text = 'Error=%.4f'%np.sum(eces)
plt.bar([intv.start for intv in intervals], [intv.end for intv in intervals], width=0.1, color='plum', edgecolor='silver', align='edge', alpha=0.5, label='Ideal')
plt.bar([intv.start for intv in intervals], [a if a != 'nan' else 0.0 for a in accs ],width=0.1, color='lightskyblue', edgecolor='silver', align='edge', alpha=0.6, label='Outputs')
plt.grid(True,alpha=0.7, linestyle='--')
plt.legend()
plt.text(0.6,0.05,text, backgroundcolor='w', alpha=0.6)
plt.ylabel('accuracy', fontsize='x-large')
plt.xlabel('confidence', fontsize='x-large')
plt.title('reliability diagram', fontsize='xx-large')
plt.subplots_adjust(wspace=0.25)
plt.show()
'논문 및 개념 정리' 카테고리의 다른 글
[GPT3] 주요 내용 정리 (0) | 2022.02.16 |
---|---|
[2020] Spot The Bot: A Robust and Efficient Framework for the Evaluation of Conversational Dialogue Systems (0) | 2021.08.13 |
[2018] Universal Language Model Fine-tuning for Text Classification(ULMfiT) (0) | 2021.03.15 |
[2018] Deep contextualized word representations(ELMo) (0) | 2021.03.15 |
[2017] Attention is All you Need (0) | 2021.03.15 |