mirror of
https://github.com/orange-cpp/omath.git
synced 2026-02-13 07:03:25 +00:00
added avx mutiplication
This commit is contained in:
@@ -10,7 +10,7 @@
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <utility>
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
namespace omath
|
||||
{
|
||||
@@ -155,10 +155,17 @@ namespace omath
|
||||
constexpr Mat<Rows, OtherColumns, Type, StoreType>
|
||||
operator*(const Mat<Columns, OtherColumns, Type, StoreType>& other) const
|
||||
{
|
||||
#ifdef OMATH_USE_AVX2
|
||||
if constexpr (StoreType == MatStoreType::ROW_MAJOR)
|
||||
return avx_multiply_row_major(other);
|
||||
if constexpr (StoreType == MatStoreType::COLUMN_MAJOR)
|
||||
return avx_multiply_col_major(other);
|
||||
#else
|
||||
if constexpr (StoreType == MatStoreType::ROW_MAJOR)
|
||||
return cache_friendly_multiply_row_major(other);
|
||||
if constexpr (StoreType == MatStoreType::COLUMN_MAJOR)
|
||||
return cache_friendly_multiply_col_major(other);
|
||||
#endif
|
||||
std::unreachable();
|
||||
}
|
||||
|
||||
@@ -391,6 +398,160 @@ namespace omath
|
||||
}
|
||||
return result;
|
||||
}
|
||||
#ifdef OMATH_USE_AVX2
|
||||
template<size_t OtherColumns> [[nodiscard]]
|
||||
constexpr Mat<Rows, OtherColumns, Type, MatStoreType::COLUMN_MAJOR>
|
||||
avx_multiply_col_major(const Mat<Columns, OtherColumns, Type, MatStoreType::COLUMN_MAJOR>& other) const
|
||||
{
|
||||
Mat<Rows, OtherColumns, Type, MatStoreType::COLUMN_MAJOR> result;
|
||||
|
||||
const Type* this_mat_data = this->raw_array().data();
|
||||
const Type* other_mat_data = other.raw_array().data();
|
||||
Type* result_mat_data = result.raw_array().data();
|
||||
|
||||
if constexpr (std::is_same_v<Type, float>)
|
||||
{
|
||||
// ReSharper disable once CppTooWideScopeInitStatement
|
||||
constexpr std::size_t vector_size = 8;
|
||||
for (std::size_t j = 0; j < OtherColumns; ++j)
|
||||
{
|
||||
auto* c_col = reinterpret_cast<float*>(result_mat_data + j * Rows);
|
||||
for (std::size_t k = 0; k < Columns; ++k)
|
||||
{
|
||||
const float bkj = reinterpret_cast<const float*>(other_mat_data)[k + j * Columns];
|
||||
__m256 bkjv = _mm256_set1_ps(bkj);
|
||||
|
||||
const auto* a_col_k = reinterpret_cast<const float*>(this_mat_data + k * Rows);
|
||||
|
||||
std::size_t i = 0;
|
||||
for (; i + vector_size <= Rows; i += vector_size)
|
||||
{
|
||||
__m256 cvec = _mm256_loadu_ps(c_col + i);
|
||||
__m256 avec = _mm256_loadu_ps(a_col_k + i);
|
||||
#if defined(__FMA__)
|
||||
cvec = _mm256_fmadd_ps(avec, bkjv, cvec);
|
||||
#else
|
||||
cvec = _mm256_add_ps(cvec, _mm256_mul_ps(avec, bkjv));
|
||||
#endif
|
||||
_mm256_storeu_ps(c_col + i, cvec);
|
||||
}
|
||||
for (; i < Rows; ++i)
|
||||
c_col[i] += a_col_k[i] * bkj;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (std::is_same_v<Type, double>)
|
||||
{ // double
|
||||
// ReSharper disable once CppTooWideScopeInitStatement
|
||||
constexpr std::size_t vector_size = 4;
|
||||
for (std::size_t j = 0; j < OtherColumns; ++j)
|
||||
{
|
||||
auto* c_col = reinterpret_cast<double*>(result_mat_data + j * Rows);
|
||||
for (std::size_t k = 0; k < Columns; ++k)
|
||||
{
|
||||
const double bkj = reinterpret_cast<const double*>(other_mat_data)[k + j * Columns];
|
||||
__m256d bkjv = _mm256_set1_pd(bkj);
|
||||
|
||||
const auto* a_col_k = reinterpret_cast<const double*>(this_mat_data + k * Rows);
|
||||
|
||||
std::size_t i = 0;
|
||||
for (; i + vector_size <= Rows; i += vector_size)
|
||||
{
|
||||
__m256d cvec = _mm256_loadu_pd(c_col + i);
|
||||
__m256d avec = _mm256_loadu_pd(a_col_k + i);
|
||||
#if defined(__FMA__)
|
||||
cvec = _mm256_fmadd_pd(avec, bkjv, cvec);
|
||||
#else
|
||||
cvec = _mm256_add_pd(cvec, _mm256_mul_pd(avec, bkjv));
|
||||
#endif
|
||||
_mm256_storeu_pd(c_col + i, cvec);
|
||||
}
|
||||
for (; i < Rows; ++i)
|
||||
c_col[i] += a_col_k[i] * bkj;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
std::unreachable();
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template<size_t OtherColumns> [[nodiscard]]
|
||||
constexpr Mat<Rows, OtherColumns, Type, MatStoreType::ROW_MAJOR>
|
||||
avx_multiply_row_major(const Mat<Columns, OtherColumns, Type, MatStoreType::ROW_MAJOR>& other) const
|
||||
{
|
||||
Mat<Rows, OtherColumns, Type, MatStoreType::ROW_MAJOR> result;
|
||||
|
||||
const Type* this_mat_data = this->raw_array().data();
|
||||
const Type* other_mat_data = other.raw_array().data();
|
||||
Type* result_mat_data = result.raw_array().data();
|
||||
|
||||
if constexpr (std::is_same_v<Type, float>)
|
||||
{
|
||||
// ReSharper disable once CppTooWideScopeInitStatement
|
||||
constexpr std::size_t vector_size = 8;
|
||||
for (std::size_t i = 0; i < Rows; ++i)
|
||||
{
|
||||
Type* c_row = result_mat_data + i * OtherColumns;
|
||||
for (std::size_t k = 0; k < Columns; ++k)
|
||||
{
|
||||
const auto aik = static_cast<float>(this_mat_data[i * Columns + k]);
|
||||
__m256 aikv = _mm256_set1_ps(aik);
|
||||
const auto* b_row = reinterpret_cast<const float*>(other_mat_data + k * OtherColumns);
|
||||
|
||||
std::size_t j = 0;
|
||||
for (; j + vector_size <= OtherColumns; j += vector_size)
|
||||
{
|
||||
__m256 cvec = _mm256_loadu_ps(c_row + j);
|
||||
__m256 bvec = _mm256_loadu_ps(b_row + j);
|
||||
#if defined(__FMA__)
|
||||
cvec = _mm256_fmadd_ps(bvec, aikv, cvec);
|
||||
#else
|
||||
cvec = _mm256_add_ps(cvec, _mm256_mul_ps(bvec, aikv));
|
||||
#endif
|
||||
_mm256_storeu_ps(c_row + j, cvec);
|
||||
}
|
||||
for (; j < OtherColumns; ++j)
|
||||
c_row[j] += aik * b_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (std::is_same_v<Type, double>)
|
||||
{ // double
|
||||
// ReSharper disable once CppTooWideScopeInitStatement
|
||||
constexpr std::size_t vector_size = 4;
|
||||
for (std::size_t i = 0; i < Rows; ++i)
|
||||
{
|
||||
Type* c_row = result_mat_data + i * OtherColumns;
|
||||
for (std::size_t k = 0; k < Columns; ++k)
|
||||
{
|
||||
const auto aik = static_cast<double>(this_mat_data[i * Columns + k]);
|
||||
__m256d aikv = _mm256_set1_pd(aik);
|
||||
const auto* b_row = reinterpret_cast<const double*>(other_mat_data + k * OtherColumns);
|
||||
|
||||
std::size_t j = 0;
|
||||
for (; j + vector_size <= OtherColumns; j += vector_size)
|
||||
{
|
||||
__m256d cvec = _mm256_loadu_pd(c_row + j);
|
||||
__m256d bvec = _mm256_loadu_pd(b_row + j);
|
||||
#if defined(__FMA__)
|
||||
cvec = _mm256_fmadd_pd(bvec, aikv, cvec);
|
||||
#else
|
||||
cvec = _mm256_add_pd(cvec, _mm256_mul_pd(bvec, aikv));
|
||||
#endif
|
||||
_mm256_storeu_pd(c_row + j, cvec);
|
||||
}
|
||||
for (; j < OtherColumns; ++j)
|
||||
c_row[j] += aik * b_row[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
std::unreachable();
|
||||
return result;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template<class Type = float, MatStoreType St = MatStoreType::ROW_MAJOR> [[nodiscard]]
|
||||
|
||||
Reference in New Issue
Block a user