pytorch中算子注册原理

注:新手文章,欢迎指正!以下内容基于pytorch2.0.0

pytorch的官方教程https://pytorch.org/tutorials/advanced/extend_dispatcher.html 中,写了注册算子的主要方式是:

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl(<myadd_schema>, &myadd_autograd);
}

pytorch代码中,/home/pytorch/torch/library.h中定义了TORCH_LIBRARY_IMPL宏:

#define TORCH_LIBRARY_IMPL(ns, k, m) _TORCH_LIBRARY_IMPL(ns, k, m, C10_UID)

_TORCH_LIBRARY_IMPL宏的定义如下:

#define _TORCH_LIBRARY_IMPL(ns, k, m, uid)
static void C10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);
static const torch::detail::TorchLibraryInit C10_CONCATENATE(
TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(
torch::Library::IMPL,
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(
c10::DispatchKey::k)>(
[]() {
return &C10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid);
},
[]() { return [](torch::Library&) -> void {}; }),
#ns,
c10::make_optional(c10::DispatchKey::k),
__FILE__,
__LINE__);
void C10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library & m)

首先看C10_UID,其定义为:

#define C10_UID __COUNTER__
#define C10_ANONYMOUS_VARIABLE(str) C10_CONCATENATE(str, __COUNTER__)

因此其实际上为一个全局唯一的ID号。

C10_CONCATENATE的定义如下:

#define C10_CONCATENATE_IMPL(s1, s2) s1##s2
#define C10_CONCATENATE(s1, s2) C10_CONCATENATE_IMPL(s1, s2)

可见其就是连接了两个字符串,如果看不懂可以查一下##在C/C++预处理中的作用。

_TORCH_LIBRARY_IMPL的定义可以被分为以下三个部分:

声明一个静态函数:
static void C10_CONCATENATE(TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid)(torch::Library&);

函数名为TORCH_LIBRARY_IMPL_init_+ns+k+uid,假设TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m)的UID为20,那么函数名为:
TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20

定义一个cpp文件内部的常量:
static const torch::detail::TorchLibraryInit C10_CONCATENATE(
TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_, uid)(
torch::Library::IMPL,
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(
c10::DispatchKey::k)>(
[]() {
return &C10_CONCATENATE(
TORCH_LIBRARY_IMPL_init_##ns##_##k##_, uid);
},
[]() { return [](torch::Library&) -> void {}; }),
#ns,
c10::make_optional(c10::DispatchKey::k),
__FILE__,
__LINE__);

该常量类型为static const torch::detail::TorchLibraryInit,仍然以上面的例子为例,其名字为:
TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20,其和上面定义的静态函数的名字的差别就是多了一个static字符串。宏展开后,整段代码为如下:

static const torch::detail::TorchLibraryInit //返回类型
TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(
torch::Library::IMPL, //参数1,Library::Kind类型
c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(
[]() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},
[]() { return [](torch::Library&) -> void {}; }
), //参数2,InitFn*类型
“aten”, //参数3,const char*类型
c10::make_optional(c10::DispatchKey::AutogradPrivateUse1), //参数4,c10::optional<c10::DispatchKey>类型
__FILE__, //参数5,const char*类型
__LINE__); //参数6,uint32_t类型

TorchLibraryInit的类定义如下:

class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;

public:
TorchLibraryInit(
Library::Kind kind,
InitFn* fn,
const char* ns,
c10::optional<c10::DispatchKey> k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};

其有只包含一个Library类型的私有成员变量,注意其初始构造函数中,会先用kind, ns, k, file, line初始化lib_,再用传入的InitFn类型,也就是void(Library&)类型的函数初始化这个私有成员变量lib_。

在定义TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20的时候,第一个参数Library::Kind kind为torch::Library::IMPL,第二个参数为

c10::guts::if_constexpr<c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1)>(
[]() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},
[]() { return [](torch::Library&) -> void {}; }
), //参数2,InitFn*类型

首先看模板参数c10::impl::dispatch_key_allowlist_check(c10::DispatchKey::AutogradPrivateUse1),其定义为:

constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
#ifdef C10_MOBILE
return true;
// Disabled for now: to be enabled later!
// return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
#else
return true;
#endif
}

可见其目前无脑返回true,因此第二个参数变成:

c10::guts::if_constexpr<true>(
[]() {return &TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20;},
[]() { return [](torch::Library&) -> void {}; }
), //参数2,InitFn*类型

if_constexpr的定义如下:

template <bool Condition, class ThenCallback, class ElseCallback>
decltype(auto) if_constexpr(
ThenCallback&& thenCallback,
ElseCallback&& elseCallback) {
#if defined(__cpp_if_constexpr)
// If we have C++17, just use it’s “if constexpr” feature instead of wrapping
// it. This will give us better error messages.
if constexpr (Condition) {
if constexpr (detail::function_takes_identity_argument<
ThenCallback>::value) {
// Note that we use static_cast<T&&>(t) instead of std::forward (or
// ::std::forward) because using the latter produces some compilation
// errors about ambiguous `std` on MSVC when using C++17. This static_cast
// is just what std::forward is doing under the hood, and is equivalent.
return static_cast<ThenCallback&&>(thenCallback)(detail::_identity());
} else {
return static_cast<ThenCallback&&>(thenCallback)();
}
} else {
if constexpr (detail::function_takes_identity_argument<
ElseCallback>::value) {
return static_cast<ElseCallback&&>(elseCallback)(detail::_identity());
} else {
return static_cast<ElseCallback&&>(elseCallback)();
}
}
#else
// C++14 implementation of if constexpr
return detail::_if_constexpr<Condition>::call(
static_cast<ThenCallback&&>(thenCallback),
static_cast<ElseCallback&&>(elseCallback));
#endif
}

这里有点炫技的味道了,直接看注释:

Example 1: simple constexpr if/then/else
template<int arg> int increment_absolute_value() {
int result = arg;
if_constexpr<(arg > 0)>(
[&] { ++result; } // then-case
[&] { –result; } // else-case
);
return result;
}

所以这就是一个简单的模板编译期if else,由于其模板参数为true,因此第二个参数就是第一部分定义的静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20,之后的参数就不再赘述了,值得注意的是,第四个参数c10::make_optional(c10::DispatchKey::AutogradPrivateUse1)颇为复杂。

正式定义第一步声明的静态函数,宏展开后为:void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){
m.impl(<myadd_schema>, &myadd_autograd);
}

整个代码简化之前为:

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
m.impl(<myadd_schema>, &myadd_autograd);
}

宏展开+简化后为:

static void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m);

static const torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20(
torch::Library::IMPL, //参数1,Library::Kind类型
&TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20, //参数2,InitFn*类型
“aten”, //参数3,const char*类型
c10::make_optional(c10::DispatchKey::AutogradPrivateUse1), //参数4,c10::optional<c10::DispatchKey>类型
__FILE__, //参数5,const char*类型
__LINE__); //参数6,uint32_t类型

void TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20(torch::Library & m){
m.impl(<myadd_schema>, &myadd_autograd);
}

//TorchLibraryInit的定义,在library.h中定义
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;

public:
TorchLibraryInit(
Library::Kind kind,
InitFn* fn,
const char* ns,
c10::optional<c10::DispatchKey> k,
const char* file,
uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};

到这里总结一下:

① 第一部分声明了一个静态函数TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20

② 第二部分声明了一个torch::detail::TorchLibraryInit类型的静态常量TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20,在有一个Library类型的成员变量,通过传入的参数和第一部分声明的静态函数来初始化这个成员变量。

③ 第三部分则是实现了第一部分声明的函数。

注意这个函数通过调用torch::Library类型参数的impl成员函数来实现算子注册,而传入的实参实际上第二部分声明的静态常量的私有成员变量,而第二部分的静态常量名称为TORCH_LIBRARY_IMPL_static_init_##ns##_##k##_##uid,也就是取决于命名空间(namespace)、设备(cpu or cuda or XXX)以及UID。

TORCH_LIBRARY_IMPL_static_init_aten_AutogradPrivateUse1_20的初始构造函数利用TORCH_LIBRARY_IMPL_init_aten_AutogradPrivateUse1_20来初始化其私有成员变量lib_,初始化方法为调用其私有成员变量lib_的impl方法。

下面讲解torch::Library类的impl方法,其定义如下:

/// Register an implementation for an operator. You may register multiple
/// implementations for a single operator at different dispatch keys
/// (see torch::dispatch()). Implementations must have a corresponding
/// declaration (from def()), otherwise they are invalid. If you plan
/// to register multiple implementations, DO NOT provide a function
/// implementation when you def() the operator.
///
/// param name The name of the operator to implement. Do NOT provide
/// schema here.
/// param raw_f The C++ function that implements this operator. Any
/// valid constructor of torch::CppFunction is accepted here;
/// typically you provide a function pointer or lambda.
///
/// “`
/// // Example:
/// TORCH_LIBRARY_IMPL(myops, CUDA, m) {
/// m.impl(“add”, add_cuda);
/// }
/// “`
template <typename Name, typename Func>
Library& impl(Name name, Func&& raw_f, _RegisterOrVerify rv = _RegisterOrVerify::REGISTER) & {
// TODO: need to raise an error when you impl a function that has a
// catch all def
#if defined C10_MOBILE
CppFunction f(std::forward<Func>(raw_f), NoInferSchemaTag());
#else
CppFunction f(std::forward<Func>(raw_f));
#endif
return _impl(name, std::move(f), rv);
}

……………健身去了,未完待续

​ 

Read More 

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *