Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype to integrate SW optimizations for Arm® CPUs #6436

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

gmiodice
Copy link

  • The prototype only works on Arm® CPUs with the i8mm and dotprod extensions

- The prototype only works on Arm® CPUs with the i8mm and dotprod extensions

Signed-off-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
const int8x16_t rhs_vec_6_1 = vandq_s8(rhs_raw_vec_6, nibble_mask);
const int8x16_t rhs_vec_7_1 = vandq_s8(rhs_raw_vec_7, nibble_mask);

const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can do the transpose using ld2 lane, like the current i8mm gemm, avoiding the combine below.

const float scalar_min = params->scalar.min;
const float scalar_max = params->scalar.max;

if (m == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

m == 0 should never happen... you can assert. Its also good to assert all values are in the expected range.

main_acc1 = vmulq_f32(main_acc1, lhs_scale);

// Clip (min-max) operation
const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hoist out of loop if there are enough registers, or dup from params if not.

const uint8_t* lhs_ptr_start = lhs_p;

for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) {
const uint8_t* rhs_ptr = rhs_p;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent of 2

@@ -5958,3 +5962,160 @@ void xnn_qu8_igemm_minmax_rndnu_ukernel_4x8c4__neondot(
}
} while (nc != 0);
}

void kai_matmul_clamp_f32_qai8dxp1x8_qsu4cxp8x8_1x8x32_neon_dotprod(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a template/source for this? We dont manually put functions in amalgams. Each kernel has its own source.

int32x4_t iacc0011 = vdupq_n_s32(0);
int32x4_t iacc2233 = vdupq_n_s32(0);
int32x4_t iacc4455 = vdupq_n_s32(0);
int32x4_t iacc6677 = vdupq_n_s32(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gemm initializes to bias?

qd8_f32_qc4w_gemm_config.nr = 8;
qd8_f32_qc4w_gemm_config.log2_kr = 16;
qd8_f32_qc4w_gemm_config.log2_sr = 2;
qd8_f32_qc4w_gemm_config.mr_lhs_pack = 4;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

packing lhs shouldnt be necessary?

qd8_f32_qc4w_gemm_config.nr = 16;
qd8_f32_qc4w_gemm_config.log2_kr = 3;
qd8_f32_qc4w_gemm_config.planes = 2;
qd8_f32_qc4w_gemm_config.mr = 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4x16 takes fewer registers in practice and is more cache friendly for LLM.
The wider output amortizes per m quantization

} \
} while (0)

#define KAI_UNUSED(x) (void)(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are XNN version of MIN/MAX in xnnpack/math.h

@@ -2034,23 +2041,226 @@ void xnn_compute_f16_qd8_convert(
context->convert_ukernel(n, input, output, &params);
}

#define KAI_ASSERT(x) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the wrong place to put this... it is a packing.c pack.h function.
For lhs, packx does packing on lhs for ppmm gemm microkernels.

const float src0_0 = *(src_ptr + k_idx);

// Calculate the max
vmax0 = vsetq_lane_f32(KAI_MAX(src0_0, vgetq_lane_f32(vmax0, 0)), vmax0, 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is slow. Better to handle similar to the main loop but with a mask.

const size_t dst_stride = mr * (k * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset);
const size_t k_block_len = kr / sr;

for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks like rminmax... but less optimized. To optimize more use 4 vectors and 4 accumulators or call f32-rminmax


for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) {
float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX);
float32x4_t vmin0 = vdupq_n_f32(FLT_MAX);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually faster to initialize with a single value duped from the first float. see f32-rminmax

const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + k_idx);

// Calculate the max
vmax0 = vmaxq_f32(src0_0, vmax0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 accumulators would allow co-issueing.
vmax0, vmax1, vmin0, vmin1

const float rmin0 = KAI_MIN(0.0f, min0);
const float rmax0 = KAI_MAX(0.0f, max0);

const float scale0 = rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

div is slow... consider reciprocal. Unroll loop to 4 to allow this portion of the code to be simd?

// Quantize the channels
k_idx = 0;
for (; k_idx < num_cols; k_idx += k_block_len) {
for (size_t k_block_idx = 0; k_block_idx < k_block_len; ++k_block_idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use cvt microkernel for float to int8

int32x4_t iacc1 = vpaddq_s32(iacc4455, iacc6677);

// LHS offset
const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be in quantization params?
Note we are thinking about adding a 3rd param - a zero point. Which is awkwardly 3 floats.
I wish we had 3 arrays instead of an array of offset, scale, zero point, but on ARM we can use LD3

lhs_ptr += sizeof(int32_t);

// LHS scale
const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A scaling value does not require a dup... a lane can be used. But I think it would be better to use vld2q_dup_f32 for 2 reasons:
1 instruction is better than 2
clang has bug with vld1q_dup that most often affects our clamp values that come in pairs, and the work around is use vld2 instead. Or a simple LDR and a lane to get to the specific element.

const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min);
const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max);

main_acc0 = vmaxq_f32(main_acc0, vmin_f32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is clamping needed? The scaling ensures the range?

Copy link
Contributor

@fbarchard fbarchard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some overall concerns
Its not clear that packing the left hand side has a benefit, and as is, it requires branching the code from the qs8 gemm and ppmm APIs.
The rminmax and cvt are reimplemented.
Lack of template/source - microkernels live in their own source files and update_microkernels.py builds them into amalgams.
missing i8mm microkernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants