使用元组输入进行计算和归约
使用元組輸入進(jìn)行計(jì)算和歸約
在一個(gè)循環(huán)中計(jì)算出具有相同形狀的多個(gè)輸出,或者執(zhí)行涉及多個(gè)值的歸約,例如 argmax。這些問題可以通過元組輸入解決。
本文將介紹TVM中元組輸入的用法。
from future import absolute_import, print_function
import tvm
from tvm import te
import numpy as np
描述Batchwise分批計(jì)算
對(duì)于形狀相同的算子te.compute,如果希望在下一個(gè)調(diào)度過程中一起調(diào)度,可以將它們放在一起作為輸入。
n = te.var(“n”)
m = te.var(“m”)
A0 = te.placeholder((m, n), name=“A0”)
A1 = te.placeholder((m, n), name=“A1”)
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name=“B”)
The generated IR code would be:
s = te.create_schedule(B0.op)
print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True))
輸出:
primfn(A0_1: handle, A1_1: handle, B.v0_1: handle, B.v1_1: handle) -> ()
attr = {“global_symbol”: “main”, “tir.noalias”: True}
buffers = {B.v1: Buffer(B.v1_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type=“auto”),
B.v0: Buffer(B.v0_2: Pointer(float32), float32, [m, n], [stride_2: int32, stride_3: int32], type=“auto”),
A1: Buffer(A1_2: Pointer(float32), float32, [m, n], [stride_4: int32, stride_5: int32], type=“auto”),
A0: Buffer(A0_2: Pointer(float32), float32, [m, n], [stride_6: int32, stride_7: int32], type=“auto”)}
buffer_map = {A0_1: A0, A1_1: A1, B.v0_1: B.v0, B.v1_1: B.v1} {
for (i: int32, 0, m) {
for (j: int32, 0, n) {
B.v0_2[((istride_2) + (jstride_3))] = ((float32*)A0_2[((istride_6) + (jstride_7))] + 2f32)
B.v1_2[((istride) + (jstride_1))] = ((float32*)A1_2[((istride_4) + (jstride_5))]*3f32)
}
}
}
描述協(xié)作輸入的約簡
多個(gè)輸入來表示一些歸約算子,這些輸入將一起協(xié)作,例如argmax。在簡化過程中,argmax比較算子的值,保留算子的索引。可以表示te.comm_reducer()如下:
x and y are the operands of reduction, both of them is a tuple of index
and value.
def fcombine(x, y):
lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
return lhs, rhs
our identity element also need to be a tuple, so fidentity accepts
two types as inputs.
def fidentity(t0, t1):
return tvm.tir.const(-1, t0), tvm.te.min_value(t1)
argmax = te.comm_reducer(fcombine, fidentity, name=“argmax”)
describe the reduction computation
m = te.var(“m”)
n = te.var(“n”)
idx = te.placeholder((m, n), name=“idx”, dtype=“int32”)
val = te.placeholder((m, n), name=“val”, dtype=“int32”)
k = te.reduce_axis((0, n), “k”)
T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name=“T”)
the generated IR code would be:
s = te.create_schedule(T0.op)
print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True))
出:
primfn(idx_1: handle, val_1: handle, T.v0_1: handle, T.v1_1: handle) -> ()
attr = {“global_symbol”: “main”, “tir.noalias”: True}
buffers = {T.v1: Buffer(T.v1_2: Pointer(int32), int32, [m: int32], [stride: int32], type=“auto”),
val: Buffer(val_2: Pointer(int32), int32, [m, n: int32], [stride_1: int32, stride_2: int32], type=“auto”),
T.v0: Buffer(T.v0_2: Pointer(int32), int32, [m], [stride_3: int32], type=“auto”),
idx: Buffer(idx_2: Pointer(int32), int32, [m, n], [stride_4: int32, stride_5: int32], type=“auto”)}
buffer_map = {idx_1: idx, val_1: val, T.v0_1: T.v0, T.v1_1: T.v1} {
for (i: int32, 0, m) {
T.v0_2[(istride_3)] = -1
T.v1_2[(istride)] = -2147483648
for (k: int32, 0, n) {
T.v0_2[(istride_3)] = @tir.if_then_else(((int32)val_2[((istride_1) + (kstride_2))] <= (int32*)T.v1_2[(istride)]), (int32)T.v0_2[(istride_3)], (int32)idx_2[((istride_4) + (kstride_5))], dtype=int32)
T.v1_2[(istride)] = @tir.if_then_else(((int32)val_2[((istride_1) + (kstride_2))] <= (int32*)T.v1_2[(istride)]), (int32)T.v1_2[(istride)], (int32)val_2[((istride_1) + (kstride_2))], dtype=int32)
}
}
}
注意
對(duì)于不熟悉歸約的人,請參閱“ 定義常規(guī)換向歸約運(yùn)算”。
使用元組輸入調(diào)度操作
盡管可以通過一次批處理算子獲得多個(gè)輸出,但是就算子而言,只能一起調(diào)度。
n = te.var(“n”)
m = te.var(“m”)
A0 = te.placeholder((m, n), name=“A0”)
B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name=“B”)
A1 = te.placeholder((m, n), name=“A1”)
C = te.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name=“C”)
s = te.create_schedule(C.op)
s[B0].compute_at(s[C], C.op.axis[0])
as you can see in the below generated IR code:
print(tvm.lower(s, [A0, A1, C], simple_mode=True))
輸出:
primfn(A0_1: handle, A1_1: handle, C_1: handle) -> ()
attr = {“global_symbol”: “main”, “tir.noalias”: True}
buffers = {C: Buffer(C_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type=“auto”),
A1: Buffer(A1_2: Pointer(float32), float32, [m, n], [stride_2: int32, stride_3: int32], type=“auto”),
A0: Buffer(A0_2: Pointer(float32), float32, [m, n], [stride_4: int32, stride_5: int32], type=“auto”)}
buffer_map = {A0_1: A0, A1_1: A1, C_1: C} {
attr [B.v0: Pointer(float32)] “storage_scope” = “global”;
allocate(B.v0, float32, [n]);
attr [B.v1: Pointer(float32)] “storage_scope” = “global”;
allocate(B.v1, float32, [n]);
for (i: int32, 0, m) {
for (j: int32, 0, n) {
B.v0[j] = ((float32*)A0_2[((istride_4) + (jstride_5))] + 2f32)
B.v1[j] = ((float32*)A0_2[((istride_4) + (jstride_5))]3f32)
}
for (j_1: int32, 0, n) {
C_2[((istride) + (j_1stride_1))] = ((float32)A1_2[((istride_2) + (j_1stride_3))] + (float32*)B.v0[j_1])
}
}
}
概要
本文介紹了元組輸入操作的用法。
? 描述正常的批量計(jì)算。
? 描述元組輸入的歸約運(yùn)算。
? 只能根據(jù)運(yùn)算而不是張量來調(diào)度計(jì)算。
總結(jié)
以上是生活随笔為你收集整理的使用元组输入进行计算和归约的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 自动调试用于移动GPU的卷积网络
- 下一篇: 使用Tensorize评估硬件内部特性