大致按照使用频率递减给出

这里可以给出一个结论性的规律, 便于判断张量形状, 索引操作时, 有几个 : , 就有几个维度

1. 基本下标与切片(Python 风格)

1
2
3
4
5
6
7
8
import torch
x = torch.arange(24).reshape(2,3,4) # shape (2,3,4)
# x[batch, row, col]
x[0,0,0] # 标量 tensor(0)
x[0, :, :] # 第一维=0,取出 shape (3,4)
x[:, 1, :] # 所有batch, 第二个row -> shape (2,4)
x[..., 2] # 省略号,等价于 x[:, :, 2] -> shape (2,3)
x[0] # 等价于 x[0, :, :] -> shape (3,4)
  • 切片语法支持 start:stop:step(含 start,不含 stop),支持负数索引和步长。
  • 例如 x[:, ::-1, :] 会在中间维度上反转顺序(返回 view 还是 copy 取决于实现;在 PyTorch 中 negative step 会返回 copy)。

2. 使用 None / np.newaxis(增加维度)

1
2
3
y = torch.tensor([1,2,3])        # shape (3,)
y[None, :] # shape (1,3)
y[:, None] # shape (3,1)
  • 常用于把向量转为列/行以便广播。

3. 布尔掩码(Boolean Masking)

1
2
3
a = torch.tensor([0,5,2,7,3])
mask = a > 3 # tensor([False, True, False, True, False])
a[mask] # tensor([5,7]) -> 1D 输出,丢失原shape信息
  • masked_select(a, mask) 等价于 a[mask]
  • 当 mask 是多维且与 a 同 shape 时,结果是扁平的 1D 张量(按行主序提取元素)。
  • 可用于筛样本、实现 padding 掩码筛选等。

4. 花式索引(整数数组索引 / Advanced Indexing)

1
2
3
4
5
6
7
M = torch.arange(12).reshape(3,4)  # shape (3,4)
rows = torch.tensor([2,0])
cols = torch.tensor([1,3])
M[rows] # 按行选择 -> shape (2,4)
M[:, cols] # 按列选择 -> shape (3,2)
# 对应元素选择(pairwise)
M[rows, cols] # 取 (2,1) 和 (0,3) -> shape (2,)
  • 若使用多个 1D 整数索引(同维度),会进行逐元素配对索引,输出长度等于索引数组长度。
  • 若混合切片和整数数组索引,规则稍复杂:整数索引会先被应用,结果维度位置会消失或变为新维度。

例(多维):

1
2
3
4
5
T = torch.arange(2*3*4).reshape(2,3,4)
idx0 = torch.tensor([0,1]) # 用于第0维
idx1 = torch.tensor([2,0]) # 用于第1维
T[idx0, idx1] # 逐对索引 -> shape (2,4)
# 等价于: torch.stack([T[0,2], T[1,0]], dim=0)

5. 广播与索引(注意形状)

1
2
3
4
A = torch.arange(6).reshape(2,3)  # (2,3)
idx = torch.tensor([0,2]) # (2,)
A[torch.arange(2), idx] # (2,) -> 每 batch 对应列取值
# torch.arange(2) 为 [0,1],与 idx 配对 -> 取 (0,0) 和 (1,2)
  • 常用于按-batch 选择每个样本对应的索引(如分类预测的 top-k 判断)。

6. gatherscatter(按索引收集/写入,适用于高维批量操作)

1
2
3
4
5
# gather 示例:从 src 中按 index 收集(需要指定 dim)
src = torch.tensor([[10,11,12],[20,21,22]]) # shape (2,3)
index = torch.tensor([[2,1,0],[0,2,1]]) # shape (2,3)
torch.gather(src, dim=1, index=index)
# -> shape (2,3): [[12,11,10],[20,22,21]]
  • gather 要求 indexsrc 在除了 dim 外的维度完全相同;返回与 index 同形状的张量。
  • 常用于实现按位置取值(例如 beam-search、按预测索引从概率张量中取值)。
  • scatter_/scatter 用于把值写入指定位置,可做 one-hot 化或累积(有 reduce 参数)。

7. index_select / take(按维度选择)

1
2
3
4
5
v = torch.tensor([10,20,30,40])
torch.index_select(v, dim=0, index=torch.tensor([3,1])) # tensor([40,20])
# 对于矩阵按行选:
M = torch.arange(12).reshape(3,4)
torch.index_select(M, dim=0, index=torch.tensor([2,0])) # shape (2,4)
  • index_select 返回的顺序与索引一致;与 fancy indexing(M[idx])相似,但有些后端实现行为细微不同(比如保留 contiguous 性)。

8. masked_fill, where(掩码赋值 / 条件选择)

1
2
3
x = torch.tensor([1., -2., 3.])
x.masked_fill(x < 0, 0.) # 把负数置0
torch.where(x>0, x, torch.zeros_like(x)) # 条件选择,相当于 np.where
  • where(cond, A, B) 返回与 A/B 广播后的形状相同的张量。

9. unsqueeze / squeezeview/reshape(维度控制)

1
2
3
4
a = torch.tensor([1,2,3])      # (3,)
a.unsqueeze(0) # (1,3)
a.unsqueeze(1) # (3,1)
torch.squeeze(a.unsqueeze(0)) # 恢复
  • squeeze(dim) 只在指定维度为 1 时删除该维度。
  • reshape/view 会改变内存视图(view 要求连续 contiguous;reshape 在必要时会复制)。

