Description
LLM 시대를 선도하는 최적의 딥러닝 라이브러리 JAX/Flax
JAX(잭스)는 대규모 계산의 확장성을 염두에 두고 설계된 고성능 라이브러리로, LLM 시대 애물단지로 전락한 파이토치를 빠르게 대체하고 있다. 모두의연구소 JAX/Flax LAB이 집필한 이 책은 JAX, 그리고 JAX와 함께 쓰이는 Flax(플랙스)를 본격적으로 다루는 국내 최초의 책이다. JAX 기초와 함수형 프로그래밍, 병렬처리 등의 특장점을 살펴보고, JAX와 Flax를 조합해서 CNN, ResNet, DCGAN, CLIP 모델을 실제로 구현해본다. 새로운 시대, 새로운 딥러닝의 방식을 익혀보자.
저자

이영빈,유현아,김한빈,조영빈,이태호,장진우,이승현,김형섭,박정현

저자:이영빈
모두의연구소에서AI교육을진행하고있으며JAX/FlaxLAB짱을맡고있다.

저자:유현아
IT업계의세미나와커뮤니티활동에힘쓰며JAXKR커뮤니티의운영진으로활동하고있다.AI기술의긍정적인가능성에대한깊은관심과열정을가지고있다.

저자:김한빈
P&G에서데이터사이언티스트로근무하고있다.AI/ML을활용하여비즈니스임팩트를창출하는것에큰관심이있다.

저자:조영빈
클라우드분야에서ML엔지니어로일하고있다.모델을서빙하고서비스에적용하는것에관심이있다.

저자:이태호
머신러닝을연구해비전,음성,로봇,자동차,재료공학,공장자동화등의분야에서상용서비스를만들었으며,다양한연구성과를내고있다.

저자:장진우
제조업분야에서데이터과학자로일하고있다.다양한분야에서데이터를활용한문제해결방법에관심이있다.

저자:이승현
컴퓨터비전,자연어처리등다양한분야의딥러닝모델을최적화및경량화하여,더적은비용으로효과적인서비스를제공하는방법을연구하고있다.

저자:김형섭
ML엔지니어.2018년부터의료영상,버추얼아이돌,패션도메인의스타트업들에서StyleGAN,VAE,스테이블디퓨전등딥러닝생성모델의활용을연구하고서비스화했다.

저자:박정현
ML엔지니어로일하고있으며,KerasKorea의Organizer로활동하고있다.

목차

베타리더후기viii
지은이소개x
JAX/FlaxLAB소개xii
들어가며xiii
이책에대하여xiv

CHAPTER1JAX/Flax를공부하기전에1
1.1JAX/Flax에대한소개와예시1
__1.1.1JAX란1
__1.1.2Flax란2
__1.1.3JAX로이루어진기타프레임워크들3
__1.1.4JAX프레임워크사용예시3
1.2함수형프로그래밍에대한이해5
__1.2.1부수효과와순수함수5
__1.2.2불변성과순수함수7
__1.2.3정리하며8
1.3JAX/Flax에서자주사용하는파이썬표준라이브러리9
__1.3.1functools.partial()10
__1.3.2typing모듈12
__1.3.3정리하며13
1.4JAX/Flax설치방법14
__1.4.1로컬에JAX/Flax설치하기14
__1.4.2코랩에서TPU사용하기14

CHAPTER2JAX의특징17
2.1NumPy에서부터JAX시작하기18
__2.1.1JAX와NumPy비교하기18
__2.1.2JAX에서미분계산하기19
__2.1.3손실함수의그레이디언트계산하기21
__2.1.4손실함수의중간과정확인하기22
__2.1.5JAX의함수형언어적특징이해하기23
__2.1.6JAX로간단한학습돌려보기25
2.2JAX의JIT컴파일28
__2.2.1JAX변환이해하기29
__2.2.2함수를JIT컴파일하기32
__2.2.3JIT컴파일이안되는경우34
__2.2.4JIT컴파일과캐싱37
2.3자동벡터화39
__2.3.1수동으로벡터화하기39
__2.3.2자동으로벡터화하기41
2.4자동미분42
__2.4.1고차도함수43
__2.4.2그레이디언트중지46
__2.4.3샘플당그레이디언트49
2.5JAX의난수52
__2.5.1NumPy의난수52
__2.5.2JAX의난수56
2.6pytree사용하기59
__2.6.1pytree의정의60
__2.6.2pytree함수사용법61
2.7JAX에서의병렬처리65
2.8상태를유지하는연산69
__2.8.1상태에대한이해69
__2.8.2모델에적용하기72

CHAPTER3Flax소개77
3.1FlaxCNN튜토리얼79
__3.1.1패키지로드하기79
__3.1.2데이터로드하기80
__3.1.3모델정의와초기화81
__3.1.4메트릭정의하기84
__3.1.5TrainState초기화84
__3.1.6훈련스텝과평가스텝정의하기85
__3.1.7모델학습하기87
__3.1.8모델추론하기89
3.2심화튜토리얼90
__3.2.1배치정규화적용91
__3.2.2드롭아웃적용95
__3.2.3학습률스케줄링98
__3.2.4체크포인트관리103

CHAPTER4JAX/Flax를활용한딥러닝모델만들기105
4.1순수JAX로구현하는CNN106
__4.1.1패키지로드하기107
__4.1.2데이터로드하기108
__4.1.3레이어구현108
__4.1.4네트워크정의하기115
__4.1.5학습및평가준비116
__4.1.6학습및평가118
__4.1.7추론121
4.2ResNet122
__4.2.1패키지로드하기123
__4.2.2데이터로드하기123
__4.2.3모델정의및초기화124
__4.2.4메트릭정의하기129
__4.2.5TrainState초기화129
__4.2.6훈련스텝과평가스텝정의하기131
__4.2.7모델학습하기132
__4.2.8결과시각화하기135
4.3DCGAN136
__4.3.1패키지로드하기136
__4.3.2데이터로드하기137
__4.3.3모델정의및초기화138
__4.3.4학습방법정의하기140
__4.3.5TrainState초기화143
__4.3.6모델학습하기145
__4.3.7결과시각화하기146
4.4CLIP148
__4.4.1CIFAR10데이터셋으로CLIP미세조정진행하기150
__4.4.2JAX로만들어진데이터셋구축클래스150
__4.4.3이미지데이터구축함수뜯어보기151
__4.4.4CLIP모델불러오기154
__4.4.5CLIP에사용하기위한전처리및미세조정155
__4.4.6모델학습에필요한함수정의하기156
__4.4.7하이퍼파라미터설정과TrainState구축하기160
__4.4.8모델저장하고체크포인트만들기161
__4.4.9요약클래스만들기162
__4.4.10학습에필요한스텝정의와랜덤인수복제163
__4.4.11모델학습하기와모델저장하기163
4.5DistilGPT2미세조정학습166
__4.5.1패키지설치167
__4.5.2환경설정168
__4.5.3토크나이저학습169
__4.5.4데이터셋전처리171
__4.5.5학습및평가173
__4.5.6추론178

CHAPTER5TPU환경설정181
5.1코랩에서TPU설정하기181
5.2캐글에서TPU세팅하기182
5.3TRC프로그램신청하기183

마무리하며186
찾아보기188

출판사 서평

주요내용

함수형프로그래밍,파이썬라이브러리등JAX사용시알아야할기초
JIT컴파일,자동벡터화,pytree,병렬처리등JAX의주요특징
CNN튜토리얼로알아보는Flax기초
ResNet,DCGAN,CLIP모델을구축하며Flax에익숙해지기
코랩,캐글에서TPU환경설정하기

책속에서

JAX는구글에서개발한고성능수치계산라이브러리로,특히병렬가속화기능을통해대규모모델의효율적인학습과추론이가능합니다.Flax는JAX기반의심플한신경망라이브러리로,JAX의장점을살려유연하고확장가능한모델구축을지원합니다.이책은모두의연구소JAX/FlaxLAB이다양한경험과지식을바탕으로JAX를어떻게실용적으로활용할수있는지에중점을두고집필한책입니다.이론만을설명하는것이아니라실제예제를통해적용방법까지소개합니다.
---p.xiv

JAX에서는해당메서드를데커레이터---p.decorator)로활용합니다.데커레이터로사용하면코드간결성이나코드재사용이늘어난다는장점을갖고있습니다.---p.…)이번예제는@partial에만집중해서살펴보겠습니다.해당데커레이터는jax.jit라는함수에서고정시키고싶은인수인n을static_argnames로고정시키고컴파일됩니다.이방식을취하면n은컴파일되어추가적인계산을진행하지않습니다./따라서JAX에서partial데커레이터를사용하면굳이선언할필요없이병렬처리를할수있게도와줍니다.
---p.11

먼저SELU---p.scaledexponentiallinearunit)를구현한예시를봅시다.---p.…)이출력결과는구글코랩의T4가속기에서실행한결과입니다.이제XLA컴파일러를이용해보겠습니다.JAX는jax.jit변환을통해JAX와호환되는함수들을JIT컴파일합니다.얼마나빨라지는지확인해보겠습니다.---p.…)7배이상빨라진것을확인할수있습니다.jax.jit이함수selu---p.)를selu_jit---p.)으로변환해준덕입니다.
---p.33

드롭아웃에서는뉴런의무작위적삭제를위한무작위연산이필요합니다.따라서PRNG상태를제공해야하며,이를위해PRNG키가필요합니다.즉기존에사용했던키외에추가적인키가필요하게됩니다.이미알고있는것처럼JAX에는PRNG키를제공하는명시적인방법이있습니다.key,subkey=jax.random.split---p.key)처럼split메서드를사용하여추가적인키를발급할수있습니다.---p.…)TrainState의초기화방법에도변화가있습니다.먼저사용자정의TrainState클래스에key필드를추가해야합니다.이제TrainState에서PRNG키또한상태로관리할수있습니다.이를위해TrainState.create---p.)메서드에dropout_key를전달합니다.
---p.96

이번절에서는Flax와TPU를활용하여허깅페이스의Transformers라이브러리에있는DistilGPT2모델을미세조정학습하는과정을다룰것입니다.이예제는고성능하드웨어의이점을최대한활용하면서도최신소프트웨어도구의유연성을결합하여효과적인학습파이프라인을구축하는방법을보여줍니다.---p.…)DistilGPT2는GPT-2---p.GenerativePre-trainedTransformer2)의가장작은버전으로,영어를지원하는사전훈련된모델입니다.지식증류---p.knowledgedistillation)기법을활용하여기존124M개의매개변수를가진모델을82M의매개변수로축소한모델입니다.이를통해모델의크기를줄이면서도성능을유지하는효율적인학습이가능해집니다.
---pp.166-167