大致按照使用频率递减给出
这里可以给出一个结论性的规律, 便于判断张量形状, 索引操作时, 有几个 : , 就有几个维度
1. 基本下标与切片(Python 风格)
1 2 3 4 5 6 7 8
| import torch x = torch.arange(24).reshape(2,3,4)
x[0,0,0] x[0, :, :] x[:, 1, :] x[..., 2] x[0]
|
- 切片语法支持
start:stop:step(含 start,不含 stop),支持负数索引和步长。
- 例如
x[:, ::-1, :] 会在中间维度上反转顺序(返回 view 还是 copy 取决于实现;在 PyTorch 中 negative step 会返回 copy)。
2. 使用 None / np.newaxis(增加维度)
1 2 3
| y = torch.tensor([1,2,3]) y[None, :] y[:, None]
|
3. 布尔掩码(Boolean Masking)
1 2 3
| a = torch.tensor([0,5,2,7,3]) mask = a > 3 a[mask]
|
masked_select(a, mask) 等价于 a[mask]。
- 当 mask 是多维且与 a 同 shape 时,结果是扁平的 1D 张量(按行主序提取元素)。
- 可用于筛样本、实现 padding 掩码筛选等。
4. 花式索引(整数数组索引 / Advanced Indexing)
1 2 3 4 5 6 7
| M = torch.arange(12).reshape(3,4) rows = torch.tensor([2,0]) cols = torch.tensor([1,3]) M[rows] M[:, cols]
M[rows, cols]
|
- 若使用多个 1D 整数索引(同维度),会进行逐元素配对索引,输出长度等于索引数组长度。
- 若混合切片和整数数组索引,规则稍复杂:整数索引会先被应用,结果维度位置会消失或变为新维度。
例(多维):
1 2 3 4 5
| T = torch.arange(2*3*4).reshape(2,3,4) idx0 = torch.tensor([0,1]) idx1 = torch.tensor([2,0]) T[idx0, idx1]
|
5. 广播与索引(注意形状)
1 2 3 4
| A = torch.arange(6).reshape(2,3) idx = torch.tensor([0,2]) A[torch.arange(2), idx]
|
- 常用于按-batch 选择每个样本对应的索引(如分类预测的 top-k 判断)。
6. gather 与 scatter(按索引收集/写入,适用于高维批量操作)
1 2 3 4 5
| src = torch.tensor([[10,11,12],[20,21,22]]) index = torch.tensor([[2,1,0],[0,2,1]]) torch.gather(src, dim=1, index=index)
|
gather 要求 index 与 src 在除了 dim 外的维度完全相同;返回与 index 同形状的张量。
- 常用于实现按位置取值(例如 beam-search、按预测索引从概率张量中取值)。
scatter_/scatter 用于把值写入指定位置,可做 one-hot 化或累积(有 reduce 参数)。
7. index_select / take(按维度选择)
1 2 3 4 5
| v = torch.tensor([10,20,30,40]) torch.index_select(v, dim=0, index=torch.tensor([3,1]))
M = torch.arange(12).reshape(3,4) torch.index_select(M, dim=0, index=torch.tensor([2,0]))
|
index_select 返回的顺序与索引一致;与 fancy indexing(M[idx])相似,但有些后端实现行为细微不同(比如保留 contiguous 性)。
8. masked_fill, where(掩码赋值 / 条件选择)
1 2 3
| x = torch.tensor([1., -2., 3.]) x.masked_fill(x < 0, 0.) torch.where(x>0, x, torch.zeros_like(x))
|
where(cond, A, B) 返回与 A/B 广播后的形状相同的张量。
9. unsqueeze / squeeze 与 view/reshape(维度控制)
1 2 3 4
| a = torch.tensor([1,2,3]) a.unsqueeze(0) a.unsqueeze(1) torch.squeeze(a.unsqueeze(0))
|
squeeze(dim) 只在指定维度为 1 时删除该维度。
reshape/view 会改变内存视图(view 要求连续 contiguous;reshape 在必要时会复制)。
10. Ellipsis ...(省略号)
1 2 3
| X = torch.randn(4,5,6,7) X[..., 0] X[0, ...]
|
- 在不确定前面/后面维度数时非常有用,特别是在写通用层时。
11. 多维返回与维度插入(保持/丢失维度)
- 使用整数索引会减少维度(那一维被消除);
- 使用切片或
None,或保持长度为 1 的索引会保留维度。
示例:
1 2 3 4
| t = torch.randn(2,3,4) t[0].shape t[0:1].shape t[[0]].shape
|
12. 视图(view)与 copy(内存/contiguous)相关注意
- 大多数简单切片和整型索引会返回原张量的 view(共享内存),但有些操作会返回 copy(例如带负步长的切片、某些高级索引)。
is_contiguous() 可以检查是否连续。若对返回的张量执行 view() 可能会报错,需先 .contiguous()。
1 2 3 4
| s = torch.arange(6).reshape(2,3) t = s[:, ::-1] t.is_contiguous() t.contiguous().view(-1)
|
- 在 in-place 操作(如
t += 1)时,如果 t 与原张量共享内存,可能会影响原张量;对 copy 则无影响。
13. 反向传播(autograd)相关
- 索引、切片会保留计算图信息(如果原张量
requires_grad=True),因此从张量中取出的部分仍可对原张量反向传播。
- 但是,用高级索引赋值(
x[idx] = something)不记录梯度;需要使用 scatter 或构造新的张量再计算 loss。
detach() 可以切断梯度传播(例如 x = x.detach())。
示例(反向传播影响):
1 2 3 4
| x = torch.randn(3, requires_grad=True) y = x[1] * 2 y.backward() x.grad
|
14. 常见用途与模式(实战片段)
1 2 3 4
| probs = torch.randn(32, 10) pred = probs.argmax(dim=1)
selected = probs[torch.arange(32), pred]
|
1 2 3 4
| seq = torch.randint(0, 100, (4,7)) mask = (seq != PAD_TOKEN)
masked_sum = (embeddings * mask.unsqueeze(-1)).sum(dim=1)
|
1 2
| idx = torch.tensor([0,2,1]) onehot = torch.nn.functional.one_hot(idx, num_classes=4)
|
- 按索引更新参数(embedding lookup 与更新):
1 2
| embedding = torch.nn.Embedding(num_embeddings=1000, embedding_dim=64) out = embedding(idx_tensor)
|
15. 进阶:einsum 作为灵活替代(当索引和 reshape 太繁琐)
einsum 可以在一次表达式中完成复杂的通道/维度重新排列与约简,可替代多个 transpose + matmul 操作。
1 2
| torch.einsum('ik,jk->ij', a, b)
|
16. 常见陷阱与建议
- 整数数组索引通常会返回 copy(非 view) —— 这会影响内存并且后续 in-place 修改不会影响原张量。
- 带负步长的切片常常产生 copy,要注意
is_contiguous()。
- 混合使用布尔掩码与维度不适配会报错,确保 mask 与被掩的张量形状一致或能广播。
- 不要对需要 autograd 的部分用原地替换(x[idx] = …),会破坏计算图,使用
scatter/scatter_add 或构造新张量。
- 索引返回的张量可能会改变梯度分配,只有被实际用到(参与 loss)的元素才会有梯度。
- 尽量用
gather 实现批量按位置索引(可保持形状且直观),而不是复杂的循环。
17. 一张速查小表(常用 API)
- 基本:
x[i], x[:, j], x[..., k], x[start:stop:step]
- 维度:
unsqueeze, squeeze, transpose, permute, reshape, view, contiguous
- 选择/花式:
index_select, take, gather, scatter, where, masked_select, masked_fill
- 布尔/掩码:
x[mask], torch.where, torch.nonzero
- 其他:
one_hot, topk, argmax/argmin, torch.arange(用于构建 batch 索引)
如果你愿意,我可以:
- 把这些示例做成一个可以直接运行的 Jupyter notebook(包含针对常见错误的测试用例),或者
- 根据你常用的框架(PyTorch / TensorFlow / NumPy)把示例改写成你习惯的 API,或
- 针对你当前代码中的索引问题(把你代码贴来),帮你找 bug 并修正。
你想要哪种后续帮助?