Chanjong Park
9 min readMar 21, 2021

[AI 논문리뷰] Continual Learning on Deep Learning

1. Catastrophic Forgetting of Deep Learning

딥러닝은 데이터를 학습하고 분류하는데 매우 강력한 성능을 보여왔으며, 많은 곳에 활용되어 왔습니다. 하지만, 이 딥러닝의 단점 중 하나는 한 개의 모델이 데이터 학습 후 새로운 데이터셋을 새로 학습할 때, 이전에 학습됬던 데이터셋에 대한 기억을 잊어버린다는 것 입니다.

이를 논문들에서는 Catastrophic forgetting이라 지칭합니다. Cross Entropy 기반의 loss를 사용하여 딥뉴럴넷 학습은 은 현재 데이터에 대한 분포를 추정하는 것이므로, 이전 데이터에 대한 분포는 고려하지 못한채 학습을 하는 것으로 보입니다.

Figure 1. Catastrophic forgetting을 도식화한 그림. 1차 학습에서 MNIST 학습한 다음, 같은 모델을 2차 학습에서 SVHN 데이터로 Transfer learning 하였을 경우, 두 번의 학습 모두 0~9클래스 분류하는 같은 작업 일 지라도, 모델 A에서 MNIST데이터에 대한 분류 성능은 보장이 되지 않습니다. (Dataset images from the references.)

위의 Figure 1 예시를 통해 설명을 드리면, MNIST로 0~9까지의 숫자 이미지를 학습한 모델 A가 있다고 가정해봅니다. 그리고 1차학습이 끝난 후, 우리는 MNIST가 학습된 모델 A에 SVHN Dataset을 학습시킵니다. 이럴 경우, 모델은 1차로 MNIST 데이터셋을 학습하였고, 2차로 SVHN데이터를 학습하였으나, 이후 모델A의 MNIST 데이터셋 분류성능은 1차학습 때만큼 나오질 않습니다. SVHN데이터셋의 분포를 학습하면서 MNIST데이터를 잊어버린 것 입니다. 이를 Catastrophic forgetting이라 지칭합니다.

우리는 딥뉴럴넷이 적용된 시스템에서, 새로 추가된 데이터로 모델을 업데이트 해야할 경우가 있습니다. 이런 경우 우리는 Catastrophic forgetting 현상 때문에, 이전에 이미 모델이 학습했던 모든 데이터셋을 모두 저장소에 저장해두었다가 다시 모든 데이터를 사용하여 학습을 해야하는 번거로움이 있습니다.

2. Continual Learning

Continual learning은 딥뉴럴넷이 Catastrophic forgetting 현상에 영향받지 않고(이전에 학습된 데이터를 잊어버리지 않고), 새로운 데이터셋만을 이용해서 지속하여 학습할 수 있도록 하는 연구 분야 입니다.

이전 데이터를 잊어버리지 않고, 학습할 수 있도록 하는 여러가지 방법들이 논문을 통해 제안되었는데, 그 중 필자가 인상깊게 읽은 논문 3가지를 간단히소개드리려고 합니다.

본 포스팅에서는 이 논문들이 사용한 방법을 간단히 설명드리려고 합니다. 개인적인 의견이 포함되어 있을 수 있으니 참고해주세요.

2–1. Overcoming catastrophic forgetting in neural networks (Elastic Weight Consolidation)

첫번째 논문은 Deep Mind에서 나온 Overcoming catastrophic forgetting in neural networks 입니다. 주로 이 논문에서 사용한 알고리즘의 이름인 Elastic Weight Consolidation(EWC)라고도 불리기도 하는듯 합니다.

이 알고리즘을 간단히 설명하면, 이전 데이터(A)를 우선 학습합니다. 그리고 딥뉴럴넷에서 이 데이터를 학습+분류하는데 중요한 Weight(뉴런?) 와 중요하지 않은 Weight를 계산하여, 이 중요한 Weight들은 이후 추가되는 새로운 데이터 학습 시에 최대한 변화되지 않도록 하고, 중요하지 않았던 Weight들 위주로 학습하도록 합니다.

Figure2. Loss function from the ‘EWC’ paper [3]

논문에서는 이 방법을 고려한 Loss function을 제안합니다. 이전데이터(A)와 새로운 학습데이터(B)를 학습한다고 할때, F_i는 기존 학습데이터(A)에 대한 Weight들의 중요도를 나타내는 정보를 담고있으며, 세타는 현재 Weight값, 세타_A는 이전 데이터를 학습한 후의 Weight값들 입니다. 이 차이가 커질수록 기존 모델로부터 변화함을 뜻하며, 이 중 중요한 Weight들의 변화가 커질수록 Loss값이 커지도록 되어있습니다. 람다값이 커질수록 이 중요한 Weight값들의 변화를 억제하는 방향으로 학습을 하는 것으로 보입니다.

2–2. Gradient Episodic Memory for Continual Learning

Facebook research가 NIPS2017에서 공개한 논문으로, 개인적으로는 매우 합리적인 Continual Learning 방법을 제안합니다.

이 방법은 먼저 학습하는 데이터A, 이후 새로 학습하는 데이터B가 있다고 할때, 데이터A를 학습 후, 데이터A의 일부만을 가지고, 모델이 데이터A에 대한 성능을 기억 수 있도록 합니다.

