我之前的一篇文章介绍了如何给PyTorch添加CPU上的简单的加法算子。在这篇文章里,我将继续展示一个更具体的PyTorch自定义算子示例——自己动手复现二维卷积算子。这个示例是基于PyTorch Extension的,在迁移项目时,不需要自己生成动态库,只需要用setup.py
重新编译一遍即可。我会同时介绍CPU版和CUDA版的实现。
许多前沿的神经网络都会对卷积进行一些修改。比如大名鼎鼎的可变形卷积(deformable convolution)。相信看完这篇文章后,大家能看懂PyTorch卷积的实现代码,并大概了解如何修改卷积的实现细节,并把新写好的卷积运用到自己的PyTorch项目中。
PyTorch Extension 实现二维卷积
搭建项目
在开始写代码前,要准备一个崭新的目录,在这个文件夹里搭建项目。
在根目录下,先创建一个setup.py
,之后要填写这份安装文件。
之后,创建一个文件夹,其名字是项目名。在这个文件夹里合适的地方新建一个子文件夹,专门用来放和算子相关的文件。我的项目名叫做panoflow
,算子相关文件放在了panoflow/core/op
子文件夹下。
接下来,和算子实现相关的文件都应该放在算子文件夹里。使用和测试算子的文件可以放在项目文件夹的其他地方。
由于在实现中我借用了MMCV的代码,还要提前准备好一些头文件。首先新建一个文件pytorch_cpp_helper.hpp
:
1 |
|
再创建一个文件pytorch_cuda_helper.hpp
:
1 |
|
还有一个common_cuda_helper.hpp
:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
int optimal_block_num = (N + num_threads - 1) / num_threads;
int max_block_num = 4096;
return min(optimal_block_num, max_block_num);
}
template <typename T>
__device__ T bilinear_interpolate(const T* input, const int height,
const int width, T y, T x,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
if (y <= 0) y = 0;
if (x <= 0) x = 0;
int y_low = (int)y;
int x_low = (int)x;
int y_high;
int x_high;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// do bilinear interpolation
T v1 = input[y_low * width + x_low];
T v2 = input[y_low * width + x_high];
T v3 = input[y_high * width + x_low];
T v4 = input[y_high * width + x_high];
T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
template <typename T>
__device__ void bilinear_interpolate_gradient(
const int height, const int width, T y, T x, T& w1, T& w2, T& w3, T& w4,
int& x_low, int& x_high, int& y_low, int& y_high,
const int index /* index for debug only*/) {
// deal with cases that inverse elements are out of feature map boundary
if (y < -1.0 || y > height || x < -1.0 || x > width) {
// empty
w1 = w2 = w3 = w4 = 0.;
x_low = x_high = y_low = y_high = -1;
return;
}
if (y <= 0) y = 0;
if (x <= 0) x = 0;
y_low = (int)y;
x_low = (int)x;
if (y_low >= height - 1) {
y_high = y_low = height - 1;
y = (T)y_low;
} else {
y_high = y_low + 1;
}
if (x_low >= width - 1) {
x_high = x_low = width - 1;
x = (T)x_low;
} else {
x_high = x_low + 1;
}
T ly = y - y_low;
T lx = x - x_low;
T hy = 1. - ly, hx = 1. - lx;
// reference in forward
// T v1 = input[y_low * width + x_low];
// T v2 = input[y_low * width + x_high];
// T v3 = input[y_high * width + x_low];
// T v4 = input[y_high * width + x_high];
// T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;
return;
}
这些文件添加了CPU和CUDA实现时需要的头文件和定义,后面的C++源码会用到它们。
CPU
C++实现
在用C++实现一个算子时,我们要编写一个形如这样的文件:
1 |
|
这个C++文件主要包含两部分内容:算子的实现函数和C++接口绑定。在实现卷积时,也是要实现这两部分内容。
在修改一个现有的算子时,最好的方法不是从头写一个,而是去开源库里找一份实现,并在这个基础上进行修改。
我在MMCV的仓库里找到了可变形卷积的实现,并把它拆解回了普通的卷积。我参考了这篇教程:手把手教你如何高效地在 MMCV 中贡献算子。另外,这份笔记还参考了PyTorch官方Extension教程。
找到了卷积的实现后,在算子文件夹下新建一个cpp源文件。比如我的文件路径就是panoflow/core/op/my_conv.cpp
。这样一个普通卷积的实现如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
void my_conv_im2col_cpu(Tensor data_im,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, Tensor data_col);
void my_conv_im2col_cuda(Tensor data_im,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, Tensor data_col);
void my_conv_shape_check(at::Tensor input,
at::Tensor weight, int kH,
int kW, int dH, int dW, int padH, int padW,
int dilationH, int dilationW, int group)
{
TORCH_CHECK(
weight.ndimension() == 4,
"4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, but got: %s",
weight.ndimension());
TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
TORCH_CHECK(kW > 0 && kH > 0,
"kernel size should be greater than zero, but got kH: %d kW: %d",
kH, kW);
TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
"kernel size should be consistent with weight, ",
"but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d",
kH, kW, weight.size(2), weight.size(3));
TORCH_CHECK(dW > 0 && dH > 0,
"stride should be greater than zero, but got dH: %d dW: %d", dH,
dW);
TORCH_CHECK(
dilationW > 0 && dilationH > 0,
"dilation should be greater than 0, but got dilationH: %d dilationW: %d",
dilationH, dilationW);
int ndim = input.ndimension();
int dimf = 0;
int dimh = 1;
int dimw = 2;
if (ndim == 4)
{
dimf++;
dimh++;
dimw++;
}
TORCH_CHECK(ndim == 3 || ndim == 4,
"3D or 4D input tensor expected but got: %s", ndim);
long nInputPlane = weight.size(1) * group;
long inputHeight = input.size(dimh);
long inputWidth = input.size(dimw);
long nOutputPlane = weight.size(0);
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
if (outputWidth < 1 || outputHeight < 1)
AT_ERROR(
"Given input size: (%ld x %ld x %ld). "
"Calculated output size: (%ld x %ld x %ld). Output size is too small",
nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
outputWidth);
TORCH_CHECK(input.size(1) == nInputPlane,
"invalid number of input planes, expected: %d, but got: %d",
nInputPlane, input.size(1));
TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
"input image is smaller than kernel");
}
void my_conv_forward(Tensor input, Tensor weight, Tensor bias,
Tensor output, Tensor columns, int kW,
int kH, int dW, int dH, int padW, int padH,
int dilationW, int dilationH, int group,
int im2col_step)
{
bool isCuda = false;
if (input.device().is_cuda())
{
CHECK_CUDA_INPUT(input);
CHECK_CUDA_INPUT(weight);
CHECK_CUDA_INPUT(bias);
CHECK_CUDA_INPUT(output);
CHECK_CUDA_INPUT(columns);
isCuda = true;
}
else
{
CHECK_CPU_INPUT(input);
CHECK_CPU_INPUT(weight);
CHECK_CPU_INPUT(bias);
CHECK_CPU_INPUT(output);
CHECK_CPU_INPUT(columns);
}
my_conv_shape_check(input, weight, kH, kW, dH, dW, padH,
padW, dilationH, dilationW, group);
at::DeviceGuard guard(input.device());
int batch = 1;
if (input.ndimension() == 3)
{
// Force batch
batch = 0;
input.unsqueeze_(0);
}
long batchSize = input.size(0);
long nInputPlane = input.size(1);
long inputHeight = input.size(2);
long inputWidth = input.size(3);
long nOutputPlane = weight.size(0);
long outputWidth =
(inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
long outputHeight =
(inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
outputHeight, outputWidth});
columns = at::zeros(
{nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
input.options());
input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
inputHeight, inputWidth});
Tensor output_buffer = at::zeros({batchSize / im2col_step, nOutputPlane,
im2col_step * outputHeight, outputWidth},
output.options());
output_buffer = output_buffer.view(
{output_buffer.size(0), group, output_buffer.size(1) / group,
output_buffer.size(2), output_buffer.size(3)});
for (int elt = 0; elt < batchSize / im2col_step; elt++)
{
if (isCuda)
{
my_conv_im2col_cuda(input[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, columns);
}
else
{
my_conv_im2col_cpu(input[elt], nInputPlane, inputHeight,
inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
dilationW, im2col_step, columns);
}
columns = columns.view({group, columns.size(0) / group, columns.size(1)});
weight = weight.view({group, weight.size(0) / group, weight.size(1),
weight.size(2), weight.size(3)});
for (int g = 0; g < group; g++)
{
output_buffer[elt][g] = output_buffer[elt][g]
.flatten(1)
.addmm_(weight[g].flatten(1), columns[g])
.view_as(output_buffer[elt][g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
weight.size(3), weight.size(4)});
}
output_buffer = output_buffer.view(
{output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
output_buffer.size(3), output_buffer.size(4)});
output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
im2col_step, outputHeight, outputWidth});
output_buffer.transpose_(1, 2);
output.copy_(output_buffer);
output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
bias = bias.view({1, bias.size(0), 1, 1});
output.add_(bias);
input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
if (batch == 0)
{
output = output.view({nOutputPlane, outputHeight, outputWidth});
input = input.view({nInputPlane, inputHeight, inputWidth});
}
}
template <typename T>
void my_conv_im2col_cpu_kernel(
const int n, const T *data_im, const int height,
const int width, const int kernel_h, const int kernel_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int batch_size,
const int num_channels, const int height_col,
const int width_col, T *data_col)
{
for (int index = 0; index < n; index++)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
T val = static_cast<T>(0);
const int h_im = h_in + i * dilation_h;
const int w_im = w_in + j * dilation_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
val = data_im_ptr[h_im * width + w_im];
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
void my_conv_im2col_cpu(Tensor data_im,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, Tensor data_col)
{
int height_col =
(height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
int width_col =
(width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
int num_kernels = channels * height_col * width_col * parallel_imgs;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
data_im.scalar_type(), "", [&]
{ my_conv_im2col_cpu_kernel<scalar_t>(
num_kernels, data_im.data_ptr<scalar_t>(),
height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
parallel_imgs, channels,
height_col, width_col, data_col.data_ptr<scalar_t>()); });
}
void my_conv_im2col_cuda(Tensor data_im,
const int channels, const int height,
const int width, const int ksize_h,
const int ksize_w, const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int parallel_imgs, Tensor data_col)
{
}
PYBIND11_MODULE(my_ops, m)
{
m.def("my_conv_forward", my_conv_forward, "my_conv_forward",
py::arg("input"), py::arg("weight"), py::arg("bias"),
py::arg("output"), py::arg("columns"), py::arg("kW"),
py::arg("kH"), py::arg("dW"), py::arg("dH"), py::arg("padW"),
py::arg("padH"), py::arg("dilationW"), py::arg("dilationH"),
py::arg("group"), py::arg("im2col_step"));
}
这份实现非常长,我挑一些重点的内容讲解。
从最下面的PYBIND11_MODULE(my_ops, m)
看起。这里的my_ops
是生成的库名,可以随便取名。待会要import这个库名。代码块里m.def
用于定义C++函数的Python接口。"my_conv_forward"
是Python调用时的函数名称,my_conv_forward
是被Python代码调用的这份代码里的C++函数名称。也就是说,这份卷积实现的入口函数就是my_conv_forward
。我们从这个函数看起。
1 | void my_conv_forward(Tensor input, Tensor weight, Tensor bias, |
my_conv_forward
就是卷积的主函数。它的参数除了PyTorch的Conv2d
传入的参数外,还多了两个参数output, columus
。这两个张量是保存中间结果的,在PyTorch侧是看不到的。output
用于保存卷积输出,columns
用于保存卷积时的列矩阵。底层实现卷积时,会先把图像转换成一个用列表示的矩阵,再把卷积操作当成一个矩阵乘法来完成。其中,第一步操作叫做”im2col”。对此原理不熟的话可以参考这篇文章:https://zhuanlan.zhihu.com/p/63974249。
my_conv_forward
函数的大部分内容都是在做类型检查和张量形状转换。在修改卷积实现时,这些东西都可以不用改。整个卷积操作的核心都在这一部分:
1 | for (int elt = 0; elt < batchSize / im2col_step; elt++) |
这段代码先做了im2col
操作,再做了矩阵乘法。其实,包括可变形卷积在内,各种稀奇古怪的卷积操作通过靠修改im2col
来完成的。CPU和CUDA版卷积的主要区别,也体现在im2col
中(后面的矩阵乘法在CPU和CUDA上都能用)。
由于是讲CPU实现,这里的CUDA实现我暂时放了一个空函数。my_conv_im2col_cpu
的内容如下:
1 | void my_conv_im2col_cpu(Tensor data_im, |
这个函数其实只是处理了一下输入,真正的实现在my_conv_im2col_cpu_kernel
里。AT_DISPATCH_FLOATING_TYPES_AND_HALF
可以让实现兼容半精度和普通float,所以实现my_conv_im2col_cpu_kernel
得写成一个模板函数。
my_conv_im2col_cpu_kernel
的实现如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44template <typename T>
void my_conv_im2col_cpu_kernel(
const int n, const T *data_im, const int height,
const int width, const int kernel_h, const int kernel_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
const int batch_size,
const int num_channels, const int height_col,
const int width_col, T *data_col)
{
for (int index = 0; index < n; index++)
{
// index index of output matrix
const int w_col = index % width_col;
const int h_col = (index / width_col) % height_col;
const int b_col = (index / width_col / height_col) % batch_size;
const int c_im = (index / width_col / height_col) / batch_size;
const int c_col = c_im * kernel_h * kernel_w;
const int h_in = h_col * stride_h - pad_h;
const int w_in = w_col * stride_w - pad_w;
T *data_col_ptr =
data_col +
((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
const T *data_im_ptr =
data_im + (b_col * num_channels + c_im) * height * width;
for (int i = 0; i < kernel_h; ++i)
{
for (int j = 0; j < kernel_w; ++j)
{
T val = static_cast<T>(0);
const int h_im = h_in + i * dilation_h;
const int w_im = w_in + j * dilation_w;
if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
{
val = data_im_ptr[h_im * width + w_im];
}
*data_col_ptr = val;
data_col_ptr += batch_size * height_col * width_col;
}
}
}
}
它的作用就是把图像里的数据搬到做卷积运算的column
里。循环遍历每一次卷积的每一个位置,把待运算的量填入column
。卷积里的所有参数(pad, stride, …)都是在这段函数里生效的。想实现可变形卷积等改进,也要修改这个函数。
Python封装
实现好了后,如果编译完了的话,刚刚的卷积接口可以通过以下方式在Python里调用:
1 | import my_ops |
这里的my_ops
这个名称必须和开始PYBIND11_MODULE(my_ops, m)
里面那个库名称对应。
基于这个接口,可以仿照PyTorch中Conv2d
的接口,编写一个和Conv2d
等价的torch.nn.Module
出来。我的这个Python文件的路径是panoflow/core/op/my_conv.py
1 | import torch |
以后,用自己的卷积MyConv2d
就和用普通的Conv2d
一样了。
编译
打开外面的setup.py
,填写以下内容。
1 | from setuptools import setup |
其中的路径要根据自己的实际情况修改。
和编译相关的内容都写在cpp_extension.CppExtension
里。其中,源文件要写在第二个参数里,头文件目录要写在include_dirs
。由于我的源文件放在panoflow/core/op
里,我写了个源文件名数组cpp_src
,在传参前把路径组合了一下。由于include_dirs
和源文件在同一个目录下,我也填的是panoflow/core/op
。
写完了setup.py
后,运行python setup.py develop
,就能一键编译和安装。如果运行后没有报编译错误,就可以把实现的卷积用起来了。
单元测试
用单元测试可以快速地验证卷积是否实现成功。我写了一个简单的单元测试文件,在任意一个文件夹下创建该文件即可。
1 | import torch |
其中,panoflow.core.op.my_conv
是我刚刚放MyConv2d
的Python模块。
直接运行这个Python文件,如果没有任何输出(报错信息),就说明卷积实现成功了。
CUDA
C++实现
在刚刚的实现中,有一个my_conv_im2col_cuda
的实现是空着的。在CUDA版本中,我们要实现这个函数。不过,这个函数要放在一个用nvcc
编译的.cu
文件里。注意!注意!注意! 因此,my_conv.cpp
里那个空的my_conv_im2col_cuda
实现应该全部删掉。
新建一个文件my_conv_cuda.cu
。
1 | // Modify from https://github.com/open-mmlab/mmcv/blob/my_conv/mmcv/ops/csrc/common/cuda/deform_conv_cuda_kernel.cuh |
和CPU版的类似,my_conv_im2col_cuda
也是预处理了输入,并调用核函数my_conv_im2col_gpu_kernel
来实现im2col
。
CUDA实现和CPU几乎一样,唯一的区别就是for循环变成了CUDA_1D_KERNEL_LOOP(index, n)
。这个宏是头文件里帮我们定义的,它简化了CUDA的一维循环。
编译
修改setup.py
:
1 | from setuptools import setup |
首先,要把源文件加入cpp_src
里。之后,把CppExtension
改成CUDAExtension
。这样,就能编译新写的CUDA文件了。
写完了之后,再次python setup.py develop
编译即可。
编译小技巧:不拿IDE直接写C++和CUDA源代码是很容易出错误的。但如果你想只用
setup.py
来验证代码的正确性,可以python setup.py develop > tmp.txt
把编译输出重定向到一个文件里来查看。由于编译时的信息过多,在命令行里很难从一堆编译warning里找到最重要的error。
测试
由于Python部分在之前都已经写好了,可以直接用刚刚的单元测试文件测试了。只要把刚刚那份文件的device_name
改成cuda:0
即可。
1 | import torch |
同样,没报错就说明写对了。