Google JAX란 무엇입니까? 알아야 할 모든 것

Google JAX, 즉 “Just After Execution”은 머신 러닝 작업 처리 속도를 향상시키기 위해 Google에서 개발한 프레임워크입니다.

더욱 빠른 작업 실행, 과학적 컴퓨팅, 함수 변환, 딥 러닝, 신경망 등 다양한 분야에서 활용 가능한 파이썬 라이브러리라고 볼 수 있습니다.

Google JAX란 무엇인가?

파이썬에서 가장 기본적인 계산 패키지인 NumPy는 집계, 벡터 연산, 선형 대수, 다차원 배열 및 행렬 조작과 같은 다양한 기능을 제공합니다.

만약 NumPy를 기반으로 수행되는 계산 속도를 더 높일 수 있다면, 특히 대규모 데이터셋을 다룰 때 어떤 일이 일어날까요?

코드 변경 없이 GPU 또는 TPU와 같은 다양한 프로세서에서 동일하게 효율적으로 작동하는 방법은 없을까요?

시스템이 구성 가능한 함수 변환을 자동으로 더 효과적으로 처리할 수 있다면 어떨까요?

Google JAX는 이러한 요구를 충족하는 라이브러리 (또는 위키피디아에서 말하는 프레임워크)입니다. 성능 최적화 및 머신 러닝(ML) 및 딥 러닝 작업을 효율적으로 처리하기 위해 설계되었습니다. Google JAX는 다른 ML 라이브러리와 차별화되는 독특한 변환 기능을 제공하여 딥 러닝 및 신경망을 위한 고급 과학 컴퓨팅을 지원합니다. 주요 기능은 다음과 같습니다.

  • 자동 미분
  • 자동 벡터화
  • 자동 병렬화
  • JIT(Just-In-Time) 컴파일

Google JAX의 핵심 기능

모든 변환은 성능 및 메모리 최적화를 위해 XLA(Accelerated Linear Algebra)를 활용합니다. XLA는 선형 대수 연산을 수행하고 TensorFlow 모델을 가속화하는 도메인별 최적화 컴파일러 엔진입니다. XLA를 Python 코드와 함께 사용하면 코드 변경을 최소화하면서 성능 향상을 누릴 수 있습니다!

이러한 각 기능을 좀 더 자세히 살펴보겠습니다.

Google JAX의 주요 특징

Google JAX는 성능을 개선하고 딥 러닝 작업을 더욱 효율적으로 수행하기 위해 중요한 구성 가능한 변환 기능을 제공합니다. 예를 들어, 자동 미분은 함수의 기울기를 계산하고 모든 차수의 도함수를 찾습니다. 마찬가지로 자동 병렬화 및 JIT는 여러 작업을 동시에 수행할 수 있게 합니다. 이러한 변환들은 로봇, 게임 및 연구 분야와 같은 다양한 응용 프로그램에서 핵심적인 역할을 합니다.

구성 가능한 변환 함수는 데이터 집합을 다른 형태로 변환하는 순수 함수입니다. 독립적이고 (즉, 이러한 함수는 프로그램의 다른 부분에 의존하지 않음) 상태 비저장 (즉, 동일한 입력은 항상 동일한 출력을 산출함)이기 때문에 구성 가능하다고 합니다.

Y(x) = T: (f(x))

위의 방정식에서 f(x)는 변환이 적용될 원본 함수입니다. Y(x)는 변환이 적용된 후의 결과 함수입니다.

예를 들어, ‘total_bill_amt’라는 함수가 있고 이 함수에 변환을 적용하고 싶다면, 기울기(grad)와 같은 원하는 변환을 적용하면 됩니다.

grad_total_bill = grad(total_bill_amt)

grad()와 같은 함수를 사용하여 수치 함수를 변환하면 고차 도함수를 쉽게 얻을 수 있으며, 이는 경사하강법과 같은 딥 러닝 최적화 알고리즘에서 널리 활용되어 알고리즘을 더 빠르고 효율적으로 만듭니다. 마찬가지로 jit()를 사용하여 Python 프로그램을 적시에(게으르게) 컴파일할 수 있습니다.

#1. 자동 미분

Python은 autograd 함수를 사용하여 NumPy 및 기본 Python 코드를 자동으로 미분합니다. JAX는 수정된 버전의 autograd(즉, grad)를 사용하고 XLA(Accelerated Linear Algebra)를 결합하여 자동 미분을 수행하고 GPU(그래픽 처리 장치) 및 TPU(텐서 처리 장치)에 대한 모든 차수의 도함수를 찾습니다.

TPU, GPU 및 CPU에 대한 간략한 설명: CPU(중앙 처리 장치)는 컴퓨터의 모든 작업을 관리합니다. GPU는 컴퓨팅 성능을 향상시키고 고급 작업을 실행하는 추가 프로세서입니다. TPU는 AI 및 딥 러닝 알고리즘과 같은 복잡하고 무거운 워크로드를 위해 특별히 설계된 강력한 장치입니다.

루프, 재귀, 분기 등을 포함하여 미분 가능한 autograd 함수와 마찬가지로 JAX는 역방향 모드 기울기(역전파)에 grad() 함수를 사용합니다. 또한 grad를 사용하여 함수를 어떤 차수로든 미분할 수 있습니다.

grad(grad(grad(sin θ))) (1.0)

고차 자동 미분

앞서 언급했듯이 grad는 함수의 편도함수를 찾는 데 매우 유용합니다. 딥 러닝에서 손실을 최소화하기 위해 신경망 매개변수에 대한 비용 함수의 기울기 하강을 계산하는 데 편도함수를 사용할 수 있습니다.

편도함수 계산

함수에 여러 변수 x, y 및 z가 있다고 가정해 봅시다. 다른 변수를 일정하게 유지하면서 한 변수의 도함수를 찾는 것을 편도함수라고 합니다. 함수가 다음과 같다고 가정합니다.

f(x,y,z) = x + 2y + z2

편도함수를 보여주는 예시

x의 편도함수는 ∂f/∂x가 되며, 이는 다른 변수가 일정하게 유지될 때 변수에 대한 함수가 어떻게 변하는지 알려줍니다. 이 작업을 수동으로 수행하려면 미분하는 프로그램을 작성하고 각 변수에 적용한 다음 경사하강법을 계산해야 합니다. 이는 여러 변수를 다루어야 할 때 복잡하고 시간이 많이 걸리는 작업이 될 수 있습니다.

자동 미분은 함수를 +, -, *, / 또는 sin, cos, tan, exp와 같은 기본 연산 세트로 분해한 다음 연쇄 규칙을 적용하여 도함수를 계산합니다. 이 작업은 정방향 모드와 역방향 모드 모두에서 수행할 수 있습니다.

이게 다가 아닙니다! 이 모든 계산은 매우 빠르게 처리됩니다 (위와 같은 백만 번의 계산에 시간이 걸릴 수 있다는 것을 생각해 보세요!). XLA는 속도와 성능을 관리합니다.

#2. 가속 선형 대수학

이전 방정식을 다시 가져와 보겠습니다. XLA가 없으면 계산에 3개 이상의 커널이 필요하며 각 커널은 작은 작업을 수행합니다. 예를 들어,

커널 k1 –> x * 2y (곱셈)

k2 –> x * 2y + z (덧셈)

k3 -> 감소

XLA에서 동일한 작업을 수행하면 단일 커널이 모든 중간 작업을 융합하여 처리합니다. 메모리 사용량을 줄이고 속도를 높이기 위해 기본 연산의 중간 결과를 메모리에 저장하지 않고 스트리밍합니다.

#3. 적시 컴파일

JAX는 내부적으로 XLA 컴파일러를 사용하여 실행 속도를 높입니다. XLA는 CPU, GPU 및 TPU의 속도를 향상시킬 수 있습니다. 이 모든 것이 JIT 코드 실행을 통해 가능합니다. jit을 사용하려면 다음처럼 가져오기를 통해 사용할 수 있습니다.

from jax import jit
def my_function(x):
  … 코드 내용 …
my_function_jit = jit(my_function)

또는 함수 정의 위에 jit을 데코레이터로 사용할 수도 있습니다.

@jit
def my_function(x):
  … 코드 내용 …

이 코드는 Python 인터프리터를 사용하는 대신 호출자에게 컴파일된 버전의 코드를 반환하기 때문에 훨씬 빠릅니다. 이는 배열 및 행렬과 같은 벡터 입력에 특히 유용합니다.

기존의 모든 파이썬 함수도 마찬가지입니다. 예를 들어 NumPy 패키지의 함수도 그렇습니다. 이 경우 NumPy가 아닌 jnp로 jax.numpy를 가져와야 합니다.

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

이렇게 하면 DeviceArray라는 핵심 JAX 배열 객체가 표준 NumPy 배열을 대체합니다. DeviceArray는 필요할 때까지 값을 가속기에 보관하는 방식으로 작동합니다. 또한 JAX 프로그램은 결과가 (파이썬) 프로그램으로 반환될 때까지 기다리지 않기 때문에 비동기 디스패치를 따릅니다.

#4. 자동 벡터화(vmap)

일반적인 머신 러닝 환경에서는 백만 개 이상의 데이터 포인트를 포함하는 데이터 세트를 자주 접하게 됩니다. 대부분의 경우 이러한 데이터 포인트 각각에 대해 일부 계산 또는 조작을 수행해야 합니다. 이는 시간과 메모리를 많이 소모하는 작업이 될 수 있습니다! 예를 들어, 데이터 세트에서 각 데이터 포인트의 제곱을 찾고 싶다면 가장 먼저 떠오르는 것은 루프를 생성하고 각 제곱을 하나씩 계산하는 것입니다.

