【PyTorch 源码阅读】DispatchKeySet 源码篇

Posted by Masutangu on June 16, 2024

代码路径

https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKeySet.h

定义

BackendComponent

根据注释:

1
2
3
4
5
6
7
8
9
10
11
12
// Semantically, each value of BackendComponent identifies a "backend" for our
// dispatch. Some functionalities that we may dispatch to are allowed to
// register different handlers for each backend. The BackendComponent is then
// used to figure out which backend implementation to dispatch to.

// In implementation terms, the backend component identifies a specific "bit" in
// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
// functionalities. When we encounter a functionality bit that is known to be
// customizeable per-backend, then we also look at the lower BackendComponent
// bits and take the highest bit to determine which backend's implementation to
// use.

每个 BackendComponent 的值标识了一个调度的后端,有些功能(functionalities)允许为不同的后端注册不同的处理程序,BackendComponent 用于确定要调度到哪个后端实现。这样,我们可以根据需要选择适当的后端处理程序来执行特定的功能。

实现上,BackendComponent 由 DispatchKeySet 中一个特定的 bit 位来标记。DispatchKeySet 中的位可分为 BackendComponent 位和功能位。当遇到后端自定义的功能位时,会查看 BackendComponent 位,并选择其中最高位来确定使用哪个后端的实现。具体见 源码阅读-DispatchKeySet 小节中的 highestPriorityTypeId 函数。

PyTorch 使用宏来生成 BackendComponent 的枚举值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#define C10_FORALL_BACKEND_COMPONENTS(_, extra) \
  _(CPU, extra)                                 \
  _(CUDA, extra)                                \
  _(HIP, extra)                                 \
  _(XLA, extra)                                 \
  _(MPS, extra)                                 \
  _(IPU, extra)                                 \
  _(XPU, extra)                                 \
  _(HPU, extra)                                 \
  _(VE, extra)                                  \
  _(Lazy, extra)                                \
  _(MTIA, extra)                                \
  _(PrivateUse1, extra)                         \
  _(PrivateUse2, extra)                         \
  _(PrivateUse3, extra)                         \
  _(Meta, extra)

// WARNING!  If we add a new per-backend functionality key that has higher
// priority than Autograd, then make sure you update EndOfRuntimeBackendKeys

#define C10_FORALL_FUNCTIONALITY_KEYS(_) \
  _(Dense, )                             \
  _(Quantized, Quantized)                \
  _(Sparse, Sparse)                      \
  _(NestedTensor, NestedTensor)          \
  _(AutogradFunctionality, Autograd)

enum class BackendComponent : uint8_t {

  // A "backend" is colloquially used to refer to handlers for dispatch
  // which actually implement the numerics of an operation in question.
  //
  // Due to the nature of the enum, these backends are specified in
  // an ordered way, but for most backends this order is not semantically
  // meaningful (e.g., it's valid to reorder these backends without changing
  // semantics).  The only situation when backend ordering is meaningful
  // is when the backend participates in multiple dispatch with another
  // backend; e.g., CPU and CUDA (cuda must have higher priority).

  // These keys don't correspond to individual kernels.
  // Instead, they represent the backends that are allowed to override specific
  // pieces of functionality:
  // - dense kernels (e.g. DispatchKey::CPU)
  // - sparse kernels (e.g. DispatchKey::SparseCPU)
  // - quantized kernels (e.g. DispatchKey::QuantizedCPU)
  // - autograd kernels (e.g. DispatchKey::AutogradCPU)
  // We reserve space in the runtime operator table for this full cross product
  // of
  // [backends in this enum] x [keys below that are explicitly marked as having
  // per-backend functionality]

  InvalidBit = 0,
#define DEFINE_BACKEND_COMPONENT(n, _) n##Bit,
  C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused)
#undef DEFINE_BACKEND_COMPONENT

  // Define an alias to represent end of backend dispatch keys.
  // If you add new backend keys after PrivateUse3, please also update it here.
  EndOfBackendKeys = MetaBit,
};

C10_FORALL_BACKEND_COMPONENTS 接受两个参数:_extra。传入DEFINE_BACKEND_COMPONENTunused 时,参数 _DEFINE_BACKEND_COMPONENT,参数 extraunused。执行 _(CPU, extra),即对 DEFINE_BACKEND_COMPONENT(n, _) 进行宏展开,其中 nCPU_unused,展开结果为 CPU##BitCPUBit

再次强调,BackendComponent 定义的这些键并不对应于单个内核。它们代表允许覆盖特定功能的后端,例如:

  • 密集内核(例如 DispatchKey::CPU)
  • 稀疏内核(例如 DispatchKey::SparseCPU)
  • 量化内核(例如 DispatchKey::QuantizedCPU)
  • 自动微分内核(例如 DispatchKey::AutogradCPU)

DispatchKey

根据注释:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// Semantically, a dispatch key identifies a possible "level" in our
// dispatch, for which a handler may be registered. Each handler corresponds
// to a type of functionality.
//
// In implementation terms, the dispatch key identifies a specific "bit" in a
// DispatchKeySet.  Higher bit indexes get handled by dispatching first (because
// we "count leading zeros" when we extract the highest priority dispatch
// key.)
//
// Note [DispatchKey Classification]
// This enum actually contains several types of keys, which are explained
// in more detail further down:
// (1) non-customizable backends (e.g. FPGA)
// (2) non-customizable functionalities (e.g. Functionalize)
// (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
// AutogradFunctionality) (4) per-backend instances of customizable
// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
// CompositeImplicitAutograd)
//
// Of the categories above, it's important to note:
// (a) which keys are assigned individual bits in a DispatchKeySet
// (b) which keys are assigned individual slots in the runtime operator table
// ("Runtime keys")
//
// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
// (1), (2) and (4) all get their own dedicated slots in the runtime operator
// table.

可以为调度键注册相应的 handler,每个 handler 对应一种特定类型的功能。 实现上,调度键标识了 DispatchKeySet 中某一特定的 “位”,较高的位优先调度。

调度键可分为以下几种类别:

  1. 不可自定义的后端(例如 FPGA)
  2. 不可自定义的功能(例如 Functionalize)
  3. 可以根据后端自定义的功能(例如 Dense、Sparse、AutogradFunctionality)
  4. 可自定义功能的后端实例(例如 CPU、SparseCPU、AutogradCPU)
  5. 别名键(例如 CompositeImplicitAutograd)

在上述类别中:

  • 1、2 和 3 在 DispatchKeySet 中有自己的专用位
  • 1、2 和 4 在运行时操作符表(runtime operator table)中有自己的专用 slot

类似的,PyTorch 使用宏来生成 DispatchKey 的枚举值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#define C10_FORALL_FUNCTIONALITY_KEYS(_) \
  _(Dense, )                             \
  _(Quantized, Quantized)                \
  _(Sparse, Sparse)                      \
  _(NestedTensor, NestedTensor)          \
  _(AutogradFunctionality, Autograd)

enum class DispatchKey : uint16_t {
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
  // This is not a "real" functionality, but it exists to give us a "nullopt"
  // element we can return for cases when a DispatchKeySet contains no elements.
  // You can think a more semantically accurate definition of DispatchKey is:
  //
  //    using DispatchKey = optional<RealDispatchKey>
  //
  // and Undefined == nullopt.  We didn't actually represent
  // it this way because optional<RealDispatchKey> would take two
  // words, when DispatchKey fits in eight bits.

  Undefined = 0,
  // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
  // Every value in the enum (up to EndOfFunctionalityKeys)
  // corresponds to an individual "functionality" that can be dispatched to.
  // This is represented in the DispatchKeySet by assigning each of these enum
  // values to each of the remaining (64 - len(BackendComponent)) bits.
  //
  // Most of these functionalities have a single handler assigned to them,
  // making them "runtime keys".
  // That map to a single slot in the runtime operator table.
  //
  // A few functionalities are allowed to be customizable per backend.
  // See [Note: Per-Backend Functionality Dispatch Keys] for details.

  // See [Note: Per-Backend Functionality Dispatch Keys]
  Dense,
  // Below are non-extensible backends.
  // These are backends that currently don't have their own overrides for
  // Autograd/Sparse/Quantized kernels,
  // and we therefore don't waste space in the runtime operator table allocating
  // space for them.
  // If any of these backends ever need to customize, e.g., Autograd, then we'll
  // need to add a DispatchKey::*Bit for them.
  FPGA,
  ...
  // See [Note: Per-Backend Functionality Dispatch Keys]
  Quantized,
  // See [Note: Per-Backend Functionality Dispatch Keys]
  Sparse,
  AutogradFunctionality,
  ...
  EndOfFunctionalityKeys,  // End of functionality keys.
  
// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
// Here are backends which you think of as traditionally specifying
// how to implement operations on some device.

#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n,

#define DEFINE_PER_BACKEND_KEYS(fullname, prefix)      \
  StartOf##fullname##Backends,                         \
      C10_FORALL_BACKEND_COMPONENTS(                   \
          DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \
  EndOf##fullname##Backends = prefix##Meta,

  C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS)

#undef DEFINE_PER_BACKEND_KEYS
#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND

  EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends,

  // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
  // Note [Alias Dispatch Keys]
  // **Alias dispatch keys are synthetic dispatch keys which map to multiple
  // runtime dispatch keys**. Alisa keys have precedence, but they are always
  // lower precedence than runtime keys. You can register a kernel to an
  // alias key, the kernel might be populated to the mapped runtime keys
  // during dispatch table computation.
  // If a runtime dispatch key has multiple kernels from alias keys, which
  // kernel wins is done based on the precedence of alias keys (but runtime
  // keys always have precedence over alias keys).
  // Alias keys won't be directly called during runtime.
  // Define an alias key to represent end of alias dispatch keys.
  // If you add new alias keys after Autograd, please also update it here.
  Autograd,
  StartOfAliasKeys = Autograd,
  EndOfAliasKeys = Autograd, //
};

C10_FORALL_FUNCTIONALITY_KEYS 生成的运行时键如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
StartOfDenseBackends,
CPU,
CUDA,
Meta,
EndOfDenseBackends = Meta,
StartOfQuantizedBackends,
QuantizedCPU,
QuantizedCUDA,
QuantizedMeta,
EndOfQuantizedBackends = QuantizedMeta,
StartOfSparseBackends,
SparseCPU,
SparseCUDA,
SparseMeta,
EndOfSparseBackends = SparseMeta,
StartOfNestedTensorBackends,
NestedTensorCPU,
NestedTensorCUDA,
NestedTensorMeta,
EndOfNestedTensorBackends = NestedTensorMeta,
StartOfAutogradFunctionalityBackends,
AutogradCPU,
AutogradCUDA,
AutogradMeta,
EndOfAutogradFunctionalityBackends = AutogradMeta,

关于注释中提到的 Note: Per-Backend Functionality Dispatch Keys:

1
2
3
4
5
6
7
8
9
// [Note: Per-Backend Functionality Dispatch Keys]
// These keys correspond to functionalities that can be customized indivually
// per backend. While they only take up one bit in the `DispatchKeySet` bitset,
// they map to (# backends) slots in the operator table.
// Each of these keys also has a separate set of "runtime keys" in the dispatch
// key enum, per backend, which *do* map to the individual operator table slots.
// For example, the "Sparse" key maps to an individual bit in the
// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
// slots in the runtime operator table.

Per-Backend Functionality Dispatch Key 对应于可以由后端单独定制的功能。虽然这些键在DispatchKeySet位集中只占用一个位,但在操作符表中映射到【后端数量】个槽位。每个键在调度键枚举中有一个单独的“运行时键”集合,每个后端有一个枚举值,映射到操作符表单独的槽位。举例来说,“Sparse“ 键映射到 DispatchKeySet 中的一个单独位,而 SparseCPUSparseCUDA 等则映射到运行时操作符表中的单独槽位。

DispatchKeySet

根据注释:

1
2
3
4
5
6
7
// A representation of a set of DispatchKeys. **A DispatchKeySet contains both
// "functionality" bits and "backend bits", and every tensor holds its own
// DispatchKeySet.** **The Dispatcher implements multiple dispatch by grabbing the
// keyset on every input tensor, or’ing them together, and dispatching to a
// specific piece of functionality.** **The functionality bits are *ordered***. When
// **multiple functionality bits are set, we use the highest priority**
// **functionality**. 

DispatchKeySet 即调度键的集合。DispatchKeySet 包含了“功能(functionality)”位“后端(backend)位”,每个 tensor 都有自己的 DispatchKeySet。Dispatcher 通过获取每个 input tensorDispatchKeySet,进行逻辑或操作,然后调度到特定的功能。功能性位是有序的。当设置了多个功能性位时,使用最高优先级的功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// Note [DispatchKeySet Internal Representation]
// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
// that get passed around at runtime.
// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
// and individual dispatch keys.
//
// First: why do we have this distinction, and why not map every dispatch key
// directly to a bit? This is mostly because we have several types of
// functionalities that different backends would like to customize. For example,
// we have:
// - "Dense":     CPU, CUDA, XLA, ... (~12 keys)
// - "Sparse":    SparseCPU, SparseCUDA, ...
// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
// - "Autograd":  AutogradCPU, AutogradCUDA, Autograd XLA, ...
// The problem is that total number of keys grows quadratically with [#
// backends] x [# functionalities], making it very difficult to map each key
// directly to a bit in a bitset without dramatically increasing the size of the
// bitset over time.

