Jax的真正使命:在WebGL上使用光线进阶渲染器
Jax's true calling: Ray-Marching renderers on WebGL

原始链接: https://benoit.paris/posts/jax-ray-marcher/

这个演示展示了高性能数值计算库 JAX 的强大功能,通过仅用 100 行 Python 代码构建一个简单的 3D 图形渲染器——该渲染器可以直接在浏览器中通过 WebGL 运行。核心技术利用**符号距离函数 (SDF)** 来表示形状,提供诸如可组合性和易于碰撞检测等优势。 与传统的多边形渲染不同,这段代码将对象定义为距离,从而实现平滑混合和数学运算。JAX 的关键优势在于通过**向量化**(使用 `vmap`)并行处理所有像素,以及通过**自动微分** (`jax.grad`) 有效地计算表面法线。 作者强调了 JAX 几乎可以纯粹作为数学函数来表达图形算法的能力,从而减少了样板代码,并为该过程带来了“机械同理心”。这展示了 JAX 在创建简洁高效的图形应用程序方面的潜力。

对不起。
相关文章

原文

Demo

(move your mouse/thumb across the image)

Why, though?

Well, I’ve been drooling over this tool the cool kids use, and wondering how I can join the gang. It’s called JAX.

It’s got GPU accelerated functions over n-dimensional arrays. And built-in compile-time differentiability of these!? Auto-vectorization?? And you just have to do like with numpy. What’s not to like? Go home APL! So I’ve been doing the obvious, the thing JAX was truly meant for: a graphics renderer.

Why, do you ask? Well, the animated image above is a 3-dimensional [512 pixels][512 pixels][3 colors] array for starters (or tensor, if you like). And we can define its content from the output of a function. Start from mouse position and time input, plug in some maths, hard-code a sphere and a cube in there, and voilà, pixels are painted!

And for our first trick here is the code, at about just 100 lines of Python. Yes, Python for browser code, because JAX can also be exported and run on the browser. On WebGL.

Below are some of the techniques used and where JAX shines:

Distances

We won’t be drawing polygons here. We’ll be using Signed Distance Functions (SDF). These have a lot going for them:

  • They’re just beautiful. What’s a sphere but a distance to a point? What’s a cylinder but a distance to a line? Just define a function to be negative when inside an object, positive outside.

  • They’re composable. Want the union of two objects? You take the minimum of their SDF. Want the intersection? max() it is. Here is a non-exhaustive list of what you can do. Here, we’ll be using smooth version of a union, as we want to preserve differentiability. If we had wanted the intersection, we could have used something akin to a softmax (I’m told this is a trendy function at the moment).

  • They contain help for moving in space without colliding with the shape they represent. By definition if the closest point to an object is a length L away from you, then you can move by length L in any direction you like without hitting it. That’s the raymarching / sphere tracing algorithm.

  • Did I mention they’re functions? We can vectorize these! Using two invocations of JAX’s vmap, we can transform functions assigned to single pixels into functions that compute all pixels of an image in parallel:

ray_colors = jax.vmap(jax.vmap(ray_color, (None, 0, None)), (0, None, None))

And last but not least, our final trick:

normal_at_surface = jax.grad(distance_function)(point_at_surface)

<3 Mechanical-sympathy for maths!

Now Google probably did not decide to build JAX so that I could personally save a few keystrokes on this. But the primitives available make it so that the code can be 70% maths. It is a nice step in having higher level DSLs be almost math-like, while staying very close to the metal.


Addendum: Things I would have liked to try but did not have time to:

联系我们 contact @ memedata.com