이러한 데이터 포인트를 벡터로 생성하면 NumPy를 사용하여 데이터 점에 대해 벡터 또는 행렬 조작을 수행하여 모든 제곱을 한 번에 처리할 수 있습니다. 만약 프로그램이 자동으로 이 작업을 수행할 수 있다면 더 바랄 것이 있을까요? 바로 그것이 JAX의 기능입니다! 모든 데이터 포인트를 자동으로 벡터화하여 모든 연산을 쉽게 수행할 수 있게 하므로 알고리즘을 더욱 빠르고 효율적으로 만들 수 있습니다.

JAX는 자동 벡터화를 위해 vmap 함수를 사용합니다. 다음 배열을 예시로 들어보겠습니다.

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

위와 같이 코드를 실행하면 배열의 각 점에 대해 square 메서드가 실행됩니다. 하지만 다음처럼 코드를 변경하면:

vmap(jnp.square(x))

함수를 실행하기 전에 데이터 포인트가 vmap 메서드를 사용하여 자동으로 벡터화되고 루프가 연산의 기본 수준으로 푸시되므로 square 메서드는 단 한 번만 실행됩니다.

#5. SPMD 프로그래밍 (pmap)

SPMD(Single Program Multiple Data) 프로그래밍은 딥 러닝 환경에서 필수적입니다. 여러 GPU 또는 TPU에 걸쳐 있는 서로 다른 데이터 세트에 동일한 기능을 적용해야 하는 경우가 많습니다. JAX에는 여러 GPU 또는 모든 가속기에서 병렬 프로그래밍을 가능하게 하는 pmap이라는 함수가 있습니다. JIT와 마찬가지로 pmap을 사용하는 프로그램은 XLA에 의해 컴파일되고 시스템 전체에서 동시에 실행됩니다. 이 자동 병렬화는 순방향 및 역방향 계산 모두에 적용됩니다.

pmap은 어떻게 작동할까요?

다음과 같이 함수에 대해 어떤 순서로든 한 번에 여러 변환을 적용할 수도 있습니다.

pmap(vmap(jit(grad(f(x))))))

여러 구성 가능한 변환

Google JAX의 한계

Google JAX 개발자들은 이러한 유용한 변환을 도입하여 딥 러닝 알고리즘의 속도를 높이는 데 심혈을 기울였습니다. 과학적 컴퓨팅 기능과 패키지는 NumPy와 유사하므로 학습 곡선에 대한 걱정을 크게 덜 수 있습니다. 하지만 JAX에는 다음과 같은 몇 가지 한계가 있습니다.

  • Google JAX는 아직 개발 초기 단계에 있으며 주요 목적은 성능 최적화이지만 CPU 컴퓨팅에는 큰 이점을 제공하지 못할 수 있습니다. NumPy가 더 나은 성능을 보이는 경우도 있으며 JAX를 사용하면 오버헤드만 추가될 수 있습니다.
  • JAX는 여전히 연구 단계에 있거나 초기 단계에 있으며, TensorFlow와 같이 더 많은 사전 정의 모델, 오픈 소스 프로젝트 및 학습 자료를 갖춘 더 확립된 프레임워크의 인프라 표준에 도달하려면 추가적인 개선이 필요합니다.
  • 현재 JAX는 Windows 운영 체제를 지원하지 않습니다. JAX가 작동하려면 가상 머신이 필요합니다.
  • JAX는 부작용이 없는 순수 함수에서만 작동합니다. 부작용이 있는 함수의 경우 JAX가 좋은 선택이 아닐 수도 있습니다.

Python 환경에 JAX를 설치하는 방법

시스템에 Python이 설정되어 있고 로컬 시스템(CPU)에서 JAX를 실행하려면 다음 명령을 사용하세요.

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

GPU 또는 TPU에서 Google JAX를 실행하려면 GitHub JAX 페이지에 제공된 지침을 따르세요. Python을 설정하려면 Python 공식 다운로드 페이지를 방문하세요.

결론

Google JAX는 효율적인 딥 러닝 알고리즘, 로봇 공학 및 연구를 개발하는 데 탁월한 선택입니다. 몇 가지 제한 사항이 있지만 Haiku, Flax와 같은 다른 프레임워크와 함께 널리 사용되고 있습니다. 프로그램을 실행할 때 JAX가 실제로 어떻게 작동하는지 이해하고 JAX를 사용하거나 사용하지 않고 코드를 실행할 때의 시간 차이를 직접 확인할 수 있습니다. 공식 Google JAX 문서를 읽는 것으로 시작해 보세요. 이 문서는 매우 포괄적인 정보를 제공합니다.