Loading [MathJax]/jax/output/CommonHTML/jax.js
본문 바로가기

학술/deep learning

DDPM에서 노이즈 예측 신경망 평균 공식 유도

다음은 DDPM에서 역확산 단계의 평균 μq(xt,x0)μq(xt,x0)=1αt(xt1αt1ˉαtϵ) 로 계산되는 과정을 단계별로 유도하는 방법입니다. 1. Forward Process에서 xt 표현 DDPM의 forward process에서는 xt=ˉαtx0+1ˉαtϵ,ϵN(0,I), 로 정의됩니다. 여기서 ˉαt=ts=1αs. 이를 x0에 대해 풀면, x0=1ˉαt(xt1ˉαtϵ). --- 2. 역확산 단계의 평균 μq(xt,x0) 기본 식 논문 등에서 증명된 바에 따르면, q(xt1xt,x0)=N(xt1;μq(xt,x0),˜βtI), 의 평균은 μq(xt,x0)=ˉαt1βt1ˉαtx0+αt(1ˉαt1)1ˉαtxt, 로 주어집니다. 여기서βt=1αt이며, ˉαt=ˉαt1αt 관계를 사용합니다. 3. x0ϵ로 표현하여 대입 1. x0를 위의 forward process 식에서 표현한 결과를 대입합니다. x0=1ˉαt(xt1ˉαtϵ). 2. 이를 μq(xt,x0)에 대입하면: μq(xt,x0)=ˉαt1βt1ˉαt1ˉαt(xt1ˉαtϵ)+αt(1ˉαt1)1ˉαtxt=ˉαt1βtˉαt(1ˉαt)xtˉαt1βt1ˉαt1ˉαtˉαtϵ+αt(1ˉαt1)1ˉαtxt. 3. xt에 대한 계수를 묶어 정리하면: Coefficient of xt=ˉαt1βtˉαt(1ˉαt)+αt(1ˉαt1)1ˉαt. **여기서** ˉαt=ˉαt1αt이므로, ˉαt=ˉαt1αt이고, βt=1αt임을 고려하면, ˉαt1βtˉαt(1ˉαt)=1αtαt(1ˉαt). 따라서 두 항의 합은: 1αtαt(1ˉαt)+αt(1ˉαt1)1ˉαt=11ˉαt(1αtαt+αt(1ˉαt1)). **주목할 점:** 1ˉαt는 다음과 같이 쓸 수 있습니다. 1ˉαt=1αtˉαt1=(1αt)+αt(1ˉαt1). 그러므로, 1αtαt+αt(1ˉαt1)=(1αt)+αt(1ˉαt1)αt=1ˉαtαt. 최종적으로 xt의 계수는: 11ˉαt1ˉαtαt=1αt. 4. ϵ에 대한 항은: ˉαt1βt1ˉαt1ˉαtˉαtϵ. 동일하게 βt=1αtˉαt=ˉαt1αt를 사용하면, ˉαt1(1αt)1ˉαt1ˉαtˉαt1αt=1αtαt1ˉαt. 5. 따라서 최종적으로 μq(xt,x0)=1αtxt1αtαt1ˉαtϵ, 즉, μq(xt,x0)=1αt(xt1αt1ˉαtϵ). 4. 노이즈 예측 네트워크와의 연결 실제 학습에서는 네트워크 eθ(xt,t)ϵ를 예측하도록 설계합니다. 따라서 위 식에서 ϵeθ(xt,t)로 대체하여 μq(xt,x0)=1αt(xt1αt1ˉαteθ(xt,t)) 와 같이 사용하게 됩니다. 요약 1. **Forward process**에서 xtx0ϵ의 선형 결합으로 표현한 후, 2. **역확산 단계**의 평균 μq(xt,x0)x0xt의 선형 결합으로 주어짐을 확인합니다. 3. x0xtϵ로 표현한 식을 대입하고, 계수를 정리하면 xt의 계수가 1/αt가 되고, ϵ에 대한 항이 (1αt)/(αt1ˉαt)가 되어 최종 식이 도출됩니다. 이와 같이, 위 유도 과정을 통해 DDPM의 역확산 단계에서 μq(xt,x0)가 위와 같이 계산되는 이유를 알 수 있습니다.