JAX/Flax로 딥러닝 레벨업 - 고급 모델링과 병렬 가속화로 무장한 차세대 딥러닝 라이브러리를 만나다
이영빈 외 지음 / 제이펍 / 2024년 9월
평점 :
장바구니담기


최근 Pytorch에서 JAX/Flax로의 전환이 이루어지고 있다. 구글에서 나오는 대부분의 모델들은 JAX를 활용하고 있고, 허깅 페이스에선 기존 모델의 속도 향상을 위해 JAX로의 전환을 진행하고 있다. JAX는 구글 딥마인드에서 사용하고 있는 고성능 딥러닝 프레임워크이며, 자동 미분과 XLA(Accelerated Linear Algebra)를 결합하여 사용하기 때문에 PyTorch에서 사용되는 동적 그래프 방식보다 훨씬 빠르게 학습과 추론을 진행할 수 있다. Flax는 JAX + Flexibility를 합쳐 만들어진 말이며, JAX를 조금 더 쉽게 사용할 수 있게 만든 프레임워크이다. 이 책은 JAX와 Flax의 개념과 장점, 사용방법에 대해 자세히 설명해주는 책이다.

JAX/Flax를 이용하여 딥러닝 모델을 만드려면 함수형 프로그래밍에 대한 이해가 있어야 하는데, 이 책의 도입부에서는 이를 먼저 설명해준다. 함수형 프로그래밍이란 계산을 수학적 함수로 취급하고 상태 및 변경 가능한 데이터를 피하는 프로그래밍 방식이다. 따라서 코드의 재사용성이 높아지게 되고 간결해지지만, 작성 난이도는 다소 높아 진입 장벽이 있는 편이다. 이 책에선 이를 먼저 이해시켜주기 때문에 다소 편하게 JAX를 접할 수 있었다. 함수형 프로그래밍의 핵심인 불변성과 순수 함수에 관한 개념도 예시와 함께 자세히 설명되어 있어 독자가 엄격한 함수형 프로그래밍 방법을 이해하는데도 큰 도움을 주는 파트다.

또한 이 책에선 JAX/Flax로 간단한 CNN 모델을 만드는 모습을 코드와 함께 튜토리얼로써 보여주는데, 이 부분이 실질적으로 독자가 JAX/Flax를 이용해 필요한 모델을 제작하는데 큰 도움이 되는 부분이다. 딥러닝 모델을 구현한 실제 코드를 보면 이를 참고하여 여러 곳에 활용할 수 있기 때문에, 독자가 더욱 성능 좋은 딥러닝 라이브러리를 직접 사용해 모델을 만들어보는 과정을 이끌어줄 수 있는 부분이라고 생각한다.

물론 무조건 PyTorch가 JAX/Flax 보다 좋지 않다는 것은 아니고 장단점이 존재한다. Pytorch는 직관적인 API 방식의 사용법을 지원해 사용자 친화적이며, 수 년간 업계에서 메이저한 라이브러리로써 쓰여 왔기 때문에 강력한 커뮤니티가 지원되며, 디버깅이 용이하다는 장점이 있다. JAX는 기본적으로 자동 미분과 XLA 컴파일러를 사용해 GPU 및 TPU에서 고도로 최적화된 연산을 수행할 수 있기 때문에 방대한 연산이 필요한 LLM을 구현하기에 적합하고, 함수형 프로그래밍 스타일을 지원해 깔끔한 코드를 작성할 수 있다는 장점이 있다. 독자는 이러한 장단점을 이해하고 어떤 라이브러리를 이용해 딥러닝 모델을 제작할지 잘 선택해야 한다.

이 책은 pytorch나 tensorflow, keras로 딥러닝 모델을 구현하는 코드를 보여주고 있으며, 머신러닝/딥러닝 모델의 학습 과정과 매개변수 등을 포함한 이론과 개념에 익숙한 사람이라면 누구나 시도해 볼 만한 책이다. 자신의 딥러닝 모델을 성능과 속도면에서 업그레이드 하고 싶다면, 이 책을 읽어보는 것을 적극 추천한다.


댓글(0) 먼댓글(0) 좋아요(0)
좋아요
공유하기 북마크하기찜하기 thankstoThanksTo