# PROGRAMMING TENSOR CORES: NATIVE VOLTA TENSOR CORES WITH CUTLASS

Andrew Kerr, Timmy Liu, Mostafa Hagog, Julien Demouth, John Tran



March 20, 2019

# **PROGRAMMING TENSOR CORES IN CUDA**

mma.sync (new instruction in CUDA 10.1)

Feeding the Data Path

CUTLASS 1.3 - Native Volta Tensor Cores GEMM (March 20, 2019)





#### **TENSOR CORES**

**Tensor Cores** 

- **8x speedup** for mixed-precision matrix multiply
- Programmable via WMMA API (CUDA 9) .

Direct access to Volta Tensor Cores: mma.sync (new instruction in CUDA 10.1)

- Maximum efficiency on Volta SM Architecture
- New in CUTLASS 1.3



#### **TENSOR CORES**

This talk is about <u>Volta</u> Tensor Cores.

#### Warp-synchronous Matrix Multiply Accumulate

mma.sync

(WMMA API)

#### portable abstraction layer for Tensor Cores

<u>Direct access</u> to Volta Tensor Cores



### VOLTA MMA.SYNC

# VOLTA MMA.SYNC

Warp-scoped matrix multiply instruction

mma.sync: new instruction in CUDA 10.1

• Directly targets Volta Tensor Cores

#### Matrix multiply-accumulate

 $\mathsf{D} = \mathsf{A} * \mathsf{B} + \mathsf{C}$ 

- A, B: half
- C, D: float or half

Warp-synchronous:

Four independent 8-by-8-by-4 matrix multiply-accumulate operations



# VOLTA MMA.SYNC

Warp-scoped matrix multiply instruction

#### Warp is partitioned into Quad Pairs

- **QPO:** T0..T3 T16..T19
- **QP1:** T4..T7 T20..T23
- QP2: T8..T11 T24..T27
- QP3: T12..T15 T28..T31

(eight threads each)

Each Quad Pair performs one **8-by-8-by-4** matrix multiply



## **COMPOSING MATRIX MULTIPLIES**

Replicate data to compute warp-wide 16-by-16-by-4 matrix product

- A<sub>0..7</sub>: QP0, QP2 A<sub>8..15</sub>: QP1, QP3
- B<sub>0..7</sub>: QP0, QP1 B<sub>8..15</sub>: QP2, QP3



1 x mma.sync: 16-by-16-by-4



## VOLTA MMA.SYNC D = A \* B + C

#### **PTX Syntax**

mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype d, a, b, c;

.alayout = {.row, .col}; .blayout = {.row, .col}; .ctype = {.f16, .f32}; .dtype = {.f16, .f32};

- d: 8 x .dtype
- a: 4 x .f16
- b: 4 x .f16
- c: 8 x .ctype

Note: .f16 elements must be packed into .f16x2



# **THREAD-DATA MAPPING - F16 MULTIPLICANDS**

Distributed among threads in quad pair (QP0 shown)

mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype d, a, b, c;



### FEEDING THE DATA PATH

## FEEDING THE DATA PATH

Efficiently storing and loading through shared memory



See <u>CUTLASS GTC 2018</u> talk for more details about this model.

### **CONFLICT-FREE ACCESS TO SHARED MEMORY**

Efficiently storing and loading through shared memory

Bank conflicts between threads in the same phase

4B words are accessed in 1 phase

8B words are accessed in 2 phases:

- Process addresses of the first 16 threads in a warp
- Process addresses of the second 16 threads in a warp

#### 16B words are accessed in 4 phases:

#### 128 bit access size

Each phase processes 8 consecutive threads of a warp

Slide borrowed from: Guillaume Thomas-Collignon and Paulius Micikevicius. "Volta Architecture and performance optimization." GTC 2018. http://on-demand.gputechconf.com/gtc/2018/presentation/s81006-volta-architecture-and-performance-optimization.pdf

# FEEDING THE DATA PATH

#### Efficiently storing and loading through shared memory

Must move data from shared memory to registers as efficiently as possible

