Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research.
Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.
Haiku provides two core tools: a module abstraction, hk.Module
, and a simple
function transformation, hk.transform
.
hk.Module
s are Python objects that hold references to their own parameters,
other modules, and methods that apply functions on user inputs.
hk.transform
turns functions that use these object-oriented, functionally
"impure" modules into pure functions that can be used with jax.jit
,
jax.grad
, jax.pmap
, etc.
https://dm-haiku.readthedocs.io/en/latest/
https://github.com/deepmind/dm-haiku
No comments:
Post a Comment