#include "winograd_f6x3.h"

[[gnu::always_inline]] inline
void perform_filter_transform_IC1(const float f[], __m512 zmm_u[]) {
    __m256 ymm_a[8], ymm_b[8];
    for (int i = 0; i < 3; ++i)
        ymm_b[i] = _mm256_loadu_ps(&f[3 * i]);
    transform_GD_6x3<__m256>(ymm_b, ymm_a);
    mm_transpose_8x3<__m256>(ymm_a, ymm_b);
    transform_GD_6x3<__m256>(ymm_b, ymm_a);
    for (int i = 0; i < 4; ++i) {
        zmm_u[i] = _mm512_castps256_ps512(ymm_a[2 * i]);
        zmm_u[i] = _mm512_insertf32x8(zmm_u[i], ymm_a[2 * i + 1], 1);
    }
}

[[gnu::always_inline]] inline
void perform_filter_transform_IC2(const float f[], __m512 zmm_u[]) {
    __m512 zmm_a[8], zmm_b[8];
    for (int i = 0; i < 3; ++i) {
        zmm_b[i] = _mm512_loadu_ps(&f[3 * i]);
        zmm_b[i] = _mm512_insertf32x8(zmm_b[i], _mm256_loadu_ps(&f[9 + 3 * i]), 1);
    }
    transform_GD_6x3<__m512>(zmm_b, zmm_a);
    mm_transpose_8x3<__m512>(zmm_a, zmm_b);
    transform_GD_6x3<__m512>(zmm_b, zmm_a);
    mm_transpose_8x16_col2row(zmm_a, zmm_u);
}

[[gnu::always_inline]] inline
void perform_image_transform_IC1(const float im1[], __m512 zmm_v[], int IW) {
    __m256 ymm_a[8], ymm_b[8];
    for (int i = 0; i < 8; ++i)
        ymm_b[i] = _mm256_loadu_ps(&im1[i * IW]);
    transform_BtD_6x3<__m256>(ymm_b, ymm_a);
    mm_transpose_8x8<__m256>(ymm_a, ymm_b);
    transform_BtD_6x3<__m256>(ymm_b, ymm_a);
    for (int i = 0; i < 4; ++i) {
        zmm_v[i] = _mm512_castps256_ps512(ymm_a[2 * i]);
        zmm_v[i] = _mm512_insertf32x8(zmm_v[i], ymm_a[2 * i + 1], 1);
    }
}

[[gnu::always_inline]] inline
void perform_image_transform_IC2(const float im1[], const float im2[], __m512 v[], int IW) {
    __m512 zmm_a[8], zmm_b[8];
    for (int i = 0; i < 8; ++i) {
        zmm_b[i] = _mm512_loadu_ps(&im1[i * IW]);
        zmm_b[i] = _mm512_insertf32x8(zmm_b[i], _mm256_loadu_ps(&im2[i * IW]), 1);
    }
    transform_BtD_6x3<__m512>(zmm_b, zmm_a);
    mm_transpose_8x8<__m512>(zmm_a, zmm_b);
    transform_BtD_6x3<__m512>(zmm_b, zmm_a);
    mm_transpose_8x16_col2row(zmm_a, v);
}

[[gnu::always_inline, gnu::hot]] inline
void perform_mult_OC2(const __m512 u1[], const __m512 u2[], const __m512 v[], __m512 zmm_m[8], int sizeIC) {
    for (int i = 0; i < 8; ++i)
        zmm_m[i] = _mm512_setzero_ps();
    for (int ci = 0; ci < sizeIC; ++ci) {
        for (int i = 0; i < 4; ++i) {
            zmm_m[i] += u1[0] * v[0];
            zmm_m[i + 4] += u2[0] * v[0];
            ++u1; ++u2; ++v;
        }
    }
}

[[gnu::always_inline, gnu::hot]] inline
void perform_store_transform_OC2(__m512 zmm_m[8], float *r1, float *r2, int OW, int over_x, int over_y) {
    __m512 zmm_a[8], zmm_b[8];
    unsigned char mask = gen_mmask8_for_dump(over_y);
    mm_transpose_8x16_row2col(zmm_m, zmm_b);
    transform_AtD_6x3<__m512>(zmm_b, zmm_a);
    mm_transpose_8x8<__m512>(zmm_a, zmm_b);
    transform_AtD_6x3<__m512>(zmm_b, zmm_a);
    for (int i = over_x; i < 6; ++i) {
        _mm256_mask_storeu_ps(&r1[i * OW], mask, _mm512_extractf32x8_ps(zmm_a[i], 0) + _mm256_loadu_ps(&r1[i * OW]));
        _mm256_mask_storeu_ps(&r2[i * OW], mask, _mm512_extractf32x8_ps(zmm_a[i], 1) + _mm256_loadu_ps(&r2[i * OW]));
    }
}

void winconv(const float *__restrict__ image, const int IH,
             const int IW, const int IC, const float *__restrict__ filter,
             const int OC, const int N, float *__restrict__ result) {
    const int OH = IH - 2;
    const int OW = IW - 2;
    const int TH = ceildiv(OH, 6);
    const int TW = ceildiv(OW, 6);

    const int sliceOC = min(OC, hintOC);
    const int sliceIC = min(IC, hintIC);

    const int size_result = N * OC * OH * OW;
#pragma omp parallel
#pragma omp for simd aligned(result) schedule(static)
    for (int i = 0; i < size_result; ++i) result[i] = 0;

    for (int startOC = 0; startOC < OC; startOC += sliceOC) {
        Range rOC(startOC, min(startOC + sliceOC, OC));
        for (int startIC = 0; startIC < IC; startIC += sliceIC) {
            Range rIC(startIC, min(startIC + sliceIC, IC));
#pragma omp parallel default(shared)
            {
                __m512 V[rIC.size() * 4];
                __m512 U[rOC.size() * rIC.size() * 4];
                for (int k = rOC.start; k < rOC.end; ++k) {
                    for (int c = rIC.start; c < rIC.end; c += 2) {
                        int ki = k - rOC.start, ci = c - rIC.start;
                        const float *f = &filter[(k * IC + c) * 9];
                        __m512 *u = &U[(ki * rIC.size() + ci) * 4];
                        if (rIC.end - c == 1) perform_filter_transform_IC1(f, u);
                        else perform_filter_transform_IC2(f, u);
                    }
                }
#pragma omp for schedule(static)
                for (int b = 0; b < N; ++b) {
                    for (int x_ = 0; x_ < TH; ++x_) {
                        for (int y_ = 0; y_ < TW; ++y_) {
                            int x = x_ * 6, y = y_ * 6;
                            int over_x = 0, over_y = 0;
                            if (x + 6 > OH) { over_x = x + 6 - OH; x -= over_x; }
                            if (y + 6 > OW) { over_y = y + 6 - OW; y -= over_y; }
                            for (int c = rIC.start; c < rIC.end; c += 2) {
                                int ci = c - rIC.start;
                                const float *im1 = &image[((b * IC + c) * IH + x) * IW + y];
                                const float *im2 = &im1[IW * IH];
                                __m512 *zmm_v = &V[ci * 4];
                                if (rIC.end - c == 1) perform_image_transform_IC1(im1, zmm_v, IW);
                                else perform_image_transform_IC2(im1, im2, zmm_v, IW);
                            }
                            for (int k = rOC.start; k < rOC.end; k += 2) {
                                int ki = k - rOC.start;
                                __m512 *zmm_u1 = &U[ki * rIC.size() * 4], *zmm_u2 = &zmm_u1[rIC.size() * 4];
                                __m512 *zmm_v = V;
                                __m512 zmm_m[8];
                                perform_mult_OC2(zmm_u1, zmm_u2, zmm_v, zmm_m, rIC.size());
                                float *r1 = &result[((b * OC + k) * OH + x) * OW + y];
                                float *r2 = &r1[OH * OW];
                                perform_store_transform_OC2(zmm_m, r1, r2, OW, over_x, over_y);
                            }
                        }
                    }
                }
            }
        }
    }
}