10. Ellipsis ...(省略号)

1
2
3
X = torch.randn(4,5,6,7)
X[..., 0] # 等价 X[:, :, :, 0]
X[0, ...] # 等价 X[0, :, :, :]
  • 在不确定前面/后面维度数时非常有用,特别是在写通用层时。

11. 多维返回与维度插入(保持/丢失维度)

  • 使用整数索引会减少维度(那一维被消除);
  • 使用切片或 None,或保持长度为 1 的索引会保留维度。
    示例:
1
2
3
4
t = torch.randn(2,3,4)
t[0].shape # (3,4) -- 整数索引去掉第0维
t[0:1].shape # (1,3,4) -- 切片保留第0维
t[[0]].shape # (1,3,4) -- 用长度1的索引数组也保留

12. 视图(view)与 copy(内存/contiguous)相关注意

  • 大多数简单切片和整型索引会返回原张量的 view(共享内存),但有些操作会返回 copy(例如带负步长的切片、某些高级索引)。
  • is_contiguous() 可以检查是否连续。若对返回的张量执行 view() 可能会报错,需先 .contiguous()
1
2
3
4
s = torch.arange(6).reshape(2,3)
t = s[:, ::-1] # 可能是 copy(不连续)
t.is_contiguous() # 可能 False
t.contiguous().view(-1) # 安全
  • 在 in-place 操作(如 t += 1)时,如果 t 与原张量共享内存,可能会影响原张量;对 copy 则无影响。

13. 反向传播(autograd)相关

  • 索引、切片会保留计算图信息(如果原张量 requires_grad=True),因此从张量中取出的部分仍可对原张量反向传播。
  • 但是,用高级索引赋值(x[idx] = something)不记录梯度;需要使用 scatter 或构造新的张量再计算 loss。
  • detach() 可以切断梯度传播(例如 x = x.detach())。

示例(反向传播影响):

1
2
3
4
x = torch.randn(3, requires_grad=True)
y = x[1] * 2
y.backward() # 会为 x[1] 累积梯度,但 x[0], x[2] 为 0
x.grad # tensor([0., 2., 0.])

14. 常见用途与模式(实战片段)

  • 按 batch 取样(分类概率取预测值):
1
2
3
4
probs = torch.randn(32, 10)  # logits or probs
pred = probs.argmax(dim=1) # (32,)
# 如果想从 probs 中收集每个 batch 对应预测的概率:
selected = probs[torch.arange(32), pred] # shape (32,)
  • padding mask(seq 长短不一):
1
2
3
4
seq = torch.randint(0, 100, (4,7))   # batch, seq_len
mask = (seq != PAD_TOKEN) # True 表示有效
# 通过 mask 做池化:
masked_sum = (embeddings * mask.unsqueeze(-1)).sum(dim=1)
  • one-hot
1
2
idx = torch.tensor([0,2,1])
onehot = torch.nn.functional.one_hot(idx, num_classes=4) # shape (3,4)
  • 按索引更新参数(embedding lookup 与更新):
1
2
embedding = torch.nn.Embedding(num_embeddings=1000, embedding_dim=64)
out = embedding(idx_tensor) # embedding 内部就是高级索引/ gather 实现

15. 进阶:einsum 作为灵活替代(当索引和 reshape 太繁琐)

  • einsum 可以在一次表达式中完成复杂的通道/维度重新排列与约简,可替代多个 transpose + matmul 操作。
1
2
# 例:批量矩阵乘积 sum over k: c_{ij} = sum_k a_{ik} b_{jk}
torch.einsum('ik,jk->ij', a, b)

16. 常见陷阱与建议

  1. 整数数组索引通常会返回 copy(非 view) —— 这会影响内存并且后续 in-place 修改不会影响原张量。
  2. 带负步长的切片常常产生 copy,要注意 is_contiguous()
  3. 混合使用布尔掩码与维度不适配会报错,确保 mask 与被掩的张量形状一致或能广播。
  4. 不要对需要 autograd 的部分用原地替换(x[idx] = …),会破坏计算图,使用 scatter/scatter_add 或构造新张量。
  5. 索引返回的张量可能会改变梯度分配,只有被实际用到(参与 loss)的元素才会有梯度。
  6. 尽量用 gather 实现批量按位置索引(可保持形状且直观),而不是复杂的循环。

17. 一张速查小表(常用 API)

  • 基本:x[i], x[:, j], x[..., k], x[start:stop:step]
  • 维度:unsqueeze, squeeze, transpose, permute, reshape, view, contiguous
  • 选择/花式:index_select, take, gather, scatter, where, masked_select, masked_fill
  • 布尔/掩码:x[mask], torch.where, torch.nonzero
  • 其他:one_hot, topk, argmax/argmin, torch.arange(用于构建 batch 索引)

如果你愿意,我可以:

  • 把这些示例做成一个可以直接运行的 Jupyter notebook(包含针对常见错误的测试用例),或者
  • 根据你常用的框架(PyTorch / TensorFlow / NumPy)把示例改写成你习惯的 API,或
  • 针对你当前代码中的索引问题(把你代码贴来),帮你找 bug 并修正。

你想要哪种后续帮助?