-
[Paper] Gradient Multi-Normalization for Stateless and Scalable LLM TrainingML engineer/Papers & CS generals 2025. 9. 11. 14:06반응형
https://arxiv.org/abs/2502.06742
Gradient Multi-Normalization for Stateless and Scalable LLM Training
Training large language models (LLMs) typically relies on adaptive optimizers like Adam (Kingma & Ba, 2015) which store additional state information to accelerate convergence but incur significant memory overhead. Recent efforts, such as SWAN (Ma et al., 2
arxiv.org
알고리즘 논문 답게, 최근의 LLM 모델 리포트 논문들과 달리 이론적 배경이 많이 포함되어 있어 제대로 읽는데 시간이 좀 걸렸네요.
먼저 섹션 3 "Multi-Normalized Gradient Descent" 부터 분석 해보겠습니다.
# Section 3 Big Idea
핵심으로, Large Language Model(LLM) 학습에 도움이 되는 새로운 일반화 가능한 optimizer 알고리즘을 소개 하고 있습니다.
목표는 stateless 즉 모델 파라미터 이외의 메모리를 사용하지 않는 optimizer를 만드는것입니다. Adam 과 같은 transformer 모델 학습에 주로 사용되는 optimizer는 해당 step의 gradient 외에 이전 gradient에 대한 정보를 momentum의 형태로 저장합니다.
(즉, 파라미터가 커질 수 록 optimizer가 그에 비례한 state를 저장해야하므로 메모리 사용량이 모델 크기에 비례해서 커지는 문제가 있습니다.)저자의 핵심 아이디어는 각 학습 step마다 gradient를 normalize하는것으로 전처리 해서 일정한 size 또는 length를 가지도록 합니다. 기존에도 normalizing하는 기법들 자체는 있었지만, 한가지 규칙에 의한 (single norm) 정규화만 했으나 (뒤에서 언급되겠지만 최근에 다른 논문에서도 2개의 norm을 이용하는 방법이 있긴 합니다) 본 논문에서는 K-개의 (multiple rules/norms) 정규화를 거칩니다.
이를 논문에선 Multi-Normalized Gradient Descent (MNGD)라 부릅니다.
## Notations
먼저 세부 알고리즘에 앞서 논문에 나온 수식을 해석 해보면:
- ∇ (or ∇t): 흔히 gradient 라고 부르죠? 모델의 error (loss)가 가장크게 증가하는 방향을 나타내는 행렬 또는 벡터라고 생각하면 됩니다. 모델을 학습 시키기 위해선 모델 weight가 gradient의 반대 방향으로 이동하면 됩니다.
- || . || (or g): 이 표기는 흔히 norm 을 나타냅니다. 수학적으로는 보통 벡터 또는 행렬의 size 또는 magnitude로 해석하면 되겠고요. 벡터 또는 행렬의 크기를 측정 하는 방법은 다양한데요, 그래서 본 논문에서는 최적의 gradient descent를 위해 multiple norm를 사용하는것이라 설명합니다.
- z: 보통 최적화 문제에서 변수죠 placeholder variable. arg max (∇, z) 같은 문제가 주어지면 보통, "(∇, z)를 최대화 하는 벡터 z를 찾아라."와 동일한 의미입니다. 그러므로, z 는, 하나의 가능한 후보 벡터인셈이고 특정 조건들을 충족하는 해를 찾는 과정입니다. 이러한 z를 찾는다면, 그게 우리가 사용할 normalized gradient가 되는겁니다.
- (∇, z): 대게 함수 처럼 보이겠지만, 모델 업데이트라는걸 알고 있으니.. gradient ∇ 와 vector z 간의 dot product (또는 inner product) 로 해석 하면 됩니다. Dot product는 z가 ∇와 같은 방향일때 가장 커집니다.
## Gradient Multi-Normalization: 논문의 핵심 문제
이제 논문의 핵심 수식들을 하나 하나 살펴 보겠습니다.
Equation (4): P||.||(x) := arg max(x, z) s.t. ||z||=1
- 수식의 의미: 벡터 x의 norm ||. ||에 대한 normalized projection입니다.
- 다소 의역:
- s.t. ||z||=1: 최적화 문제에서 자주 보는 조건이죠. "벡터 z의 ||.|| norm 사이즈가 1", 즉 ||.|| 에 대하여 unit sphere에 해당하는 조건입니다.
- arg max(x, z): "unit size 벡터들 중에서 원본 x와 가장 같은 방향인 벡터 z를 찾아라" (앞서 언급했듯이 dot product는 같은 방향의 벡터들 간에 가장 큰 값을 가진다고 했죠?)
- 요약하면: 이 수식은 x와 같은 방향의 unit-length 벡터를 의미 합니다.
Equation (5): arg max(∇, z) s.t. ∀ i ∈ [1, K], g_i(z) = 1
- 수식의 의미: 이상적인, 하지만 상당히 어려울 multi-normalization 문제입니다.
- 다소 의역:
- s.t. ∀ i ∈ [1, K], g_i(z) = 1: "모든 K-normalization 방법에 대해서 사이즈가 동시에, 정확히 1인 벡터 z를 찾아라!"
- arg max(∇, z): "이 조건을 충족한 벡터 중에서 원본 gradient ∇와 같은 방향인 벡터 하나를 골라라"
- 문제..: 문제는 저자들이 언급하듯이, 이건 NP-hard 문제입니다. 알고리즘으로 현실적으로 직접 풀기에는 너무 어려운 문제라는 소리죠. 즉, 동시에 여러가지 측정 규칙에 따라 해당 조건을 만족하는 벡터 z를 진짜로 찾는건 거의 불가능하다는 뜻입니다. (시간과 자원이 무한하지 않으니) - 어찌 보면 다행이죠! NP-hard문제들은 늘 approximate solution을 내는것으로 논문 거리가 되니까 말이죠.
Algorithm 1: MultiNorm(∇, L, g)
앞서 살펴본 수식 (5)를 직접 풀수 없다고 했으니, 저자는 현실적인 근사 방법(해결책)을 내놓습니다.
- 무엇을 하는가?: normalization 조건을 한개씩 (k개를 동시에 하지 않고, 한번에 1개씩) 반복문을 통해 충족 시킵니다.
- 단계별 해석:
- 입력: 원본 gradient ∇, 반복 횟수 L, 그리고 K개의 normalization 함수 g.
- 초기값: 변수 x를 만들고 원본 gradient ∇로 둡니다.
- 첫번째 반복문 (for l = 1 to L): 전체 과정을 L번 반복 하는 역할로, 반복이 늘어날 수 록 실제 솔루션에 수렴해갑니다.
- 내부 반복문 (for i = 1 to K): 각 K normalization을 수행합니다.
- x ← P_gi(x): 매 스텝 마다 x를 현재 norm인 g_i에 대하여 projection합니다. (수식 (4)처럼) x가 i번째 (K개 중의 하나) 조건을 충족하도록 합니다.
- Return x: L번 반복 후 업데이트된 x를 최종적으로 리턴 합니다. 즉 multi-normalized gradient의 근사값인 셈이죠.
SWAN: An Instance of MultiNorm
논문에서는 여기서 더 나아가, 기존 연구들 중에 SWAN 알고리즘이 앞서 설명한 MultiNorm 방식의 특수한 케이스임을 보입니다. (이론적으로 참 탄탄한 논문이네요.) 즉 앞서 설명한 방법이 "일반"화 가능한 형태임을 보이는것입니다.
(SWAN도 당연히 해당 저자 논문일 줄 알았는데 다른 저자 논문이였더라고요.)간단히 말하면, SWAN은 K = 2인, 즉 g1, g2에 대해서만 수행하는 multi-normalization 아이디어다 이 말씀.
- g1: 행렬의 행의 크기를 기반으로하는 normalization으로 모든 행이 비슷한 magnitude를 가지도록 합니다.
- g2: 행렬의 singular value(즉 spectral norm)로 normalization을 하고, 이를 "whitening"이라고 부르는데, 각 gradient의 요소들이 서로 상관관계를 가지지 않도록 해주는 효과라고 하네요.
이 부분은 사실 그냥 본인들 논문의 이론적 근거를 다지는 섹션이지만, 이 논문을 통해 오히려 SWAN 논문을 더 잘 이해할 수 있게 되었네요.
3.2 MultiNorm의 수렴함을 증명
알고리즘 1이 실제로 수렴하는 안정적인 알고리즘이여야겠죠? 사실 동시에 조건을 충족 시킨다는게 하나씩 충족 시킨다고 반드시 되는건 아닐테니까요. (그냥 값이 이리 저리 튀어 다니고 수렴하지 않을 가능성도 높죠)
Theorem 3.6
- 무엇을 증명하는건가?: 알고리즘 1을 여러번 반복 하면(L이 충분히 크면) 리턴값 x는 fixed-point에 수렴 함을 보인다.
- 그럼 fixed-point가 무엇인가? fixed-point 는 K-normalization 중 어떤걸 적용해도 값이 변하지 않는 벡터입니다. 즉, 모든 조건에 대해서 unit vector인 상태에 도달한 것으로 NP hard라 불렀던 문제의 최적해 되겠습니다. (P_g1(x) = x & P_g2(x) = x)
- 어쩌란 말인가?: 즉, 이 알고리즘을 충분히 많이 돌리(?)면 최적해에 수렴한다는 의미입니다. 물론 L을 적게 수행하면 근사값에 그치겠지만요.
자세한 증명은
이해하기가 난해하니.. 뭐 맞다고 칩시다논문을 참고..3.3 MNGD: 새로운 Stateless Optimizer
마지막으로 앞서 언급한 조각들을 한데 모아 새로운 stateless optimizer를 설명합니다. (별건 없고 그냥 위에 언급한 내용들을 알고리즘 1에 붙인것입니다.)
Algorithm 2: Multi-Normalized GD (MNGD)
- 무엇을 하는가?: 새로운 전처리를 통한 전체 gradient descent 학습 반복문입니다.
- 단계별 해석:
각 학습 스텝마다:- ∇t ← ∇θL(θt, x(t)): 현재 batch에 대해서 gradient를 구합니다.
- ∇t ← MultiNorm(∇t, L, g): (핵심 파트). 원본 gradient ∇t 를 알고리즘 1에 의해 (L번, k-norm 반복) 업데이트 합니다.
- θt+1 ← θt - ηt∇t: 모델 학습 과정이죠? 모델 weight(θ)를 약간(ηt) 업데이트 합니다. 어떻게? 방금 구한 gradient ∇t 방향으로요.
이렇게 MNGD 알고리즘을 완성 시켰습니다. 사실 알고 나면 별 내용이 없는데, 이게 된다는걸 발견하고 이론과 실험을 통해 증명한게 본 논문의 업적(?)이겠네요.
# Section 4 Big Idea: SinkGD
이제 MNGD를 바탕으로, 기존의 MNGD의 일종이라고 언급한 SWAN optimizer를 한번 더 개선합니다.
앞서 설명 했듯이 SWAN에는 두가지 normalization 스텝이 있는데요 (전처리 과정)- (1) Row wise normalization은 효과적이고 계산하기도 쉬운 방법이라 건드리지 않고
- (2) Spectral normalization은 singular value를 구하는 과정을 거치므로 연산량이 상당합니다. 즉 모델이 커지면 느려진다는 소리죠.여기에 착안하여 (2) 단계를 저자는 새로운 normalizer로 교체 합니다. 이를 column-wise normalization이라고 부르는데요, (1)과 유사하게, 다만 row 대신 col 방향으로 normalization을 취합니다. 근데 여기서 이 부분 역시 그냥 대충 해보고 어 되네? 하는게 아니라 (실제로는 해보고 어 되네? 이론을 찾아보자.. 순서 였을 수 도 있겠지만 ㅎㅎ)
Sinkhorn 알고리즘이라는 기존에 있는, 빠르게 수렴하는 알고리즘에서 착안했다고 합니다.
Sinkhorn알고리즘은 찾아보면 복잡해보이지만, 사실 스도쿠 아시죠? 스도쿠 처럼 가로 세로 합이 일정한 행렬을 만드는것입니다. (가로 세로 합이 1이 되도록 하는)따라서 저자가 제안하는 두가지 normalization은 다음과 같습니다.
- g1(W) (Row-wise l2-norm): max ||W_i,:||₂ / √n
- 무엇을 하나?: SWAN에서와 동일한 row-wise norm.
- 직역해보면: 가장 크기가 큰 row를 gradient 행렬에서 찾아서 (l2-norm 기준) 그 크기로 맞춥니다.
- 다소 의역: 신경망 네트워크 모델의 레이어를 생각해보면, weight 행렬에서 row는 하나의 출력에 해당합니다. 따라서 g1의 역할은 모든 출력 값이, 즉 gradient update 한번에 출력 뉴런(?)들이 일정한 수준의 magnitude를 가지도록 해서 특정 output만 유독 크거나 작은 값을 가지지 않게 해주는 효과?
- g2(W) (Column-wise l2-norm): max ||W_:,j||₂ / √m
- 무엇을 하나?: SWAN의 g2를 대체 합니다.
- 직역해보면: 마찬가지로 가장 크기가 큰 column을 찾아서 그 크기로 맞춥니다.
- 다소 의역: 각 column은 row와 마찬가지로 해석하면, 하나의 입력 feature로 볼 수 있습니다. 따라서 g2의 역할은 gradient update가 특정 feature하나에 너무 치중되지 않도록 normalize해주는 효과라고 볼 수 있습니다. 즉, 모델 업데이트에 모든 입력 feature가 고르게 영향을 미치게끔 하는것이죠.
이제 이 Sinkhorn 알고리즘을 연결 짓는 과정이 매우 자연스러워집니다..
1960대에 소개된 알고리즘인데, 원래 목적은 행렬의 가로 세로가 합이 항상 1이 되는 행렬을 만드는 알고리즘이였습니다. 행 방향으로 한번, 열 방향으로 한번씩 번갈아 가면서 normalize 하는 것이죠. 위에서 제안된 방식과 동일하죠? 이미 해당 알고리즘에서 이렇게 번갈아 normalize하는것만으로 수렴하는것을 증명 해두었으니 말그대로 거인의 어깨에 올라타는 셈이죠.
(차이라면, 그냥 1/n, 1/m 하는게 아니라 루트를 취했다는 차이만 있습니다.)Algorithm 4: Sinkhorn GD (SinkGD)
자 드디어 꾸역 꾸역 해석 해보다 보니, 최종 종착점에 도달했네요, 지금까지 확인한 내용을 한데 모아 알고리즘으로 표현한것 뿐입니다.
- 알고리즘의 목적: 앞서 언급한 MNGD (Algorithm 2) 하에, 좀 더 연산 효율적인, Sinkhorn 알고리즘을 이용한 Optimizer 입니다.
- 각 단계별 설명:
각 학습 스텝마다:- ∇t ← ∇θL(θt, x(t)): Gradient를 먼저 구하고,
- ∇t ← SR-Sinkhorn(∇t, L): row/column 방향으로 각각 normalize 해주는 전처리 과정을 L번 반복 해줍니다.
- θt+1 ← θt - ηt∇t: 이렇게 업데이트된 gradient 방향으로 모델 파라미터를 업데이트 합니다.
# 결론
- Multi-normalization 프레임웍 하에 잘 정돈된 알고리즘이고,
- 기존 SWAN 알고리즘의 비교적 연산이 느린 whitening 과정을 훨씬 빠르게 계산 가능한, column wise normalization으로 대체해서,
- 안정적이고 빠르게 수렴하는 optimizer를 이론적 근거를 가지고 제안합니다.
실험을 통해서는 state-less인 만큼 메모리 사용량이 대폭 줄어듦을 보이고, 근소하게나마 학습도 기존의 Adam이나 SWAN등에 비해서도 빠르게 잘 수렴함을 보입니다. (아마 비용적인 문제로 모델 크기는 1.3B 까지만 scale 해서 실험을 진행했네요.)
반응형'ML engineer > Papers & CS generals' 카테고리의 다른 글
IaaS / PaaS / SaaS (0) 2023.09.22 [Python] Pickle에 대한 오해와 Can’t Pickle local object Error 해결 (0) 2023.02.16 [coding] Notes on space complexity (0) 2023.02.06 Domain Name System (DNS) 개요 (0) 2023.01.22 분산 시스템 디자인 (0) 2023.01.19