diff --git a/include/omath/linear_algebra/mat.hpp b/include/omath/linear_algebra/mat.hpp index 99f8775..d6f2d33 100644 --- a/include/omath/linear_algebra/mat.hpp +++ b/include/omath/linear_algebra/mat.hpp @@ -10,7 +10,7 @@ #include #include #include - +#include namespace omath { @@ -155,10 +155,17 @@ namespace omath constexpr Mat operator*(const Mat& other) const { +#ifdef OMATH_USE_AVX2 + if constexpr (StoreType == MatStoreType::ROW_MAJOR) + return avx_multiply_row_major(other); + if constexpr (StoreType == MatStoreType::COLUMN_MAJOR) + return avx_multiply_col_major(other); +#else if constexpr (StoreType == MatStoreType::ROW_MAJOR) return cache_friendly_multiply_row_major(other); if constexpr (StoreType == MatStoreType::COLUMN_MAJOR) return cache_friendly_multiply_col_major(other); +#endif std::unreachable(); } @@ -391,6 +398,160 @@ namespace omath } return result; } +#ifdef OMATH_USE_AVX2 + template [[nodiscard]] + constexpr Mat + avx_multiply_col_major(const Mat& other) const + { + Mat result; + + const Type* this_mat_data = this->raw_array().data(); + const Type* other_mat_data = other.raw_array().data(); + Type* result_mat_data = result.raw_array().data(); + + if constexpr (std::is_same_v) + { + // ReSharper disable once CppTooWideScopeInitStatement + constexpr std::size_t vector_size = 8; + for (std::size_t j = 0; j < OtherColumns; ++j) + { + auto* c_col = reinterpret_cast(result_mat_data + j * Rows); + for (std::size_t k = 0; k < Columns; ++k) + { + const float bkj = reinterpret_cast(other_mat_data)[k + j * Columns]; + __m256 bkjv = _mm256_set1_ps(bkj); + + const auto* a_col_k = reinterpret_cast(this_mat_data + k * Rows); + + std::size_t i = 0; + for (; i + vector_size <= Rows; i += vector_size) + { + __m256 cvec = _mm256_loadu_ps(c_col + i); + __m256 avec = _mm256_loadu_ps(a_col_k + i); +#if defined(__FMA__) + cvec = _mm256_fmadd_ps(avec, bkjv, cvec); +#else + cvec = _mm256_add_ps(cvec, _mm256_mul_ps(avec, bkjv)); +#endif + _mm256_storeu_ps(c_col + i, cvec); + } + for (; i < Rows; ++i) + c_col[i] += a_col_k[i] * bkj; + } + } + } + else if (std::is_same_v) + { // double + // ReSharper disable once CppTooWideScopeInitStatement + constexpr std::size_t vector_size = 4; + for (std::size_t j = 0; j < OtherColumns; ++j) + { + auto* c_col = reinterpret_cast(result_mat_data + j * Rows); + for (std::size_t k = 0; k < Columns; ++k) + { + const double bkj = reinterpret_cast(other_mat_data)[k + j * Columns]; + __m256d bkjv = _mm256_set1_pd(bkj); + + const auto* a_col_k = reinterpret_cast(this_mat_data + k * Rows); + + std::size_t i = 0; + for (; i + vector_size <= Rows; i += vector_size) + { + __m256d cvec = _mm256_loadu_pd(c_col + i); + __m256d avec = _mm256_loadu_pd(a_col_k + i); +#if defined(__FMA__) + cvec = _mm256_fmadd_pd(avec, bkjv, cvec); +#else + cvec = _mm256_add_pd(cvec, _mm256_mul_pd(avec, bkjv)); +#endif + _mm256_storeu_pd(c_col + i, cvec); + } + for (; i < Rows; ++i) + c_col[i] += a_col_k[i] * bkj; + } + } + } + else + std::unreachable(); + + return result; + } + + template [[nodiscard]] + constexpr Mat + avx_multiply_row_major(const Mat& other) const + { + Mat result; + + const Type* this_mat_data = this->raw_array().data(); + const Type* other_mat_data = other.raw_array().data(); + Type* result_mat_data = result.raw_array().data(); + + if constexpr (std::is_same_v) + { + // ReSharper disable once CppTooWideScopeInitStatement + constexpr std::size_t vector_size = 8; + for (std::size_t i = 0; i < Rows; ++i) + { + Type* c_row = result_mat_data + i * OtherColumns; + for (std::size_t k = 0; k < Columns; ++k) + { + const auto aik = static_cast(this_mat_data[i * Columns + k]); + __m256 aikv = _mm256_set1_ps(aik); + const auto* b_row = reinterpret_cast(other_mat_data + k * OtherColumns); + + std::size_t j = 0; + for (; j + vector_size <= OtherColumns; j += vector_size) + { + __m256 cvec = _mm256_loadu_ps(c_row + j); + __m256 bvec = _mm256_loadu_ps(b_row + j); +#if defined(__FMA__) + cvec = _mm256_fmadd_ps(bvec, aikv, cvec); +#else + cvec = _mm256_add_ps(cvec, _mm256_mul_ps(bvec, aikv)); +#endif + _mm256_storeu_ps(c_row + j, cvec); + } + for (; j < OtherColumns; ++j) + c_row[j] += aik * b_row[j]; + } + } + } + else if (std::is_same_v) + { // double + // ReSharper disable once CppTooWideScopeInitStatement + constexpr std::size_t vector_size = 4; + for (std::size_t i = 0; i < Rows; ++i) + { + Type* c_row = result_mat_data + i * OtherColumns; + for (std::size_t k = 0; k < Columns; ++k) + { + const auto aik = static_cast(this_mat_data[i * Columns + k]); + __m256d aikv = _mm256_set1_pd(aik); + const auto* b_row = reinterpret_cast(other_mat_data + k * OtherColumns); + + std::size_t j = 0; + for (; j + vector_size <= OtherColumns; j += vector_size) + { + __m256d cvec = _mm256_loadu_pd(c_row + j); + __m256d bvec = _mm256_loadu_pd(b_row + j); +#if defined(__FMA__) + cvec = _mm256_fmadd_pd(bvec, aikv, cvec); +#else + cvec = _mm256_add_pd(cvec, _mm256_mul_pd(bvec, aikv)); +#endif + _mm256_storeu_pd(c_row + j, cvec); + } + for (; j < OtherColumns; ++j) + c_row[j] += aik * b_row[j]; + } + } + } + else + std::unreachable(); + return result; + } +#endif }; template [[nodiscard]]