跳至主要內容

在Go语言中优雅地使用AVX512

技术分享SIMD大约 15 分钟

在Go语言中优雅地使用AVX512

AVX512是英特尔发布的最新最新一代SIMD指令,能够在一个指令周期内对512位数据进行处理,相当于16个单精度浮点数或者8个双精度浮点数。Gorseopen in new window推荐系统中的推荐模型训练和推理过程需要大量的向量计算工作,AVX512理论上能够带来一定的加速效果。然而比较遗憾的是,Go语言的编译器并不能自动生成使用SIMD指令的机器码。

MinIOopen in new window曾经开源过一个将英特尔汇编转换为Go汇编的工具c2goasmopen in new window。首先,将需要向量化的函数用C语言实现,通过Clang编译出包含SIMD指令的汇编。然后,由于Go汇编open in new window支持AVX512open in new window的,将Intel汇编指令转换成Go汇编,即可通过汇编调用SIMD实现的函数。c2goasm的方案十分有效,然而项目已经将近4年没有更新,经过测试也无法处理AVX512指令。

为能够在Go语言中使用AVX512,我们开发了一个将C语言编译为Go汇编函数的工具包goatopen in new window,借助goat实现了一个向量化函数库github.com/gorse-io/gorse/base/floatsopen in new window。在继承c2goasm技术路线的基础上,goat实现了功能上进一步的增强:

  1. 从C语言源代码直接开始转换得到Go汇编函数,用户不必自行编译C语言源代码;
  2. 同时会根据C语言函数定义生成Go函数定义,用户不需要手写Go函数定义。

接下来的篇幅会详细介绍goat的技术实现思路,具体的实现欢迎阅读代码github.com/gorse-io/gorse/cmd/goatopen in new window

从C语言编译汇编代码

C语言实现的函数_mm512_mul_to将两个浮点数组相乘之后,将结果保存到第三个数组中。一般情况下,编译器会自动生成使用SIMD的汇编,此处使用了intrinsic确保一定生成SIMD指令。

#include <immintrin.h>

void _mm512_mul_to(float *a, float *b, float *c, int64_t n)
{
    int epoch = n / 16;
    int remain = n % 16;
    for (int i = 0; i < epoch; i++)
    {
        __m512 v1 = _mm512_loadu_ps(a);
        __m512 v2 = _mm512_loadu_ps(b);
        __m512 v = _mm512_mul_ps(v1, v2);
        _mm512_storeu_ps(c, v);
        a += 16;
        b += 16;
        c += 16;
    }
    if (remain >= 8)
    {
        __m256 v1 = _mm256_loadu_ps(a);
        __m256 v2 = _mm256_loadu_ps(b);
        __m256 v = _mm256_mul_ps(v1, v2);
        _mm256_storeu_ps(c, v);
        a += 8;
        b += 8;
        c += 8;
        remain -= 8;
    }
    for (int i = 0; i < remain; i++)
    {
        c[i] = a[i] * b[i];
    }
}

将代码保存到mm512_mul_to.c,使用以下命令将C代码编译为汇编以及二进制。

# 生成汇编
clang -S -c mm512_mul_to.c -o mm512_mul_to.s \
  -O3 -mavx -mfma -mavx512f -mavx512dq \
  -mno-red-zone -mstackrealign -mllvm -inline-threshold=1000 \
  -fno-asynchronous-unwind-tables -fno-exceptions -fno-rtti

# 生成二进制
clang -c mm512_mul_to.c -o mm512_mul_to.o \
  -O3 -mavx -mfma -mavx512f -mavx512dq \
  -mno-red-zone -mstackrealign -mllvm -inline-threshold=1000 \
  -fno-asynchronous-unwind-tables -fno-exceptions -fno-rtti
mm512_mul_to.s
 .globl _mm512_mul_to                   # -- Begin function _mm512_mul_to
 .p2align 4, 0x90
 .type _mm512_mul_to,@function
_mm512_mul_to:                          # @_mm512_mul_to
# %bb.0:
 pushq %rbp
 movq %rsp, %rbp
 andq $-8, %rsp
 leaq 15(%rcx), %r9
 testq %rcx, %rcx
 cmovnsq %rcx, %r9
 shrq $4, %r9
 movl %r9d, %eax
 shll $4, %eax
 subl %eax, %ecx
 testl %r9d, %r9d
 jle .LBB3_6
# %bb.1:
 leal -1(%r9), %eax
 movl %r9d, %r8d
 andl $3, %r8d
 cmpl $3, %eax
 jb .LBB3_4
# %bb.2:
 andl $-4, %r9d
 negl %r9d
 .p2align 4, 0x90
.LBB3_3:                                # =>This Inner Loop Header: Depth=1
 vmovups (%rdi), %zmm0
 vmulps (%rsi), %zmm0, %zmm0
 vmovups %zmm0, (%rdx)
 vmovups 64(%rdi), %zmm0
 vmulps 64(%rsi), %zmm0, %zmm0
 vmovups %zmm0, 64(%rdx)
 vmovups 128(%rdi), %zmm0
 vmulps 128(%rsi), %zmm0, %zmm0
 vmovups %zmm0, 128(%rdx)
 vmovups 192(%rdi), %zmm0
 vmulps 192(%rsi), %zmm0, %zmm0
 vmovups %zmm0, 192(%rdx)
 addq $256, %rdi                      # imm = 0x100
 addq $256, %rsi                      # imm = 0x100
 addq $256, %rdx                      # imm = 0x100
 addl $4, %r9d
 jne .LBB3_3
.LBB3_4:
 testl %r8d, %r8d
 je .LBB3_6
 .p2align 4, 0x90
.LBB3_5:                                # =>This Inner Loop Header: Depth=1
 vmovups (%rdi), %zmm0
 vmulps (%rsi), %zmm0, %zmm0
 vmovups %zmm0, (%rdx)
 addq $64, %rdi
 addq $64, %rsi
 addq $64, %rdx
 addl $-1, %r8d
 jne .LBB3_5
.LBB3_6:
 cmpl $7, %ecx
 jle .LBB3_8
# %bb.7:
 vmovups (%rdi), %ymm0
 vmulps (%rsi), %ymm0, %ymm0
 vmovups %ymm0, (%rdx)
 addq $32, %rdi
 addq $32, %rsi
 addq $32, %rdx
 addl $-8, %ecx
.LBB3_8:
 testl %ecx, %ecx
 jle .LBB3_14
# %bb.9:
 movl %ecx, %ecx
 leaq -1(%rcx), %rax
 movl %ecx, %r8d
 andl $3, %r8d
 cmpq $3, %rax
 jae .LBB3_15
# %bb.10:
 xorl %eax, %eax
 jmp .LBB3_11
.LBB3_15:
 andl $-4, %ecx
 xorl %eax, %eax
 .p2align 4, 0x90
.LBB3_16:                               # =>This Inner Loop Header: Depth=1
 vmovss (%rdi,%rax,4), %xmm0            # xmm0 = mem[0],zero,zero,zero
 vmulss (%rsi,%rax,4), %xmm0, %xmm0
 vmovss %xmm0, (%rdx,%rax,4)
 vmovss 4(%rdi,%rax,4), %xmm0           # xmm0 = mem[0],zero,zero,zero
 vmulss 4(%rsi,%rax,4), %xmm0, %xmm0
 vmovss %xmm0, 4(%rdx,%rax,4)
 vmovss 8(%rdi,%rax,4), %xmm0           # xmm0 = mem[0],zero,zero,zero
 vmulss 8(%rsi,%rax,4), %xmm0, %xmm0
 vmovss %xmm0, 8(%rdx,%rax,4)
 vmovss 12(%rdi,%rax,4), %xmm0          # xmm0 = mem[0],zero,zero,zero
 vmulss 12(%rsi,%rax,4), %xmm0, %xmm0
 vmovss %xmm0, 12(%rdx,%rax,4)
 addq $4, %rax
 cmpq %rax, %rcx
 jne .LBB3_16
.LBB3_11:
 testq %r8, %r8
 je .LBB3_14
# %bb.12:
 leaq (%rdx,%rax,4), %rcx
 leaq (%rsi,%rax,4), %rdx
 leaq (%rdi,%rax,4), %rax
 xorl %esi, %esi
 .p2align 4, 0x90
.LBB3_13:                               # =>This Inner Loop Header: Depth=1
 vmovss (%rax,%rsi,4), %xmm0            # xmm0 = mem[0],zero,zero,zero
 vmulss (%rdx,%rsi,4), %xmm0, %xmm0
 vmovss %xmm0, (%rcx,%rsi,4)
 addq $1, %rsi
 cmpq %rsi, %r8
 jne .LBB3_13
.LBB3_14:
 movq %rbp, %rsp
 popq %rbp
 vzeroupper
 retq
.Lfunc_end3:
 .size _mm512_mul_to, .Lfunc_end3-_mm512_mul_to
                                        # -- End function

在Go语言中,可以用exec.Command很方便地执行编译命令。

从汇编代码构造Go代码

转换汇编代码

mm512_mul_to.o是函数被编译后的二进制,通过objdump可以看到每条汇编都被转换成了机器码。

objdump -d mm512_mul_to.o --insn-width 16
objdump输出内容
0000000000000460 <_mm512_mul_to>:
 460:   55                              push   %rbp
 461:   48 89 e5                        mov    %rsp,%rbp
 464:   48 83 e4 f8                     and    $0xfffffffffffffff8,%rsp
 468:   4c 8d 49 0f                     lea    0xf(%rcx),%r9
 46c:   48 85 c9                        test   %rcx,%rcx
 46f:   4c 0f 49 c9                     cmovns %rcx,%r9
 473:   49 c1 e9 04                     shr    $0x4,%r9
 477:   44 89 c8                        mov    %r9d,%eax
 47a:   c1 e0 04                        shl    $0x4,%eax
 47d:   29 c1                           sub    %eax,%ecx
 47f:   45 85 c9                        test   %r9d,%r9d
 482:   0f 8e bc 00 00 00               jle    544 <_mm512_mul_to+0xe4>
 488:   41 8d 41 ff                     lea    -0x1(%r9),%eax
 48c:   45 89 c8                        mov    %r9d,%r8d
 48f:   41 83 e0 03                     and    $0x3,%r8d
 493:   83 f8 03                        cmp    $0x3,%eax
 496:   72 74                           jb     50c <_mm512_mul_to+0xac>
 498:   41 83 e1 fc                     and    $0xfffffffc,%r9d
 49c:   41 f7 d9                        neg    %r9d
 49f:   90                              nop
 4a0:   62 f1 7c 48 10 07               vmovups (%rdi),%zmm0
 4a6:   62 f1 7c 48 59 06               vmulps (%rsi),%zmm0,%zmm0
 4ac:   62 f1 7c 48 11 02               vmovups %zmm0,(%rdx)
 4b2:   62 f1 7c 48 10 47 01            vmovups 0x40(%rdi),%zmm0
 4b9:   62 f1 7c 48 59 46 01            vmulps 0x40(%rsi),%zmm0,%zmm0
 4c0:   62 f1 7c 48 11 42 01            vmovups %zmm0,0x40(%rdx)
 4c7:   62 f1 7c 48 10 47 02            vmovups 0x80(%rdi),%zmm0
 4ce:   62 f1 7c 48 59 46 02            vmulps 0x80(%rsi),%zmm0,%zmm0
 4d5:   62 f1 7c 48 11 42 02            vmovups %zmm0,0x80(%rdx)
 4dc:   62 f1 7c 48 10 47 03            vmovups 0xc0(%rdi),%zmm0
 4e3:   62 f1 7c 48 59 46 03            vmulps 0xc0(%rsi),%zmm0,%zmm0
 4ea:   62 f1 7c 48 11 42 03            vmovups %zmm0,0xc0(%rdx)
 4f1:   48 81 c7 00 01 00 00            add    $0x100,%rdi
 4f8:   48 81 c6 00 01 00 00            add    $0x100,%rsi
 4ff:   48 81 c2 00 01 00 00            add    $0x100,%rdx
 506:   41 83 c1 04                     add    $0x4,%r9d
 50a:   75 94                           jne    4a0 <_mm512_mul_to+0x40>
 50c:   45 85 c0                        test   %r8d,%r8d
 50f:   74 33                           je     544 <_mm512_mul_to+0xe4>
 511:   66 2e 0f 1f 84 00 00 00 00 00   nopw   %cs:0x0(%rax,%rax,1)
 51b:   0f 1f 44 00 00                  nopl   0x0(%rax,%rax,1)
 520:   62 f1 7c 48 10 07               vmovups (%rdi),%zmm0
 526:   62 f1 7c 48 59 06               vmulps (%rsi),%zmm0,%zmm0
 52c:   62 f1 7c 48 11 02               vmovups %zmm0,(%rdx)
 532:   48 83 c7 40                     add    $0x40,%rdi
 536:   48 83 c6 40                     add    $0x40,%rsi
 53a:   48 83 c2 40                     add    $0x40,%rdx
 53e:   41 83 c0 ff                     add    $0xffffffff,%r8d
 542:   75 dc                           jne    520 <_mm512_mul_to+0xc0>
 544:   83 f9 07                        cmp    $0x7,%ecx
 547:   7e 1b                           jle    564 <_mm512_mul_to+0x104>
 549:   c5 fc 10 07                     vmovups (%rdi),%ymm0
 54d:   c5 fc 59 06                     vmulps (%rsi),%ymm0,%ymm0
 551:   c5 fc 11 02                     vmovups %ymm0,(%rdx)
 555:   48 83 c7 20                     add    $0x20,%rdi
 559:   48 83 c6 20                     add    $0x20,%rsi
 55d:   48 83 c2 20                     add    $0x20,%rdx
 561:   83 c1 f8                        add    $0xfffffff8,%ecx
 564:   85 c9                           test   %ecx,%ecx
 566:   0f 8e ac 00 00 00               jle    618 <_mm512_mul_to+0x1b8>
 56c:   89 c9                           mov    %ecx,%ecx
 56e:   48 8d 41 ff                     lea    -0x1(%rcx),%rax
 572:   41 89 c8                        mov    %ecx,%r8d
 575:   41 83 e0 03                     and    $0x3,%r8d
 579:   48 83 f8 03                     cmp    $0x3,%rax
 57d:   73 04                           jae    583 <_mm512_mul_to+0x123>
 57f:   31 c0                           xor    %eax,%eax
 581:   eb 5b                           jmp    5de <_mm512_mul_to+0x17e>
 583:   83 e1 fc                        and    $0xfffffffc,%ecx
 586:   31 c0                           xor    %eax,%eax
 588:   0f 1f 84 00 00 00 00 00         nopl   0x0(%rax,%rax,1)
 590:   c5 fa 10 04 87                  vmovss (%rdi,%rax,4),%xmm0
 595:   c5 fa 59 04 86                  vmulss (%rsi,%rax,4),%xmm0,%xmm0
 59a:   c5 fa 11 04 82                  vmovss %xmm0,(%rdx,%rax,4)
 59f:   c5 fa 10 44 87 04               vmovss 0x4(%rdi,%rax,4),%xmm0
 5a5:   c5 fa 59 44 86 04               vmulss 0x4(%rsi,%rax,4),%xmm0,%xmm0
 5ab:   c5 fa 11 44 82 04               vmovss %xmm0,0x4(%rdx,%rax,4)
 5b1:   c5 fa 10 44 87 08               vmovss 0x8(%rdi,%rax,4),%xmm0
 5b7:   c5 fa 59 44 86 08               vmulss 0x8(%rsi,%rax,4),%xmm0,%xmm0
 5bd:   c5 fa 11 44 82 08               vmovss %xmm0,0x8(%rdx,%rax,4)
 5c3:   c5 fa 10 44 87 0c               vmovss 0xc(%rdi,%rax,4),%xmm0
 5c9:   c5 fa 59 44 86 0c               vmulss 0xc(%rsi,%rax,4),%xmm0,%xmm0
 5cf:   c5 fa 11 44 82 0c               vmovss %xmm0,0xc(%rdx,%rax,4)
 5d5:   48 83 c0 04                     add    $0x4,%rax
 5d9:   48 39 c1                        cmp    %rax,%rcx
 5dc:   75 b2                           jne    590 <_mm512_mul_to+0x130>
 5de:   4d 85 c0                        test   %r8,%r8
 5e1:   74 35                           je     618 <_mm512_mul_to+0x1b8>
 5e3:   48 8d 0c 82                     lea    (%rdx,%rax,4),%rcx
 5e7:   48 8d 14 86                     lea    (%rsi,%rax,4),%rdx
 5eb:   48 8d 04 87                     lea    (%rdi,%rax,4),%rax
 5ef:   31 f6                           xor    %esi,%esi
 5f1:   66 2e 0f 1f 84 00 00 00 00 00   nopw   %cs:0x0(%rax,%rax,1)
 5fb:   0f 1f 44 00 00                  nopl   0x0(%rax,%rax,1)
 600:   c5 fa 10 04 b0                  vmovss (%rax,%rsi,4),%xmm0
 605:   c5 fa 59 04 b2                  vmulss (%rdx,%rsi,4),%xmm0,%xmm0
 60a:   c5 fa 11 04 b1                  vmovss %xmm0,(%rcx,%rsi,4)
 60f:   48 83 c6 01                     add    $0x1,%rsi
 613:   49 39 f0                        cmp    %rsi,%r8
 616:   75 e8                           jne    600 <_mm512_mul_to+0x1a0>
 618:   48 89 ec                        mov    %rbp,%rsp
 61b:   5d                              pop    %rbp
 61c:   c5 f8 77                        vzeroupper 
 61f:   c3                              retq   

编码机器码

Go汇编提供了三种表示二进制机器码的指令:

  • BYTE表示一个字节长度的二进制数据
  • WORD表示两个字节长度的二进制数据
  • LONG表示四个字节长度的二进制数据

如果指令机器码长度刚好是二的倍数,例如

 498:   41 83 e1 fc                     and    $0xfffffffc,%r9d

可以转换为

 LONG $0xfce18341

但是如果长度不是二的倍数,例如

 4b2:   62 f1 7c 48 10 47 01            vmovups 0x40(%rdi),%zmm0

就需要三种指令组合表示

 LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups	64(%rdi), %zmm0

需要注意,指令编码的字节顺序和objdump输出的字节顺序是相反的。

函数定义和参数

在C语言汇编中,如果函数的参数不超过6个,那么参数会被保存在寄存器中传递给函数,摆放顺序为:%rdi%rsi%rdx%rcx%r8%r9。然而在Go汇编中,函数参数会放在从FP寄存器所保存地址开始的内存中,这就需要我们将内存中的参数移动到寄存器中。

_mm512_mul_to函数有四个参数,那么就需要在函数开始之前移动好4个参数:

TEXT ·_mm512_mul_to(SB), $0-32
  MOVQ a+0(FP), DI
  MOVQ b+8(FP), SI
  MOVQ c+16(FP), DX
  MOVQ n+24(FP), CX

函数定义由三部分构成:TEXT关键字、以·开始以(SB)结尾的名字,最后是参数内存大小32字节。汇编中并没有参数数量信息,需要从C语言函数定义中获取(参考生成go函数定义)。

重定向跳转指令

在x86中跳转指令跳转的是绝对地址,对跳转指令进行直接编码是无法工作的。因此,跳转指令需要转换为Go汇编中的跳转指令。

  • 转换标签: Go汇编的标签不能以.开始,因此需要将.删除;
  • 转换命令: Go汇编的跳转指令是大写的。

经过上述三步骤处理,最终得到Go汇编代码

Go汇编
TEXT ·_mm512_mul_to(SB), $0-32
 MOVQ a+0(FP), DI
 MOVQ b+8(FP), SI
 MOVQ c+16(FP), DX
 MOVQ n+24(FP), CX
 BYTE $0x55               // pushq %rbp
 WORD $0x8948; BYTE $0xe5 // movq %rsp, %rbp
 LONG $0xf8e48348         // andq $-8, %rsp
 LONG $0x0f498d4c         // leaq 15(%rcx), %r9
 WORD $0x8548; BYTE $0xc9 // testq %rcx, %rcx
 LONG $0xc9490f4c         // cmovnsq %rcx, %r9
 LONG $0x04e9c149         // shrq $4, %r9
 WORD $0x8944; BYTE $0xc8 // movl %r9d, %eax
 WORD $0xe0c1; BYTE $0x04 // shll $4, %eax
 WORD $0xc129             // subl %eax, %ecx
 WORD $0x8545; BYTE $0xc9 // testl %r9d, %r9d
 JLE  LBB3_6
 LONG $0xff418d41         // leal -1(%r9), %eax
 WORD $0x8945; BYTE $0xc8 // movl %r9d, %r8d
 LONG $0x03e08341         // andl $3, %r8d
 WORD $0xf883; BYTE $0x03 // cmpl $3, %eax
 JB   LBB3_4
 LONG $0xfce18341         // andl $-4, %r9d
 WORD $0xf741; BYTE $0xd9 // negl %r9d

LBB3_3:
 LONG $0x487cf162; WORD $0x0710             // vmovups (%rdi), %zmm0
 LONG $0x487cf162; WORD $0x0659             // vmulps (%rsi), %zmm0, %zmm0
 LONG $0x487cf162; WORD $0x0211             // vmovups %zmm0, (%rdx)
 LONG $0x487cf162; WORD $0x4710; BYTE $0x01 // vmovups 64(%rdi), %zmm0
 LONG $0x487cf162; WORD $0x4659; BYTE $0x01 // vmulps 64(%rsi), %zmm0, %zmm0
 LONG $0x487cf162; WORD $0x4211; BYTE $0x01 // vmovups %zmm0, 64(%rdx)
 LONG $0x487cf162; WORD $0x4710; BYTE $0x02 // vmovups 128(%rdi), %zmm0
 LONG $0x487cf162; WORD $0x4659; BYTE $0x02 // vmulps 128(%rsi), %zmm0, %zmm0
 LONG $0x487cf162; WORD $0x4211; BYTE $0x02 // vmovups %zmm0, 128(%rdx)
 LONG $0x487cf162; WORD $0x4710; BYTE $0x03 // vmovups 192(%rdi), %zmm0
 LONG $0x487cf162; WORD $0x4659; BYTE $0x03 // vmulps 192(%rsi), %zmm0, %zmm0
 LONG $0x487cf162; WORD $0x4211; BYTE $0x03 // vmovups %zmm0, 192(%rdx)
 LONG $0x00c78148; WORD $0x0001; BYTE $0x00 // addq $256, %rdi
 LONG $0x00c68148; WORD $0x0001; BYTE $0x00 // addq $256, %rsi
 LONG $0x00c28148; WORD $0x0001; BYTE $0x00 // addq $256, %rdx
 LONG $0x04c18341                           // addl $4, %r9d
 JNE  LBB3_3

LBB3_4:
 WORD $0x8545; BYTE $0xc0 // testl %r8d, %r8d
 JE   LBB3_6

LBB3_5:
 LONG $0x487cf162; WORD $0x0710 // vmovups (%rdi), %zmm0
 LONG $0x487cf162; WORD $0x0659 // vmulps (%rsi), %zmm0, %zmm0
 LONG $0x487cf162; WORD $0x0211 // vmovups %zmm0, (%rdx)
 LONG $0x40c78348               // addq $64, %rdi
 LONG $0x40c68348               // addq $64, %rsi
 LONG $0x40c28348               // addq $64, %rdx
 LONG $0xffc08341               // addl $-1, %r8d
 JNE  LBB3_5

LBB3_6:
 WORD $0xf983; BYTE $0x07 // cmpl $7, %ecx
 JLE  LBB3_8
 LONG $0x0710fcc5         // vmovups (%rdi), %ymm0
 LONG $0x0659fcc5         // vmulps (%rsi), %ymm0, %ymm0
 LONG $0x0211fcc5         // vmovups %ymm0, (%rdx)
 LONG $0x20c78348         // addq $32, %rdi
 LONG $0x20c68348         // addq $32, %rsi
 LONG $0x20c28348         // addq $32, %rdx
 WORD $0xc183; BYTE $0xf8 // addl $-8, %ecx

LBB3_8:
 WORD $0xc985             // testl %ecx, %ecx
 JLE  LBB3_14
 WORD $0xc989             // movl %ecx, %ecx
 LONG $0xff418d48         // leaq -1(%rcx), %rax
 WORD $0x8941; BYTE $0xc8 // movl %ecx, %r8d
 LONG $0x03e08341         // andl $3, %r8d
 LONG $0x03f88348         // cmpq $3, %rax
 JAE  LBB3_15
 WORD $0xc031             // xorl %eax, %eax
 JMP  LBB3_11

LBB3_15:
 WORD $0xe183; BYTE $0xfc // andl $-4, %ecx
 WORD $0xc031             // xorl %eax, %eax

LBB3_16:
 LONG $0x0410fac5; BYTE $0x87   // vmovss (%rdi,%rax,4), %xmm0
 LONG $0x0459fac5; BYTE $0x86   // vmulss (%rsi,%rax,4), %xmm0, %xmm0
 LONG $0x0411fac5; BYTE $0x82   // vmovss %xmm0, (%rdx,%rax,4)
 LONG $0x4410fac5; WORD $0x0487 // vmovss 4(%rdi,%rax,4), %xmm0
 LONG $0x4459fac5; WORD $0x0486 // vmulss 4(%rsi,%rax,4), %xmm0, %xmm0
 LONG $0x4411fac5; WORD $0x0482 // vmovss %xmm0, 4(%rdx,%rax,4)
 LONG $0x4410fac5; WORD $0x0887 // vmovss 8(%rdi,%rax,4), %xmm0
 LONG $0x4459fac5; WORD $0x0886 // vmulss 8(%rsi,%rax,4), %xmm0, %xmm0
 LONG $0x4411fac5; WORD $0x0882 // vmovss %xmm0, 8(%rdx,%rax,4)
 LONG $0x4410fac5; WORD $0x0c87 // vmovss 12(%rdi,%rax,4), %xmm0
 LONG $0x4459fac5; WORD $0x0c86 // vmulss 12(%rsi,%rax,4), %xmm0, %xmm0
 LONG $0x4411fac5; WORD $0x0c82 // vmovss %xmm0, 12(%rdx,%rax,4)
 LONG $0x04c08348               // addq $4, %rax
 WORD $0x3948; BYTE $0xc1       // cmpq %rax, %rcx
 JNE  LBB3_16

LBB3_11:
 WORD $0x854d; BYTE $0xc0 // testq %r8, %r8
 JE   LBB3_14
 LONG $0x820c8d48         // leaq (%rdx,%rax,4), %rcx
 LONG $0x86148d48         // leaq (%rsi,%rax,4), %rdx
 LONG $0x87048d48         // leaq (%rdi,%rax,4), %rax
 WORD $0xf631             // xorl %esi, %esi

LBB3_13:
 LONG $0x0410fac5; BYTE $0xb0 // vmovss (%rax,%rsi,4), %xmm0
 LONG $0x0459fac5; BYTE $0xb2 // vmulss (%rdx,%rsi,4), %xmm0, %xmm0
 LONG $0x0411fac5; BYTE $0xb1 // vmovss %xmm0, (%rcx,%rsi,4)
 LONG $0x01c68348             // addq $1, %rsi
 WORD $0x3949; BYTE $0xf0     // cmpq %rsi, %r8
 JNE  LBB3_13

LBB3_14:
 WORD $0x8948; BYTE $0xec // movq %rbp, %rsp
 BYTE $0x5d               // popq %rbp
 WORD $0xf8c5; BYTE $0x77 // vzeroupper
 BYTE $0xc3               // retq

生成Go函数定义

将C语言函数定义转换为Go函数需要一个C语言解释器,cznic用Go实现了一个C语言到Go语言的转换器cc/v3open in new window中提供了C语言解释器。函数定义转换包含两个部分:

  • 函数名称: 直接将C函数的函数名作为Go函数名
  • 函数参数: 需要检查函数参数必须为64位类型,在Go函数中使用unsafe.Pointer传递

_mm512_mul_to转换为Go函数的定义如下

import "unsafe"

//go:noescape
func _mm512_mul_to(a, b, c, n unsafe.Pointer)

使用go generate构建

对于生成代码,go generate命令能够调用自定义命令完成C函数到Go函数的转换。Gorse在floats_amd64.goopen in new window文件中添加了go generate命令,在根目录执行go generate ./...即可自动生成Go向量化函数。

//go:generate go run ../../cmd/goat src/floats_avx.c -O3 -mavx
//go:generate go run ../../cmd/goat src/floats_avx512.c -O3 -mavx -mfma -mavx512f -mavx512dq

函数调用

调用向量化函数并不是一个简单的任务,因为只有比较新的CPU支持AVX512,一些比较古老的CPU甚至不支持AVX2,这就需要在调用向量化函数的时候:

  1. 需要提供一个非向量化函数实现用于不支持AVX512的CPU;
  2. init函数中检测当前运行的CPU是否支持AVX512;
  3. 包装一个外层函数根据CPU指令支持情况自动选择向量化函数或者非向量化函数。

完整的包装如下

import (
 "github.com/klauspost/cpuid/v2"
 "unsafe"
)

// CPU指令支持情况全局变量
var impl = Default

const (
 Default int = iota
 AVX512
)

func init() {
 // 检测CPU是否支持AVX512
 if cpuid.CPU.Supports(cpuid.AVX512F, cpuid.AVX512DQ) {
  impl = AVX512
 }
}

// 非向量化实现
func mulTo(a, b, c []float32) {
 for i := range a {
  c[i] = a[i] * b[i]
 }
}

// 包装函数
func MulTo(a, b, c []float32) {
 if len(a) != len(b) || len(a) != len(c) {
  panic("floats: slice lengths do not match")
 }
 // 自动选择执行的函数
 switch impl {
 case AVX512:
  _mm512_mul_to(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(&c[0]), unsafe.Pointer(uintptr(len(a))))
 default:
  mulTo(a, b, c)
 }
}

Go汇编需要保存到后缀为.s的文件中,Go函数定义和包装函数保存在Go文件中,这些文件必须在同一个文件夹中,使用相同的包名。

总结

最后,我们对比下非向量化函数和向量化函数的性能。

  • 首选,向量化函数用时明显少于非向量化函数,尤其在向量长度为128的时候;
  • 然后,AVX512相对AVX2有微量的性能提升。
向量化实现性能对比

另外,ARM平台的SIMD方案Neon同样可以通过本文的方案在Go语言中使用,goat同样支持ARM平台,Gorse同样实现了Neon指令的向量化计算库(参考github.com/gorse-io/gorse/base/floatsopen in new window)。欢迎大家使用goat在Go项目中构建自己的向量化函数,更多问题欢迎在Issuesopen in new window或者QQ群open in new window中交流。