调度键和 DispatchKeySet 中的位并不是一对一的关系,主要是因为 Per-Backend Functionality Dispatch Key 的存在,如下:

  • “Dense”: CPU、CUDA、XLA 等
  • “Sparse”: SparseCPU、SparseCUDA 等
  • “Quantized”: QuantizedCPU、QuantizedCUDA、QuantizedXLA 等
  • “Autograd”: AutogradCPU、AutogradCUDA、AutogradXLA 等

调度键的总数将随着【后端数量】 x 【功能数量】呈二次增长,如果直接将每个键映射到 DispatchKeySet 中的一个位,将会大大增加 DispatchKeySet 的大小。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// The two enums (BackendComponent and DispatchKey) can be divided roughly into
// 5 categories.
//
// (1) "Building block" keys
//    (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit, 
//        CUDABIt) 
//    (b) functionalities: (per-backend) functionality-bit DispatchKeys
//        (e.g. AutogradFunctionality, Sparse, Dense)
// (2) "Runtime" keys
//    (a) "non-customizable backends" (e.g. FPGA)
//    (b) "non-customizable functionalities" (e.g. Functionalize)
//    (c) "per-backend instances of customizable functionalities" (e.g. CPU,
//        SparseCPU, AutogradCPU)
// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])

BackendComponent 和 DispatchKey 枚举大致可以分为 5 个类别:

  1. “构建块”键
    a. 后端:BackendComponent 枚举中的所有值(CPUBit、CUDABit 等)
    b. 功能:(每个后端)功能位的 DispatchKeys(例如 AutogradFunctionality、Sparse、Dense)

  2. “运行时”键
    a. “不可定制的后端”(例如 FPGA)
    b. “不可定制的功能”(例如 Functionalize)
    c. “可定制功能的每个后端实例”(例如 CPU、SparseCPU、AutogradCPU)

  3. “别名” DispatchKeys(见 Note [Alias Dispatch Keys])

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// (1) Building block keys always correspond to individual bits in a
// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
// runtime keys. e.g.
//     auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
//     DispatchKey::Dense});
//     // The keyset has the runtime dense-cpu key.
//     dense_cpu_ks.has(DispatchKey::CPU);
//     // And it contains the building block keys too.
//     dense_cpu_ks.has(DispatchKey::CPUBit);
//     dense_cpu_ks.has(DispatchKey::Dense);
//
// Not every backend and not every functionality counts as a "building block
// key". This is mostly to give us more levers to pull in the design space.
// Backend keys and functionality keys that count as "building blocks" will
// contribute to a full cross product of functionality that can be overriden.

“构建块”键始终对应于 DispatchKeySet 中的某个位。它们也可以被组合形成实际的运行时键,例如:

1
2
3
4
5
6
7
auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit, DispatchKey::Dense});
// 该键集具有运行时的 dense-cpu 键。
dense_cpu_ks.has(DispatchKey::CPU);
// 它还包含构建块键。
dense_cpu_ks.has(DispatchKey::CPUBit);
dense_cpu_ks.has(DispatchKey::Dense);

在上述示例中,dense_cpu_ks 是一个包含 DispatchKey::CPUBitDispatchKey::Dense 的调度键集。它既包含了运行时键 DispatchKey::CPU,也包含了构建块键 DispatchKey::CPUBitDispatchKey::Dense

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// (2) Every runtime key corresponds directly to a slot in an operator's runtime
// dispatch table, and you can directly register kernels to a runtime dispatch
// key.
//
// For per-backend functionalities like "Dense" or "AutogradFunctionality",
// you can think of the corresponding runtime dispatch keys as "instances" of
// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
// runtime instances of the "Dense" building block key.

// (2a) and (2b) are represented identically in the DispatchKeySet logic:
// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
// customizeable per backend.
//   In order to do so, we'd need to promote it to a per-backend functionality
//   "building block" key.
// - non-customizeable backends (e.g. FPGA) can NOT customize existing
// functionality like Sparse, Autograd, etc.
//   In order to do so, we'd need to promote it to a backend "building block"
//   key.
//
// In both cases, these keys directly correspond to runtime slots in the
// operator table.

“运行时”键直接对应于操作符的运行时调度表中的一个 slot,可以直接为运行时键注册内核。像 “Dense” 或 “AutogradFunctionality” 这样的 per-backend 功能,可以将相应的运行时调度键视为该功能在每个后端上的 “实例”。例如,”CPU”、”CUDA”、”XLA” 等都是 “Dense” 构建块键的运行时实例。

2a 和 2b 在 DispatchKeySet 逻辑中的表示方式是相同的:

  • 与后端无关的功能(例如 FuncTorchBatched)不能在每个后端上进行自定义。要实现这一点,我们需要将其提升为 per-backend 功能 “构建块” 键。
  • 不可自定义的后端(例如 FPGA)不能自定义现有的功能,如 Sparse、Autograd 等。要实现这一点,我们需要将其提升为后端 “构建块” 键。

在这两种情况下,这些键直接对应于操作符表中的运行时 slot。

总结

DispatchKey

类型 功能键/后端键 占用位/占用 slot
不可自定义的后端 后端键 占用位/占用slot
不可自定义的功能 功能键 占用位/占用slot
可以根据后端自定义的功能 功能键 占用位
可以自定义功能的后端实例 后端键 占用slot

DispatchKey + BackendComponent

类型 构建键/运行时键 占用位/占用 slot
BackendComponent 所有枚举值(可定制的后端) 构建键 占用位
(per-backend)的功能(可自定义的功能) 构建键 占用位
不可自定义的后端 运行时键 占用位/占用slot
不可自定义的功能 运行时键 占用位/占用slot
可自定义功能的后端 运行时键 占用slot

源码阅读

DispatchKey

主要关注这三个转换函数:

  • toBackendComponent:计算出该 DispatchKey(需要是 Per-Backend Functionality Dispatch Key)对应的 BackendComponent
  • toFunctionalityKey:计算出该 DispatchKey(需要是 Per-Backend Functionality Dispatch Key)对应的 Functionality
  • toRuntimePerBackendFunctionalityKey :将 BackendComponent 和 Functionality 组合成 Per-Backend Functionality Dispatch Key
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr BackendComponent toBackendComponent(DispatchKey k) {
  if (k >= DispatchKey::StartOfDenseBackends &&
      k <= DispatchKey::EndOfDenseBackends) {
    return static_cast<BackendComponent>(
        static_cast<uint8_t>(k) -
        static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
  } else if (
      k >= DispatchKey::StartOfQuantizedBackends &&
      k <= DispatchKey::EndOfQuantizedBackends) {
    return static_cast<BackendComponent>(
        static_cast<uint8_t>(k) -
        static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
  }
  ...
  } else {
    return BackendComponent::InvalidBit;
  }
}

constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
  if (k <= DispatchKey::EndOfFunctionalityKeys) {
    return k;
  } else if (k <= DispatchKey::EndOfDenseBackends) {
    return DispatchKey::Dense;
  } else if (k <= DispatchKey::EndOfQuantizedBackends) {
    return DispatchKey::Quantized;
  } else if (k <= DispatchKey::EndOfSparseBackends) {
    return DispatchKey::Sparse;
  } else if (k <= DispatchKey::EndOfNestedTensorBackends) {
    return DispatchKey::NestedTensor;
  } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
    return DispatchKey::AutogradFunctionality;
  } else {
    return DispatchKey::Undefined;
  }
}

// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns
// DispatchKey::CUDA.
// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
// This function relies on the invariant that the dispatch keys between
// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
// in the same order as `BackendComponent`.
constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
    DispatchKey functionality_k,
    BackendComponent backend_k) {
  if (functionality_k == DispatchKey::Dense) {
    return static_cast<DispatchKey>(
        static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
        static_cast<uint8_t>(backend_k));
  }
  if (functionality_k == DispatchKey::Sparse) {
    return static_cast<DispatchKey>(
        static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
        static_cast<uint8_t>(backend_k));
  }
  if (functionality_k == DispatchKey::Quantized) {
    return static_cast<DispatchKey>(
        static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
        static_cast<uint8_t>(backend_k));
  }
  if (functionality_k == DispatchKey::NestedTensor) {
    return static_cast<DispatchKey>(
        static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends) +
        static_cast<uint8_t>(backend_k));
  }
  if (functionality_k == DispatchKey::AutogradFunctionality) {
    return static_cast<DispatchKey>(
        static_cast<uint8_t>(
            DispatchKey::StartOfAutogradFunctionalityBackends) +
        static_cast<uint8_t>(backend_k));
  }
  return DispatchKey::Undefined;
}

DispatchKeySet

1
2
3
4
class DispatchKeySet final {
private:
 uint64_t repr_ = 0;
}

构造函数

由 BackendComponent / DispatchKey 初始化 DispatchKeySet 中对应的位。重点关注 DispatchKey 对不同类型键的处理(尤其是 case 3)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  constexpr explicit DispatchKeySet(BackendComponent k) {
    if (k == BackendComponent::InvalidBit) {
      repr_ = 0;
    } else {
      repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
    }
  }

  constexpr explicit DispatchKeySet(DispatchKey k) {
    if (k == DispatchKey::Undefined) {
      repr_ = 0;
    } else if (k <= DispatchKey::EndOfFunctionalityKeys) {
      // Case 2: handle "functionality-only" keys
      // These keys have a functionality bit set, but no backend bits
      // These can technically be either:
      // - valid runtime keys (e.g. DispatchKey::AutogradOther,
      // DispatchKey::FuncTorchBatched, etc)
      // - "building block" keys that aren't actual runtime keys (e.g.
      // DispatchKey::Dense or Sparse)
      uint64_t functionality_val = 1ULL
          << (num_backends + static_cast<uint8_t>(k) - 1);
      repr_ = functionality_val;
    } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
      // Case 3: "runtime" keys that have a functionality bit AND a backend bit.
      // First compute which bit to flip for the functionality.
      auto functionality_k = toFunctionalityKey(k);
      // The - 1 is because Undefined is technically a "functionality" that
      // doesn't show up in the bitset. So e.g. Dense is technically the second
      // functionality, but the lowest functionality bit.
      uint64_t functionality_val = 1ULL
          << (num_backends + static_cast<uint8_t>(functionality_k) - 1);

      // then compute which bit to flip for the backend
      // Case 4a: handle the runtime instances of "per-backend functionality"
      // keys For example, given DispatchKey::CPU, we should set:
      // - the Dense functionality bit
      // - the CPUBit backend bit
      // first compute which bit to flip for the backend
      auto backend_k = toBackendComponent(k);
      uint64_t backend_val = backend_k == BackendComponent::InvalidBit
          ? 0
          : 1ULL << (static_cast<uint8_t>(backend_k) - 1);
      repr_ = functionality_val + backend_val;
    } else {
      // At this point, we should have covered every case except for alias keys.
      // Technically it would be possible to add alias dispatch keys to a
      // DispatchKeySet, but the semantics are a little confusing and this
      // currently isn't needed anywhere.
      repr_ = 0;
    }
  }

优先级计算

indexOfHighestBit 返回最高有效位的索引:

1
2
3
4
5
6
7
8
  
  // Returns the index of the most-significant bit in the keyset.
  // This is used to as part of the calculation into the operator table to get:
  // - the highest "functionality" bit in the keyset.
  // - the highest "backend" bit in the keyset.
  uint8_t indexOfHighestBit() const {
    return 64 - llvm::countLeadingZeros(repr_);
  }

highestFunctionalityKey:返回最高优先级的 Functionality
highestBackendKey:返回最高优先级的 BackendComponent

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
  DispatchKey highestFunctionalityKey() const {
    auto functionality_idx = indexOfHighestBit();
    // This means that none of the functionality bits were set.
    if (functionality_idx < num_backends)
      return DispatchKey::Undefined;
    // The first num_backend bits in the keyset don't correspond to real
    // dispatch keys.
    return static_cast<DispatchKey>(functionality_idx - num_backends);
  }

  BackendComponent highestBackendKey() const {
    // mask to mask out functionality bits
    auto backend_idx =
        DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
    // all zeros across the backend bits means that no backend bits are set.
    if (backend_idx == 0)
      return BackendComponent::InvalidBit;
    return static_cast<BackendComponent>(backend_idx);
  }

highestPriorityTypeId:返回 DispatchKeySet 中最高优先级的 DispatchKey,如果是 per-backend 功能,则需要调用 toRuntimePerBackendFunctionalityKey,返回 backend 和 functionality 组合成的 DispatchKey:

1
2
3
4
5
6
7
8
9
  // returns the DispatchKey of highest priority in the set.
  DispatchKey highestPriorityTypeId() const {
    auto functionality_k = highestFunctionalityKey();
    if (isPerBackendFunctionalityKey(functionality_k)) {
      return toRuntimePerBackendFunctionalityKey(
          functionality_k, highestBackendKey());
    }
    return functionality_k;
  }