알고리즘을 간단히 설명드리면, 모델은 데이터A 학습 후 데이터A의 일부만 샘플링하여 저장해둡니다. 이후 데이터B를 학습할 때, 매번 Back-propagation 때마다, 샘플링된 데이터A를 이용해 계산된 Gradient A와 데이터B에 대한 Gradient B를 비교하여, 만약 Gradient B의 학습이 데이터A에 대한 Loss가 증가하는 방향이라면 두 Gradient를 이용하여 새로 조정된 Gradient를 통해 학습합니다.

이 때 데이터 A에 대하여 Loss가 증가하는지, 하지않는지 여부를 두 Gradient사이의 각도차이로 판단합니다. 그래서 데이터 A에대한 Loss가 증가하지 않을 경우, 그대로 Weight update를 진행하고, 만약 A에대한 Loss가 증가할 경우, Gradient 방향을 이전 데이터에 대한 loss값이 상승하지 않는 조건을 만족하는 가장 가까운 다른 Gradient에 투영하여 업데이트 합니다.

Gradient로 Weight를 업데이트하는 2가지 케이스 : 데이터B에 대한 Gradient(Blue)와 데이터A의 샘플에 대한 Gradient(Black)간의 각도를 비교하여, 데이터B에 대한 Gradient만으로 업데이트할지, 조정된 Gradient(Green)으로 업데이트 할지가 결정됩니다.

이렇게 하면 모델이 데이터B를 학습을 할때, 이전 데이터A에 대한 기억을 잃지않는 방향으로 Gradient를 보정해가며 학습한다고 합니다.

2–3. Continual Learning with Deep Generative Replay

이 논문 역시 NIPS2017에서 발표되었으며, Generative adversarial network(GAN)를 이용해서 데이터를 기억하고, 재현하는 방법을 통해 Continual Learning 프레임 워크를 제안합니다.

Generator와 Solver 2가지 Network로 학습하며, Generator는 학습했던 데이터를 다시 재생성 해낼수 있으므로, 데이터를 저장하는 역할이며, Solver는 진짜 데이터와 생성된 데이터를 이용하여 Task를 학습하는(분류기) 역할 입니다.

Figure2. Sequential 학습을 도식화한 그림. New Scholar는 현재 Task에 대한것은 진짜 데이터(x), 이전 (Old Task)에 대한 것은 이전 모델(Old Scholar : Generator and Solver)이 생성한 이미지(x`)를 사용하여 학습하게 됩니다. [Figure Image is from the paper [3]]

만약 데이터1, 2, 3을 순차적으로 학습할 경우, 데이터 2를 학습할 차례에서는 Generator에서 생성된 데이터1과 실제 데이터2 둘을 이용해서 학습하게 됩니다. 이렇게 Generator모델을 이용하여 이전에 학습된 데이터를 지속적으로 저장해가며 학습해가므로, 따로 방대한 데이터를 저장소에 저장해두지 않아도 되게 됩니다.

Table from the paper[3] : 논문에서는 Solver 1 (분류기) 학습에서만 Real data를 사용하였으며, 2~5번 학습에서는 이전 Scholar networks(Generator, Solver)로 부터 학습한 결과라고 합니다. 2번째부터는 Generator모델에 의해 생성된 MNIST이미지만을 사용하여 학습했지만 Test Accuracy 는 거의 일정하게 유지되는 것을 보입니다.

이런 지속적인 학습이 되려면, Generator에서 이전데이터를 정확하게 잘 생성하고, Solver는 Generator가 생성한 데이터만으로도 이전에 학습했던 데이터의 분포를 잘 재현 할 수 있어야 할텐데, 논문에서는 이 부분을 MNIST데이터를 사용하여 실험하고 진행했습니다.

3. Summary

지금까지 3개 논문들에게서 제안된 Continual Learning 방법을 간단하게 알아보았습니다. 정리하자면 Continual Learning은 딥뉴럴넷이 새로운 학습시 이전 데이터를 잊어버리는 단점(Catastrophic forgetting)을 극복하기 위한 연구들입니다.

그리고 소개된 3가지 논문에 대해 간단히 비교하였으며 내용은 다음과 같습니다.

Continual Learning에 관한 논문 3가지와 간단한 비교.

본 포스팅에서는 학습 방법만 간단하게 나열하였고, 전체 실험과 모델 성능은 기재하지를 못했지만, 논문들에서는 대부분 MNIST 데이터를 이용하여 자세하게 성능 기재가 되어있으며, 데이터별로 의미있는 성능이 보이는 것 같습니다.

그리고 세 논문 모두 2017년 또는 이전에 발표된 논문이므로, 최신 연구결과는 더 많을 것으로 생각되지만, 딥뉴럴넷을 이용한 Continual Learning의 방법들을 제시한 논문으로써 의미가 있어보였기에 간단히 정리하게 되었습니다. 이 논문들과 새로운 논문들의 자세한 성능과 내용은 다음 포스팅에서 다루어보도록 하겠습니다.

Chanjong Park
Chanjong Park

Written by Chanjong Park

Deep Learning Researcher&Developer