BFloat16(Brain Floating Point 16)是近年来在机器学习和高性能计算领域广泛采用的一种16位浮点数格式。与传统的FP16不同,BFloat16保留了与FP32相同的8位指数位,仅将尾数位从23位缩减到7位。这种设计取舍带来了几个关键优势:
Arm的SME2(Scalable Matrix Extension 2)指令集在BFloat16支持上做了深度优化,主要特性包括:
{ <Zn1>.H-<Zn4>.H }语法支持2-4个向量寄存器同时操作关键提示:在启用BFloat16指令前,必须通过
ID_AA64ZFR0_EL1.B16B16检查硬件支持,否则会触发Undefined Instruction异常。
armasm复制BFCVT <Zd>.H, { <Zn1>.S-<Zn2>.S } // FP32转BFloat16
BFCVT <Zd>.B, { <Zn1>.H-<Zn2>.H } // BFloat16转FP8
转换过程遵循IEEE 754标准,关键处理逻辑包括:
典型使用场景:
c复制// 将FP32卷积权重转换为BFloat16存储
float32_t weights_fp32[1024];
bfloat16_t weights_bf16[1024];
for(int i=0; i<1024; i+=2) {
asm volatile(
"ldp q0, q1, [%0], #32\n\t"
"bfcvt v2.8h, v0.4s\n\t"
"bfcvt v3.8h, v1.4s\n\t"
"stp q2, q3, [%1], #32"
:: "r"(weights_fp32+i), "r"(weights_bf16+i)
);
}
| 指令 | NaN处理 | 零值比较 | 适用场景 |
|---|---|---|---|
| BFMAX | 受FPCR.DN控制 | -0 < +0 | 通用最大值 |
| BFMAXNM | 忽略quiet NaN | -0 < +0 | 数值计算 |
运算流程差异:
python复制def BFMax(a, b, fpcr):
if fpcr.AH == 1:
if (a == 0 and b == 0) or (is_nan(a) or is_nan(b)):
return b
return max(a, b)
def BFMaxNum(a, b, fpcr):
if is_nan(a) and not is_snan(a):
return b
if is_nan(b) and not is_snan(b):
return a
return max(a, b)
armasm复制BFDOT ZA.S[<Wv>, <offs>{, VGx4}], { <Zn1>.H-<Zn4>.H }, <Zm>.H[<index>]
执行过程分解:
<index>选择Zm中的BFloat16元素对acc = fma(a, b, acc),结果保持FP32精度性能优化要点:
<Wv>寄存器实现矩阵分块并行计算<index>实现数据复用,减少内存访问c复制void enable_sme2() {
uint64_t cpacr = read_cpacr_el1();
cpacr |= (3 << 16); // Enable FP/SIMD
write_cpacr_el1(cpacr);
uint64_t smcr = read_smcr_el2();
smcr |= (1 << 0); // Enable SME
write_smcr_el2(smcr);
asm volatile("msr SVCR, #1"); // Enter streaming mode
}
armasm复制// 假设: ZA[4x4] += A[4x8] * B[8x4]
mov x0, #0 // 初始化行计数器
.loop_row:
mov x1, #0 // 初始化列计数器
.loop_col:
// 加载A矩阵2x2块 (BFloat16)
ld1 {z0.h-z3.h}, [a_ptr], #64
// 加载B矩阵2x2块 (BFloat16)
ld1 {z4.h-z7.h}, [b_ptr], #64
// 计算2x2x2分块矩阵乘
bfdot za.s[w0, 0], {z0.h-z1.h}, z4.h[0]
bfdot za.s[w0, 1], {z0.h-z1.h}, z5.h[0]
bfdot za.s[w0+1,0], {z2.h-z3.h}, z4.h[0]
bfdot za.s[w0+1,1], {z2.h-z3.h}, z5.h[0]
add x1, x1, #2
cmp x1, #8
b.lt .loop_col
add x0, x0, #2
cmp x0, #4
b.lt .loop_row
armasm复制ld1 {z0.h-z3.h}, [a_ptr], #64 // 预加载下一块
bfdot za.s[w8,0], {z4.h-z5.h}, z16.h[0] // 计算当前块
armasm复制// FP32累加 + BFloat16乘
bfmla z0.s, z1.h, z2.h
// BF16->FP32转换 + 乘加
bfcvtn z3.h, z4.s
fmmla z5.s, z6.h, z7.h
| 异常类型 | 可能原因 | 解决方案 |
|---|---|---|
| Illegal Instruction | 缺少SME2支持 | 检查ID_AA64ZFR0_EL1 |
| Data Abort | 未对齐访问 | 确保数据64字节对齐 |
| FP Trap | 异常输入值 | 检查FPCR.DN设置 |
bash复制perf stat -e instructions,cycles,L1D-cache-misses ./bf16_app
armasm复制// 3x3卷积核实现
.macro conv3x3_kernel
ld1 {z0.h-z3.h}, [input_ptr], #64
ld1 {z4.h-z7.h}, [kernel_ptr], #64
bfdot za.s[w12,0], {z0.h-z1.h}, z4.h[0]
bfdot za.s[w12,1], {z0.h-z1.h}, z5.h[0]
...
.endm
c复制// Q*K^T矩阵计算
for(int i=0; i<heads; i++) {
asm volatile(
"bfdot za.s[%0,0], {z0.h-z3.h}, {z4.h-z7.h}\n\t"
: : "r"(i*4)
: "z0","z1","z2","z3","z4","z5","z6","z7"
);
}
经验总结:在实际部署中发现,当序列长度超过512时,采用VGx4版本指令相比VGx2可获得约1.7倍加速,但需注意寄存器压力导致的频率调节问题。