-
Notifications
You must be signed in to change notification settings - Fork 330
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype to integrate SW optimizations for Arm® CPUs #6436
base: master
Are you sure you want to change the base?
Conversation
gmiodice
commented
May 17, 2024
- The prototype only works on Arm® CPUs with the i8mm and dotprod extensions
- The prototype only works on Arm® CPUs with the i8mm and dotprod extensions Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
const int8x16_t rhs_vec_6_1 = vandq_s8(rhs_raw_vec_6, nibble_mask); | ||
const int8x16_t rhs_vec_7_1 = vandq_s8(rhs_raw_vec_7, nibble_mask); | ||
|
||
const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can do the transpose using ld2 lane, like the current i8mm gemm, avoiding the combine below.
const float scalar_min = params->scalar.min; | ||
const float scalar_max = params->scalar.max; | ||
|
||
if (m == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
m == 0 should never happen... you can assert. Its also good to assert all values are in the expected range.
main_acc1 = vmulq_f32(main_acc1, lhs_scale); | ||
|
||
// Clip (min-max) operation | ||
const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hoist out of loop if there are enough registers, or dup from params if not.
const uint8_t* lhs_ptr_start = lhs_p; | ||
|
||
for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { | ||
const uint8_t* rhs_ptr = rhs_p; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent of 2
@@ -5958,3 +5962,160 @@ void xnn_qu8_igemm_minmax_rndnu_ukernel_4x8c4__neondot( | |||
} | |||
} while (nc != 0); | |||
} | |||
|
|||
void kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a template/source for this? We dont manually put functions in amalgams. Each kernel has its own source.
int32x4_t iacc0011 = vdupq_n_s32(0); | ||
int32x4_t iacc2233 = vdupq_n_s32(0); | ||
int32x4_t iacc4455 = vdupq_n_s32(0); | ||
int32x4_t iacc6677 = vdupq_n_s32(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gemm initializes to bias?
qd8_f32_qc4w_gemm_config.nr = 8; | ||
qd8_f32_qc4w_gemm_config.log2_kr = 16; | ||
qd8_f32_qc4w_gemm_config.log2_sr = 2; | ||
qd8_f32_qc4w_gemm_config.mr_lhs_pack = 4; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
packing lhs shouldnt be necessary?
qd8_f32_qc4w_gemm_config.nr = 16; | ||
qd8_f32_qc4w_gemm_config.log2_kr = 3; | ||
qd8_f32_qc4w_gemm_config.planes = 2; | ||
qd8_f32_qc4w_gemm_config.mr = 8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4x16 takes fewer registers in practice and is more cache friendly for LLM.
The wider output amortizes per m quantization
} \ | ||
} while (0) | ||
|
||
#define KAI_UNUSED(x) (void)(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are XNN version of MIN/MAX in xnnpack/math.h
@@ -2034,23 +2041,226 @@ void xnn_compute_f16_qd8_convert( | |||
context->convert_ukernel(n, input, output, ¶ms); | |||
} | |||
|
|||
#define KAI_ASSERT(x) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like the wrong place to put this... it is a packing.c pack.h function.
For lhs, packx does packing on lhs for ppmm gemm microkernels.
const float src0_0 = *(src_ptr + k_idx); | ||
|
||
// Calculate the max | ||
vmax0 = vsetq_lane_f32(KAI_MAX(src0_0, vgetq_lane_f32(vmax0, 0)), vmax0, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is slow. Better to handle similar to the main loop but with a mask.
const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); | ||
const size_t k_block_len = kr / sr; | ||
|
||
for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks like rminmax... but less optimized. To optimize more use 4 vectors and 4 accumulators or call f32-rminmax
|
||
for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { | ||
float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); | ||
float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usually faster to initialize with a single value duped from the first float. see f32-rminmax
const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + k_idx); | ||
|
||
// Calculate the max | ||
vmax0 = vmaxq_f32(src0_0, vmax0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 accumulators would allow co-issueing.
vmax0, vmax1, vmin0, vmin1
const float rmin0 = KAI_MIN(0.0f, min0); | ||
const float rmax0 = KAI_MAX(0.0f, max0); | ||
|
||
const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
div is slow... consider reciprocal. Unroll loop to 4 to allow this portion of the code to be simd?
// Quantize the channels | ||
k_idx = 0; | ||
for (; k_idx < num_cols; k_idx += k_block_len) { | ||
for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use cvt microkernel for float to int8
int32x4_t iacc1 = vpaddq_s32(iacc4455, iacc6677); | ||
|
||
// LHS offset | ||
const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be in quantization params?
Note we are thinking about adding a 3rd param - a zero point. Which is awkwardly 3 floats.
I wish we had 3 arrays instead of an array of offset, scale, zero point, but on ARM we can use LD3
lhs_ptr += sizeof(int32_t); | ||
|
||
// LHS scale | ||
const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A scaling value does not require a dup... a lane can be used. But I think it would be better to use vld2q_dup_f32 for 2 reasons:
1 instruction is better than 2
clang has bug with vld1q_dup that most often affects our clamp values that come in pairs, and the work around is use vld2 instead. Or a simple LDR and a lane to get to the specific element.
const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); | ||
const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); | ||
|
||
main_acc0 = vmaxq_f32(main_acc0, vmin_f32); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is clamping needed? The scaling ensures the range?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some overall concerns
Its not clear that packing the left hand side has a benefit, and as is, it requires branching the code from the qs8 gemm and ppmm APIs.
The rminmax and cvt are reimplemented.
Lack of template/source - microkernels live in their own source files and update_microkernels.py builds them into amalgams.
missing i8mm microkernel.