Google JAX란 무엇입니까? 알아야 할 모든 것

Google JAX 또는 Just After Execution은 기계 학습 작업의 속도를 높이기 위해 Google에서 개발한 프레임워크입니다.

더 빠른 작업 실행, 과학 컴퓨팅, 함수 변환, 딥 러닝, 신경망 등에 도움이 되는 Python용 라이브러리라고 생각할 수 있습니다.

Google JAX 정보

Python에서 가장 기본적인 계산 패키지는 집계, 벡터 연산, 선형 대수, n차원 배열 및 행렬 조작, 기타 여러 고급 기능과 같은 모든 기능을 포함하는 NumPy 패키지입니다.

NumPy를 사용하여 수행되는 계산 속도를 더욱 높일 수 있다면, 특히 거대한 데이터 세트의 경우 어떻게 될까요?

코드 변경 없이 GPU 또는 TPU와 같은 다양한 유형의 프로세서에서 똑같이 잘 작동할 수 있는 것이 있습니까?

시스템이 구성 가능한 함수 변환을 자동으로 더 효율적으로 수행할 수 있다면 어떨까요?

Google JAX는 바로 그 이상을 수행하는 라이브러리(또는 Wikipedia에서 말하는 프레임워크)입니다. 성능을 최적화하고 머신 러닝(ML) 및 딥 러닝 작업을 효율적으로 수행하도록 구축되었습니다. Google JAX는 다른 ML 라이브러리와 차별화되고 딥 러닝 및 신경망을 위한 고급 과학 계산에 도움이 되는 다음과 같은 변환 기능을 제공합니다.

  • 자동 미분
  • 자동 벡터화
  • 자동 병렬화
  • JIT(Just-In-Time) 컴파일

Google JAX의 고유 기능

모든 변환은 더 높은 성능과 메모리 최적화를 위해 XLA(가속 선형 대수)를 사용합니다. XLA는 선형 대수학을 수행하고 TensorFlow 모델을 가속화하는 도메인별 최적화 컴파일러 엔진입니다. Python 코드 위에 XLA를 사용하면 코드를 크게 변경할 필요가 없습니다!

이러한 각 기능을 자세히 살펴보겠습니다.

Google JAX의 기능

Google JAX는 성능을 개선하고 딥 러닝 작업을 보다 효율적으로 수행하기 위해 중요한 구성 가능한 변환 기능을 제공합니다. 예를 들어, 자동 미분은 함수의 기울기를 구하고 모든 차수의 도함수를 찾습니다. 마찬가지로 자동 병렬화 및 JIT는 여러 작업을 병렬로 수행합니다. 이러한 변환은 로봇, 게임 및 연구와 같은 응용 프로그램의 핵심입니다.

구성 가능한 변환 함수는 데이터 집합을 다른 형식으로 변환하는 순수 함수입니다. 그것들은 독립적이고(즉, 이러한 함수는 프로그램의 나머지 부분과 종속되지 않음) 상태 비저장(즉, 동일한 입력이 항상 동일한 출력을 낳음)이기 때문에 구성 가능이라고 합니다.

Y(x) = T: (f(x))

위의 방정식에서 f(x)는 변환이 적용되는 원래 함수입니다. Y(x)는 변환이 적용된 후의 결과 함수입니다.

  Canva를 사용하여 전문가처럼 디자인하는 방법

예를 들어, ‘total_bill_amt’라는 이름의 함수가 있고 그 결과를 함수 변환으로 원하는 경우 원하는 변환, 즉 그래디언트(grad)를 사용하기만 하면 됩니다.

grad_total_bill = grad(total_bill_amt)

grad()와 같은 함수를 사용하여 수치 함수를 변환함으로써 고차 도함수를 쉽게 얻을 수 있으며, 이는 경사하강법과 같은 딥 러닝 최적화 알고리즘에서 광범위하게 사용할 수 있으므로 알고리즘을 더 빠르고 효율적으로 만듭니다. 마찬가지로 jit()를 사용하여 Python 프로그램을 적시에(게으르게) 컴파일할 수 있습니다.

#1. 자동 미분

Python은 autograd 함수를 사용하여 NumPy와 기본 Python 코드를 자동으로 구분합니다. JAX는 수정된 버전의 autograd(즉, grad)를 사용하고 XLA(Accelerated Linear Algebra)를 결합하여 자동 미분을 수행하고 GPU(그래픽 처리 장치) 및 TPU(텐서 처리 장치)에 대한 모든 차수의 도함수를 찾습니다.]

TPU, GPU 및 CPU에 대한 빠른 참고 사항: CPU 또는 중앙 처리 장치는 컴퓨터의 모든 작업을 관리합니다. GPU는 컴퓨팅 성능을 향상하고 고급 작업을 실행하는 추가 프로세서입니다. TPU는 AI 및 딥 러닝 알고리즘과 같은 복잡하고 과중한 워크로드를 위해 특별히 개발된 강력한 장치입니다.

루프, 재귀, 분기 등을 통해 구별할 수 있는 autograd 함수와 동일한 라인을 따라 JAX는 역방향 모드 기울기(역전파)에 grad() 함수를 사용합니다. 또한 grad를 사용하여 함수를 어떤 순서로든 구별할 수 있습니다.

grad(grad(grad(sin θ))) (1.0)

고차의 자동 미분

앞에서 언급했듯이 grad는 함수의 편도함수를 찾는 데 매우 유용합니다. 손실을 최소화하기 위해 딥 러닝에서 신경망 매개변수에 대한 비용 함수의 기울기 하강을 계산하기 위해 편도함수를 사용할 수 있습니다.

편도함수 계산

함수에 여러 변수 x, y 및 z가 있다고 가정합니다. 다른 변수를 일정하게 유지하여 한 변수의 도함수를 찾는 것을 편도함수라고 합니다. 함수가 있다고 가정하고,

f(x,y,z) = x + 2y + z2

편도함수를 보여주는 예

x의 편도함수는 ∂f/∂x가 되며, 이는 다른 변수가 일정할 때 변수에 대한 함수가 어떻게 변경되는지 알려줍니다. 이를 수동으로 수행하는 경우 미분하는 프로그램을 작성하고 각 변수에 적용한 다음 경사하강법을 계산해야 합니다. 이것은 여러 변수에 대해 복잡하고 시간 소모적인 일이 될 것입니다.

자동 미분은 함수를 +, -, *, / 또는 sin, cos, tan, exp 등과 같은 기본 연산 세트로 나눈 다음 연쇄 규칙을 적용하여 도함수를 계산합니다. 정방향 및 역방향 모드 모두에서 이 작업을 수행할 수 있습니다.

이게 아니야! 이 모든 계산은 매우 빠르게 이루어집니다(위와 비슷한 백만 계산과 시간이 걸릴 수 있음을 생각해보세요!). XLA는 속도와 성능을 관리합니다.

  Android에서 색맹을 보정하기 위해 화면 필터를 활성화하는 방법

#2. 가속 선형 대수학

이전 방정식을 취합시다. XLA가 없으면 계산에는 3개(또는 그 이상)의 커널이 필요하며 각 커널은 더 작은 작업을 수행합니다. 예를 들어,

커널 k1 –> x * 2y(곱셈)

k2 –> x * 2y + z(덧셈)

k3 -> 감소

XLA에서 동일한 작업을 수행하면 단일 커널이 모든 중간 작업을 융합하여 처리합니다. 기본 연산의 중간 결과를 메모리에 저장하지 않고 스트리밍하여 메모리를 절약하고 속도를 향상시킵니다.

#삼. 적시 컴파일

JAX는 내부적으로 XLA 컴파일러를 사용하여 실행 속도를 높입니다. XLA는 CPU, GPU 및 TPU의 속도를 높일 수 있습니다. 이 모든 것은 JIT 코드 실행을 사용하여 가능합니다. 이것을 사용하기 위해 import를 통해 jit을 사용할 수 있습니다:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

또 다른 방법은 함수 정의 위에 jit를 장식하는 것입니다.

@jit
def my_function(x):
	…………some lines of code

이 코드는 변환이 Python 인터프리터를 사용하는 대신 호출자에게 컴파일된 버전의 코드를 반환하기 때문에 훨씬 빠릅니다. 이것은 배열 및 행렬과 같은 벡터 입력에 특히 유용합니다.

기존의 모든 파이썬 함수도 마찬가지입니다. 예를 들어 NumPy 패키지의 함수입니다. 이 경우 NumPy가 아닌 jnp로 jax.numpy를 가져와야 합니다.

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

이렇게 하면 DeviceArray라는 핵심 JAX 배열 객체가 표준 NumPy 배열을 대체합니다. DeviceArray는 지연됩니다. 값은 필요할 때까지 가속기에 보관됩니다. 이것은 또한 JAX 프로그램이 결과가 호출(Python) 프로그램으로 반환될 때까지 기다리지 않으므로 비동기 디스패치를 ​​따릅니다.

#4. 자동 벡터화(vmap)

일반적인 기계 학습 세계에는 백만 개 이상의 데이터 포인트가 있는 데이터 세트가 있습니다. 대부분의 경우 이러한 데이터 포인트 각각 또는 대부분에 대해 일부 계산 또는 조작을 수행할 것입니다. 이는 매우 시간과 메모리를 소모하는 작업입니다! 예를 들어, 데이터 세트에서 각 데이터 포인트의 제곱을 찾고 싶다면 가장 먼저 생각할 수 있는 것은 루프를 만들고 제곱을 하나씩 취하는 것입니다.

이러한 점을 벡터로 생성하면 우리가 가장 좋아하는 NumPy를 사용하여 데이터 점에 대해 벡터 또는 행렬 조작을 수행하여 모든 정사각형을 한 번에 처리할 수 있습니다. 그리고 만약 당신의 프로그램이 이것을 자동으로 할 수 있다면 – 더 요청할 수 있습니까? 그것이 바로 JAX가 하는 일입니다! 모든 데이터 포인트를 자동으로 벡터화하여 모든 작업을 쉽게 수행할 수 있으므로 알고리즘을 훨씬 빠르고 효율적으로 만들 수 있습니다.

JAX는 자동 벡터화를 위해 vmap 함수를 사용합니다. 다음 배열을 고려하십시오.

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

위와 같이 하면 배열의 각 점에 대해 square 메서드가 실행됩니다. 그러나 다음을 수행하는 경우:

vmap(jnp.square(x))

함수를 실행하기 전에 데이터 포인트가 이제 vmap 메서드를 사용하여 자동으로 벡터화되고 루핑이 연산의 기본 수준으로 푸시되기 때문에 메서드 square는 한 번만 실행됩니다. .

  Xbox 360 컨트롤러를 PC에 연결할 수 없습니까? 여기 당신이해야 할 일입니다

#5. SPMD 프로그래밍(pmap)

SPMD 또는 단일 프로그램 다중 데이터 프로그래밍은 딥 러닝 컨텍스트에서 필수적입니다. 여러 GPU 또는 TPU에 있는 서로 다른 데이터 세트에 동일한 기능을 적용하는 경우가 많습니다. JAX에는 여러 GPU 또는 모든 가속기에서 병렬 프로그래밍을 허용하는 펌프라는 함수가 있습니다. JIT와 마찬가지로 pmap을 사용하는 프로그램은 XLA에 의해 컴파일되고 시스템 전체에서 동시에 실행됩니다. 이 자동 병렬화는 순방향 및 역방향 계산 모두에 작동합니다.

pmap은 어떻게 작동합니까?

다음과 같이 함수에 대해 어떤 순서로든 한 번에 여러 변환을 적용할 수도 있습니다.

pmap(vmap(jit(grad(f(x))))))

여러 구성 가능한 변환

Google JAX의 한계

Google JAX 개발자는 이러한 멋진 변환을 모두 도입하면서 딥 러닝 알고리즘의 속도를 높이는 것에 대해 잘 생각했습니다. 과학적인 계산 기능과 패키지는 NumPy 라인에 있으므로 학습 곡선에 대해 걱정할 필요가 없습니다. 그러나 JAX에는 다음과 같은 제한 사항이 있습니다.

  • Google JAX는 아직 개발 초기 단계에 있으며 주요 목적은 성능 최적화이지만 CPU 컴퓨팅에는 많은 이점을 제공하지 않습니다. NumPy는 더 나은 성능을 보이는 것으로 보이며 JAX를 사용하면 오버헤드만 추가될 수 있습니다.
  • JAX는 아직 연구 중이거나 초기 단계에 있으며 보다 확립되고 더 많은 사전 정의된 모델, 오픈 소스 프로젝트 및 학습 자료가 있는 TensorFlow와 같은 프레임워크의 인프라 표준에 도달하기 위해 더 많은 미세 조정이 필요합니다.
  • 현재 JAX는 Windows 운영 체제를 지원하지 않습니다. JAX가 작동하려면 가상 머신이 필요합니다.
  • JAX는 부작용이 없는 순수 함수에서만 작동합니다. 부작용이 있는 기능의 경우 JAX가 좋은 옵션이 아닐 수 있습니다.

Python 환경에 JAX를 설치하는 방법

시스템에 python 설정이 있고 로컬 시스템(CPU)에서 JAX를 실행하려면 다음 명령을 사용하십시오.

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

GPU 또는 TPU에서 Google JAX를 실행하려면 에 제공된 지침을 따르십시오. GitHub JAX 페이지. Python을 설정하려면 다음을 방문하십시오. 파이썬 공식 다운로드 페이지.

결론

Google JAX는 효율적인 딥 러닝 알고리즘, 로봇 및 연구를 작성하는 데 적합합니다. 제한 사항에도 불구하고 Haiku, Flax 등과 같은 다른 프레임워크와 함께 광범위하게 사용됩니다. 프로그램을 실행할 때 JAX가 하는 일을 이해하고 JAX를 사용하거나 사용하지 않고 코드를 실행할 때의 시간 차이를 볼 수 있습니다. 다음을 읽는 것으로 시작할 수 있습니다. 공식 Google JAX 문서이것은 매우 포괄적입니다.