将高维张量绘制为矩阵的矩阵。
Draw high dimensional tensors as a matrix of matrices

原始链接: https://blog.ezyang.com/2025/10/draw-high-dimensional-tensors-as-a-matrix-of-matrices/

可视化高维张量(4D+)时,标准方法如列出2D切片或展平可能会掩盖张量的结构。一种更直观的方法是将张量表示为“矩阵的矩阵”。这种方法建立在低维表示之上,随着维度增加,先水平堆叠矩阵,然后垂直堆叠。 这会产生一种分形状的模式,每个增加的维度都会在行和列中引入“跳跃”,清晰地表明轴之间的关系。例如,在4D张量中,“行”可能会跳过元素以表示沿更高维度的进展。当所有维度的大小均为2时,这种排列会生成类似于Morton曲线的序列。 文本中包含高达5D张量的示例,并演示了`torch.split()`将如何沿每个维度分割4D张量,直观地显示矩阵的矩阵表示中的结果块。这种可视化有助于理解张量的组织方式以及诸如分割之类的操作如何影响其结构。

Hacker News 新闻 | 过去 | 评论 | 提问 | 展示 | 招聘 | 提交 登录 将高维张量绘制为矩阵的矩阵 (ezyang.com) 24 分,由 matt_d 1 天前发布 | 隐藏 | 过去 | 收藏 | 3 条评论 IAmBroom 22 小时前 | 下一个 [–] 由于计算机上的矩阵实际上是存储为线性“数组”(好的,内存位置)的 2D 矩阵……我不确定这是什么。概念的标准实现。回复 saagarjha 20 小时前 | 上一个 | 下一个 [–] 好奇为什么 PyTorch 中没有这个功能。回复 mbowring 19 小时前 | 上一个 [–] MATLAB 这样做。回复 考虑申请 YC 的 2026 年冬季批次!申请截止日期为 11 月 10 日 指南 | 常见问题 | 列表 | API | 安全 | 法律 | 申请 YC | 联系方式 搜索:
相关文章

原文

I have recently needed to draw the contents of high-dimensional (e.g., 4D and up) tensors where it is important to ensure that is clear how to identify each of the dimensions in the representation. Common strategies I've seen people do in this situation include printing a giant list 2D slices (what the default PyTorch printer will do) or flattening the Tensor in some way back down to a 2D tensor. However, if you have a lot of horizontal space, there is a strategy that I like that makes it easy to identify all the axes of the higher dimensional tensor: draw it as a matrix of matrices.

Here are some examples, including the easy up-to-2D cases for completeness.

0D: torch.arange(1).view()

0

1D: torch.arange(2)

0  1

2D: torch.arange(4).view(2, 2 )

0  1
2  3

3D: torch.arange(8).view(2, 2, 2)

0  1    4  5
2  3    6  7

4D: torch.arange(16).view(2, 2, 2, 2)

 0  1    4  5
 2  3    6  7

 8  9   12 13
10 11   14 15

5D: torch.arange(32).view(2, 2, 2, 2, 2):

 0  1    4  5  :  16 17   20 21
 2  3    6  7  :  18 19   22 23
               :
 8  9   12 13  :  24 25   28 29
10 11   14 15  :  26 27   30 31

The idea is that every time you add a new dimension, you alternate between stacking the lower dimension matrices horizontally and vertically. You always stack horizontally before stacking vertically, to follow the standard row-major convention for printing in the 2D case. Dimensions always proceed along the x and y axis, but the higher dimensions (smaller dim numbers) involve skipping over blocks. For example, a "row" on dim 3 in the 4D tensor is [0, 1] but the "row" on dim 1 is [0, 4] (we skip over to the next block.) The fractal nature of the construction means we can keep repeating the process for as many dimensions as we like.

In fact, for the special case when every size in the tensor is 2, the generated sequence of indices form a Morton curve. But I don't call it that, since I couldn't find a popular name for the variation of the Morton curve where the radix of each digit in the coordinate representation can vary.

Knowledge check. For the 4D tensor of size (2, 2, 2, 2) arranged in this way, draw the line(s) that would split the tensor into the pieces that torch.split(x, 1, dim), for each possible dimension 0, 1, 2 and 3. Answer under the fold.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

.

dim=0

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=0)]
[tensor([0, 1, 2, 3, 4, 5, 6, 7]), tensor([ 8, 9, 10, 11, 12, 13, 14, 15])]

     0  1    4  5
     2  3    6  7
   ----------------
     8  9   12 13
    10 11   14 15


dim=1

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=1)]
[tensor([ 0, 1, 2, 3, 8, 9, 10, 11]), tensor([ 4, 5, 6, 7, 12, 13, 14, 15])]

     0  1 |  4  5
     2  3 |  6  7
          |
     8  9 | 12 13
    10 11 | 14 15

dim=2

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=2)]
[tensor([ 0, 1, 4, 5, 8, 9, 12, 13]), tensor([ 2, 3, 6, 7, 10, 11, 14, 15])]

     0  1    4  5
   ------- -------
     2  3    6  7

     8  9   12 13
   ------- -------
    10 11   14 15

dim=3

>>> [x.reshape(-1) for x in torch.arange(16).view(2,2,2,2).split(1,dim=3)]
[tensor([ 0, 2, 4, 6, 8, 10, 12, 14]), tensor([ 1, 3, 5, 7, 9, 11, 13, 15])]

     0 |  1    4 |  5
     2 |  3    6 |  7

     8 |  9   12 | 13
    10 | 11   14 | 15
联系我们 contact @ memedata.com