Skip to content

How is env.xla() implemented? #192

Answered by mavenlin
uduse asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @uduse xla here is implemented via custom call, any c++ code (envpool code) can be put under a custom call and jitted into jax computation graph.

The GPU device support is faked via an extra call to copy the memory to GPU device.

https://dfm.io/posts/extending-jax/ this post is our main reference for the implementation.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@uduse
Comment options

Answer selected by uduse
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants