动机、参考资料、涉及内容

动机:

  • 在深入阅读一些深度学习模型代码时,会遇到一些关于 Tensor 的下标处理操作,例如:huggingface transformers 中的 generation/logit_process.py:RepetitionPenaltyLogitsProcess
  • numpy 的一些下标操作

参考资料:

  • pytorch 官方文档
  • numpy 官方文档
  • 一些博客

涉及内容:

  • API 解释
  • 一些综合使用的例子

torch.gather

import torch
torch.gather(input, dim: int, index)

其中 input 的形状假设是 (4, 3, 6),而 dim=1,index 的形状必须为 (4, K, 6),其中 K 可以任取(因为 dim=1)。而输出的 tensor 的形状与 index 完全一致,即:(4, K, 6)。

具体的例子与解释参考博客.

torch.scatter

repeat, repeat_interleave, squeeze, unsqueeze

broadcast_to, expand

torch.broadcast_totorch.expand 完全等价, 注意 torch.repeat 会发生内存拷贝, 而 torch.expand 不会, 这两者的使用方式也不一样(适用于 torch.expand 的入参可能不适用 torch.repeat, 反之亦然)

索引与切片

总的来说, 有如下几类

t[None]
t[:]
t[0, ...]
t[:3], t[-3:], t[2:5], t[start:end:step]
t[torch.tensor([True, False])]
t[torch.tensor([1, 2, 4])], t[torch.tensor([[1, 2, 4], [0, 1, 3]])]

pytorch 的文档中似乎对各种切片操作没有仔细介绍, 但应该基本上与 numpy.ndarray 的用法相似, 因此可以参照 numpy 的文档, 在姑且不论 view 与 copy 的区别时

首先引用 numpy 文档中的一段提示

Note that in Python, x[(exp1, exp2, …, expN)] is equivalent to x[exp1, exp2, …, expN]; the latter is just syntactic sugar for the former.

如果使用 bool 数组的方式进行索引时, 可以理解成将该位置的 bool 数组转化为列表

a[torch.tensor([True, False, True, False]), :]  # 等价于 a[[0, 2], :]
a[[0, 2, 1], [2, 3, 1]]  # 假设 a 只有两维, 注意返回值为 torch.tensor([a[0, 2], a[2, 3], a[1, 1]])

reduce

import torch
import einops
from functools import partial
x = torch.rand(2, 2, 3)
mean = eniops.reduce(x, "o ... -> o", "mean")  # (2,)
var = eniops.reduce(x, "o ... -> o", partial(torch.var, unbiased=False))  # [((x[0]-mean[0])**2)/6, ((x[1]-mean[1])**2)/6]
# 如果 unbiased = True, 则除以 5 而不是 6

einops.rearange

import torch
from einops.layers.torch import Rearrange

x = torch.tensor([
    [0, 1, 0, 1, 0, 1, 0, 1],
    [2, 3, 2, 3, 2, 3, 2, 3],
    [0, 1, 0, 1, 0, 1, 0, 1],
    [2, 3, 2, 3, 2, 3, 2, 3],
    [0, 1, 0, 1, 0, 1, 0, 1],
    [2, 3, 2, 3, 2, 3, 2, 3],
]).reshape(1, 1, 6, 8)

Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2)(x)

# tensor([[[[0, 0, 0, 0],
#           [0, 0, 0, 0],
#           [0, 0, 0, 0]],

#          [[1, 1, 1, 1],
#           [1, 1, 1, 1],
#           [1, 1, 1, 1]],

#          [[2, 2, 2, 2],
#           [2, 2, 2, 2],
#           [2, 2, 2, 2]],

#          [[3, 3, 3, 3],
#           [3, 3, 3, 3],
#           [3, 3, 3, 3]]]])