- 128 bit access size
- Conflict-free Shared Memory stores
- Conflict-free Shared Memory loads



# **MMA.SYNC GEMM: SPATIALLY INTERLEAVED**

Accumulator tiles may not be contiguous





1 x mma.sync: 16-by-16-by-4

## **MMA.SYNC GEMM: SPATIALLY INTERLEAVED**





4 x mma.sync: 32-by-32-by-4 (spatially interleaved)

### **THREAD-DATA MAPPING - F16 MULTIPLICANDS**



COL-ROW ("NT")

#### **SPATIALLY INTERLEAVED: 128 BIT ACCESSES**





4 x mma.sync: 32-by-32-by-4 (spatially interleaved)

# FEEDING THE DATA PATH

#### Efficiently storing and loading through shared memory

Must move data from shared memory to registers as efficiently as possible

- 128 bit access size
- Conflict-free Shared Memory stores
- Conflict-free Shared Memory loads



## **GLOBAL MEMORY (CANONICAL)**

Striped over 8 x 4 threads

| A <sub>07,0</sub>  | A <sub>815,0</sub>  | A <sub>1623, 0</sub> | A <sub>2431, 0</sub> | A <sub>3239, 0</sub> | A <sub>4047, 0</sub> | A <sub>4855, 0</sub> | A <sub>5663, 0</sub>             |
|--------------------|---------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------------------|
| A <sub>07, 1</sub> | A <sub>815, 1</sub> | A <sub>1623, 1</sub> | A <sub>2431, 1</sub> | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub> | A <sub>5663, 1</sub>             |
| A <sub>07, 2</sub> | A <sub>815, 2</sub> | A <sub>1623, 2</sub> | A <sub>2431, 2</sub> | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>4855, 2</sub> | A <sub>5663</sub> , <sub>2</sub> |
| A <sub>07, 3</sub> | A <sub>815, 3</sub> | A <sub>1623, 3</sub> | A <sub>2431, 3</sub> | A <sub>3239, 3</sub> | A <sub>4047, 3</sub> | A <sub>4855, 3</sub> | A <sub>5663, 3</sub>             |

GMEM



### SHARED MEMORY (PERMUTED)

Permuted layout

| A <sub>07, 0</sub>   | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub> | A <sub>1623, 3</sub> |
|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
| A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub> | A <sub>2431, 2</sub> |
| A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub> | A <sub>4855, 1</sub> |
| A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047,0</sub>  | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub> | A <sub>5663, 0</sub> |



Load (128 bits per thread)

Store

(128 bits per thread) Global Memory (column-major)

| A <sub>07,0</sub>  | A <sub>815,0</sub>  | A <sub>1623, 0</sub> | A <sub>2431,0</sub>  | A <sub>3239, 0</sub> | A <sub>4047,0</sub>  | A <sub>4855, 0</sub> | A <sub>5663, 0</sub> |
|--------------------|---------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
| A <sub>07, 1</sub> | A <sub>815, 1</sub> | A <sub>1623, 1</sub> | A <sub>2431, 1</sub> | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub> | A <sub>5663, 1</sub> |
| A <sub>07, 2</sub> | A <sub>815, 2</sub> | A <sub>1623, 2</sub> | A <sub>2431, 2</sub> | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>4855, 2</sub> | A <sub>5663, 2</sub> |
| A <sub>07, 3</sub> | A <sub>815, 3</sub> | A <sub>1623, 3</sub> | A <sub>2431, 3</sub> | A <sub>3239, 3</sub> | A <sub>4047, 3</sub> | A <sub>4855, 3</sub> | A <sub>5663, 3</sub> |

GMEM

#### Shared Memory (permuted)

| A <sub>07,0</sub>    | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub> | A <sub>1623, 3</sub> |
|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
| A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub> | A <sub>2431, 2</sub> |
| A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub> | A <sub>4855, 1</sub> |
| A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047,0</sub>  | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub> | A <sub>5663, 0</sub> |

| Phase 1                  | Т0                   | T1                   | T2                   | Т3                    | T4                   | T5                   | Т6                     | T7                    |
|--------------------------|----------------------|----------------------|----------------------|-----------------------|----------------------|----------------------|------------------------|-----------------------|
|                          | A <sub>07, 0</sub>   | A <sub>815, 0</sub>  | A <sub>1623, 0</sub> | A <sub>2431, 0</sub>  | A <sub>3239, 0</sub> | A <sub>4047, 0</sub> | A <sub>4855, 0</sub>   | A <sub>5663, 0</sub>  |
| Load                     | A <sub>07, 1</sub>   | A <sub>85, 1</sub>   | A <sub>1623,1</sub>  | A <sub>2431,1</sub>   | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub>   | A <sub>56.63, 1</sub> |
| (128 bits per<br>thread) | A <sub>07, 2</sub>   | A <sub>8</sub> 5, 2  | A <sub>1623, 2</sub> | <b>A</b> 2431, 2      | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>48</sub> .55, 2 | A <sub>5663, 2</sub>  |
|                          | A <sub>07, 3</sub>   | A <sub>85, 3</sub>   | A <sub>1623, 3</sub> | A <sub>243</sub> , 3  | A <sub>32.39</sub> 3 | A <sub>4047, 3</sub> | A <sub>48</sub> .55, 3 | A <sub>56.63, 3</sub> |
|                          |                      |                      |                      |                       | $\mathbf{X}$         |                      |                        |                       |
|                          | A <sub>07, 0</sub>   | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>    | 1623, 0              | <b>1</b> 623, 1      | A <sub>16. 23, 2</sub> | A <sub>1623, 3</sub>  |
| ;<br>Store               | A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>   | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>24.</sub> 31, 3 | A <sub>2431, 2</sub>  |
| (128 bits per<br>thread) | A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>323</sub> 9, 1 | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub>   | A <sub>4855, 1</sub>  |
|                          | A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047, 0</sub>  | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub>   | A <sub>5663, 0</sub>  |

GMEM

| Phase 2                           | Т8                   | Т9                   | T10                  | T11                  | T12                  | T13                  | T14                  | T15                  |
|-----------------------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
|                                   | A <sub>07, 0</sub>   | A <sub>815, 0</sub>  | A <sub>1623, 0</sub> | A <sub>2431, 0</sub> | A <sub>3239, 0</sub> | A <sub>4047, 0</sub> | A <sub>4855, 0</sub> | A <sub>5663, 0</sub> |
| Load                              | A <sub>07, 1</sub>   | A <sub>815, 1</sub>  | A <sub>1623, 1</sub> | A <sub>2431, 1</sub> | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub> | A <sub>5663, 1</sub> |
| (128 bits per<br>thread)          | A <sub>07, 2</sub>   | A <sub>815, 2</sub>  | A <sub>1623, 2</sub> | A <sub>2431, 2</sub> | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>4855, 2</sub> | A <sub>5663, 2</sub> |
|                                   | A <sub>07, 3</sub>   | A <sub>815, 3</sub>  | A <sub>1623, 3</sub> | A <sub>2431, 3</sub> | A <sub>3239, 3</sub> | A <sub>4047, 3</sub> | A <sub>4855, 3</sub> | A <sub>5663, 3</sub> |
|                                   |                      |                      |                      |                      |                      |                      |                      |                      |
|                                   | A <sub>07,0</sub>    | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub> | A <sub>1623, 3</sub> |
| I<br>I<br>V                       | A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub> | A <sub>2431, 2</sub> |
| Store<br>(128 bits per<br>thread) | A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub> | A <sub>4855, 1</sub> |
|                                   | A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047, 0</sub> | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub> | A <sub>5663, 0</sub> |

GMEM

| Phase 3                           | T16                  | T17                  | T18                  | T19                  | T20                  | T21                  | T22                  | T23                  |
|-----------------------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
|                                   | A <sub>07, 0</sub>   | A <sub>815, 0</sub>  | A <sub>1623, 0</sub> | A <sub>2431, 0</sub> | A <sub>3239, 0</sub> | A <sub>4047, 0</sub> | A <sub>4855, 0</sub> | A <sub>5663, 0</sub> |
| Load                              | A <sub>07, 1</sub>   | A <sub>815, 1</sub>  | A <sub>1623, 1</sub> | A <sub>2431, 1</sub> | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub> | A <sub>5663, 1</sub> |
| (128 bits per<br>thread)          | A <sub>07, 2</sub>   | A <sub>815, 2</sub>  | A <sub>1623, 2</sub> | A <sub>2431, 2</sub> | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>4855, 2</sub> | A <sub>5663, 2</sub> |
|                                   | A <sub>07, 3</sub>   | A <sub>815, 3</sub>  | A <sub>1623, 3</sub> | A <sub>2431, 3</sub> | A <sub>3239, 3</sub> | A <sub>4047, 3</sub> | A <sub>4855, 3</sub> | A <sub>5663, 3</sub> |
|                                   |                      |                      |                      |                      |                      |                      |                      |                      |
|                                   | A <sub>07,0</sub>    | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub> | A <sub>1623, 3</sub> |
| I<br>I<br>▼                       | A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub> | A <sub>2431, 2</sub> |
| Store<br>(128 bits per<br>thread) | A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub> | A <sub>4855, 1</sub> |
|                                   | A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047,0</sub>  | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub> | A <sub>5663, 0</sub> |

GMEM

| Phase 4                  | T24                  | T25                  | T26                  | T27                  | T28                  | T29                  | Т30                              | T31                   |
|--------------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------------------|-----------------------|
|                          | A <sub>07, 0</sub>   | A <sub>815, 0</sub>  | A <sub>1623, 0</sub> | A <sub>2431, 0</sub> | A <sub>3239, 0</sub> | A <sub>4047, 0</sub> | A <sub>4855, 0</sub>             | A <sub>5663, 0</sub>  |
| Load                     | A <sub>07, 1</sub>   | A <sub>815, 1</sub>  | A <sub>1623, 1</sub> | A <sub>2431, 1</sub> | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub>             | A <sub>5663, 1</sub>  |
| (128 bits per<br>thread) | A <sub>07, 2</sub>   | A <sub>815, 2</sub>  | A <sub>1623, 2</sub> | A <sub>2431, 2</sub> | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>4855</sub> , <sub>2</sub> | A <sub>5663</sub> , 2 |
|                          | A <sub>07, 3</sub>   | A <sub>815, 3</sub>  | A <sub>1623, 3</sub> | A <sub>2431, 3</sub> | A <sub>3239, 3</sub> | A <sub>4047, 3</sub> | A <sub>4855, 3</sub>             | A <sub>5663, 3</sub>  |
|                          |                      |                      |                      |                      |                      |                      |                                  |                       |
|                          | A <sub>07,0</sub>    | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub>             | A <sub>1623, 3</sub>  |
| ↓<br>Store               | A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub>             | A <sub>2431, 2</sub>  |
| (128 bits per<br>thread) | A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub>             | A <sub>4855, 1</sub>  |
|                          | A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047, 0</sub> | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub>             | A <sub>5663, 0</sub>  |

GMEM

### POINTER OFFSETS FOR PERMUTED SHARED MEMORY

#### Global Memory (column-major)

| A <sub>07, 0</sub> | A <sub>815,0</sub>  | A <sub>1623, 0</sub> | A <sub>2431, 0</sub> | A <sub>3239, 0</sub> | A <sub>4047, 0</sub> | A <sub>4855, 0</sub> | A <sub>5663, 0</sub> |
|--------------------|---------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
| A <sub>07, 1</sub> | A <sub>815, 1</sub> | A <sub>1623, 1</sub> | A <sub>2431, 1</sub> | A <sub>3239, 1</sub> | A <sub>4047, 1</sub> | A <sub>4855, 1</sub> | A <sub>5663, 1</sub> |
| A <sub>07, 2</sub> | A <sub>815, 2</sub> | A <sub>1623, 2</sub> | A <sub>2431, 2</sub> | A <sub>3239, 2</sub> | A <sub>4047, 2</sub> | A <sub>4855, 2</sub> | A <sub>5663, 2</sub> |
| A <sub>07, 3</sub> | A <sub>815, 3</sub> | A <sub>1623, 3</sub> | A <sub>2431, 3</sub> | A <sub>3239, 3</sub> | A <sub>4047, 3</sub> | A <sub>4855, 3</sub> | A <sub>5663, 3</sub> |

int lane = threadIdx.x % 32;

int gmem\_offset = c + s \* lda;

#### Shared Memory (permuted)

| A <sub>07, 0</sub>   | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub> | A <sub>1623, 3</sub> |
|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|
| A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub> | A <sub>2431, 2</sub> |
| A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855, 0</sub> | A <sub>4855, 1</sub> |
| A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047,0</sub>  | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub> | A <sub>5663, 0</sub> |

| <pre>int lane = threadIdx.x % 32;</pre>                                                                                   |
|---------------------------------------------------------------------------------------------------------------------------|
| <pre>int c = lane % 8;<br/>int s = lane / 8;</pre>                                                                        |
| <pre>int smem_row = (c &amp; 1)   ((c &gt;&gt; 1) &amp; 2);<br/>int bank = ((c &lt;&lt; 1) &amp; 4)   s ^ smem_row;</pre> |
| <pre>int smem offset = smem row * ldm smem + bank;</pre>                                                                  |

# FEEDING THE DATA PATH

#### Efficiently storing and loading through shared memory

Must move data from shared memory to registers as efficiently as possible

- 128 bit access size
- Conflict-free Shared Memory stores
- Conflict-free Shared Memory loads



Phase 1 QP0

#### T0 T1 T2 T3

| A <sub>07, 0</sub>   | A <sub>07, 1</sub>   | A <sub>07, 2</sub>   | A <sub>07, 3</sub>   | A <sub>1623, 0</sub> | A <sub>1623, 1</sub> | A <sub>1623, 2</sub>  | A <sub>1623, 3</sub>  |                 |  |
|----------------------|----------------------|----------------------|----------------------|----------------------|----------------------|-----------------------|-----------------------|-----------------|--|
| A <sub>815, 1</sub>  | A <sub>815, 0</sub>  | A <sub>815, 3</sub>  | A <sub>815, 2</sub>  | A <sub>2431, 1</sub> | A <sub>2431, 0</sub> | A <sub>2431, 3</sub>  | A <sub>2431, 2</sub>  |                 |  |
| A <sub>3239, 2</sub> | A <sub>3239, 3</sub> | A <sub>3239, 0</sub> | A <sub>3239, 1</sub> | A <sub>4855, 2</sub> | A <sub>4855, 3</sub> | A <sub>4855</sub> , 0 | A <sub>4855</sub> , 1 |                 |  |
| A <sub>4047, 3</sub> | A <sub>4047, 2</sub> | A <sub>4047, 1</sub> | A <sub>4047,0</sub>  | A <sub>5663, 3</sub> | A <sub>5663, 2</sub> | A <sub>5663, 1</sub>  | A <sub>5663, 0</sub>  | A <sub>07</sub> |  |







Phase 2

QP2

QP3

#### T8 T9 T10 T11 T12 T13 T14 T15







Phase 3

QP0

0 1 2 3

0 1

2 3

4 5 6

QP1





0 1

2 3

0 1

2 3

4 5 6

QP2

QP3



# FEEDING THE DATA PATH

#### Efficiently storing and loading through shared memory

Must move data from shared memory to registers as efficiently as possible

- 128 bit access size
- Conflict-free Shared Memory stores
- Conflict-free Shared Memory loads





#### CUTLASS CUDA C++ Template Library for Deep Learning



#### CUTLASS template library for GEMM computations

- Blocked structure to maximize data reuse
- Software pipelined to hide latency
- Conflict-free Shared Memory access to maximize data throughput

See <u>CUTLASS GTC 2018</u> talk.

# CUTLASS 1.3

Reusable components targeting Volta Tensor Cores



# STORING TO SHARED MEMORY



cutlass/gemm/volta884 multiplicand.h

template <</pre>

// Defines iterators for loading and storing multiplicands

#### CUTLASS Tile Iterators to transform:

• Global Memory: Canonical matrix layout → Shared Memory: permuted shared memory layout

# LOADING FROM SHARED MEMORY



CUTLASS Tile Iterators to transform:

Shared Memory: permuted shared memory layout → Register File: mma.sync thread-data mapping

cutlass/gemm/volta884\_multiplicand.h

```
// Defines iterators for loading and storing multiplicands
template <</pre>
 /// Identifies multiplicand of GEMM (A or B)
  GemmOperand::Kind Operand,
 /// Specifies layout of data in source memory
 MatrixLayout::Kind Layout,
  /// Specifies threadblock tile shape
  typename Tile,
  /// Specifies warp tile shape
  typename WarpTile,
  /// Specifies the number of participating warps
 int WarpCount,
 /// Specifies the delta between warp tiles
  typename WarpDelta
>
struct Volta884Multiplicand {
 11
 // Thread-block load iterator (canonical matrix layout)
 typedef ... LoadIterator;
 11
 // Thread-block store iterator (permuted SMEM layout)
  11
  typedef ... StoreIterator;
  // Warp-level load iterator
  11
 typedef ... WarpLoadIterator;
```

# EXECUTING MMA.SYNC

typename WarpGemmShape , /// Layout of A multiplicand MatrixLayout::Kind LayoutA, /// Data type of A multiplicand typename ScalarA, /// Layout of B multiplicand MatrixLayout::Kind LayoutB, /// Data type of A multiplicand typename ScalarB, /// Data type of accumulators typename ScalarC, /// Whether infinite results are saturated to +-MAX FLOAT bool SatFinite = false > struct Volta884MultiplyAdd { Warp Tile Thread Block Tile mma Tile Blocked GEMM // Multiply : d = (-)a\*b + c. 11 Register File CUDA/Tensor Cores Global Memory Shared Memory CUTLASS DEVICE void multiply add( FragmentA const& A, FragmentB const& B, Accumulators const& C, Accumulators& D, bool negate = false) { };

cutlass/gemm/volta884 multiply add.h

/// Shape of a warp-level GEMM (K-by-N-by-M)

template <</pre>

#### CUTLASS Warp-scoped matrix multiply

Register File: mma.sync thread-data mapping → Tensor Cores: mma.sync

## SPEEDUP RELATIVE TO WMMA





# CONCLUSION

Volta Tensor Cores directly programmable in CUDA 10.1

- Complements WMMA API
- Direct access: mma.sync instruction for Volta Architecture

CUTLASS 1.3 (March 2019)

- CUDA C++ Template Library for Deep Learning
- Reusable components:
  - mma.sync for Volta Tensor Cores
  - Storing and loading from permuted shared memory
  - Efficient epilogue for updating output matrix
- New kernels:
  - Real- and complex-valued mixed precision GEMMs targeting Tensor Cores
  - Parallelized reductions for mma.sync GEMM (first added in CUTLASS 1.2)

#### https://github.com/NVIDIA/cutlass



# REFERENCES

CUTLASS source code: <u>https://github.com/NVIDIA/cutlass</u>

Volta Tensor Cores in CUDA

- mma.sync: <u>https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma</u>
- Matrix fragments: <u>https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma</u>

GEMM resources

- CUTLASS Parallel for All blog post
- GTC 2018 CUTLASS talk [video recording]



## EXTRA MATERIAL

## **THREAD-DATA MAPPING - F16 ACCUMULATION**

Accumulators distributed among threads (QP0 shown)

mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype d, a, b, c;



## **THREAD-DATA MAPPING - F32 ACCUMULATION**

Accumulators distributed among threads (QP0 shown)

mma.sync.aligned.m8n8k4.alayout.blayout.dtype.f16.f16.ctype d, a, b, c;

