【pytorch】广播机制
# 背景## 广播机制
在上面的部分中,我们看到了如何在相同形状的两个张量上执行按元素操作。
在某些情况下,[**即使形状不同,我们仍然可以通过调用
*广播机制*(broadcasting mechanism)来执行按元素操作**]。
这种机制的工作方式如下:
1. 通过适当复制元素来扩展一个或两个数组,以便在转换之后,两个张量具有相同的形状;
2. 对生成的数组执行按元素操作。
在大多数情况下,我们将沿着数组中长度为1的轴进行广播,如下例子:
```python
a = torch.arange(3).reshape((3, 1))
b = torch.arange(2).reshape((1, 2))
a, b
```
输出
```
(tensor([,
,
]),
tensor([]))
```
由于 `a`和 `b`分别是$3\times1$和$1\times2$矩阵,如果让它们相加,它们的形状不匹配。
我们将两个矩阵*广播*为一个更大的$3\times2$矩阵,如下所示:矩阵 `a`将复制列,矩阵 `b`将复制行,然后再按元素相加。
```python
a + b
```
输出
```
tensor([,
,
])
```
这时候问题来了,以上都是二维的张量操作,如果上升到三维,会怎么样呢?
用其他形状(例如三维张量)替换广播机制中按元素操作的两个张量。结果是否与预期相同?
# 尝试
## 出错
```
a = torch.arange(6).reshape((3, 1, 2))
b = torch.arange(6).reshape((1, 2, 3))
a, b
```
输出
```
(tensor([[],
[],
[]]),
tensor([[,
]]))
```
接着,我们执行张量按元素相加
```
a + b #报错RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 2
```
报错了,维度不相同。第三个维度a是2,b是3,不相等,所以报错。
## 换一下,成功
```
a = torch.arange(6).reshape((3, 1, 2))
b = torch.arange(6).reshape((1, 3, 2))
a, b
```
输出
```
(tensor([[],
[],
[]]),
tensor([[,
,
]]))
```
执行相加
```
a + b
```
输出
```
tensor([[[ 0,2],
[ 2,4],
[ 4,6]],
[[ 2,4],
[ 4,6],
[ 6,8]],
[[ 4,6],
[ 6,8],
[ 8, 10]]])
```
这个时候,三维张量又满足了广播机制了。
# 总结
满足广播的条件:需要两个张量的每个维度满足以下条件:
a.这两个维度的大小相等;
b.某个维度,一个张量有,一个张量没有;
c.某个维度,一个张量有,一个张量也有但大小是1。
以上规则再回到示例中回顾去,确实是这样。
页:
[1]