added avx mutiplication

This commit is contained in:
2025-09-17 19:47:29 +03:00
parent a2de6f8fae
commit d773985822

View File

@@ -10,7 +10,7 @@
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
#include <immintrin.h>
namespace omath namespace omath
{ {
@@ -155,10 +155,17 @@ namespace omath
constexpr Mat<Rows, OtherColumns, Type, StoreType> constexpr Mat<Rows, OtherColumns, Type, StoreType>
operator*(const Mat<Columns, OtherColumns, Type, StoreType>& other) const operator*(const Mat<Columns, OtherColumns, Type, StoreType>& other) const
{ {
#ifdef OMATH_USE_AVX2
if constexpr (StoreType == MatStoreType::ROW_MAJOR)
return avx_multiply_row_major(other);
if constexpr (StoreType == MatStoreType::COLUMN_MAJOR)
return avx_multiply_col_major(other);
#else
if constexpr (StoreType == MatStoreType::ROW_MAJOR) if constexpr (StoreType == MatStoreType::ROW_MAJOR)
return cache_friendly_multiply_row_major(other); return cache_friendly_multiply_row_major(other);
if constexpr (StoreType == MatStoreType::COLUMN_MAJOR) if constexpr (StoreType == MatStoreType::COLUMN_MAJOR)
return cache_friendly_multiply_col_major(other); return cache_friendly_multiply_col_major(other);
#endif
std::unreachable(); std::unreachable();
} }
@@ -391,6 +398,160 @@ namespace omath
} }
return result; return result;
} }
#ifdef OMATH_USE_AVX2
template<size_t OtherColumns> [[nodiscard]]
constexpr Mat<Rows, OtherColumns, Type, MatStoreType::COLUMN_MAJOR>
avx_multiply_col_major(const Mat<Columns, OtherColumns, Type, MatStoreType::COLUMN_MAJOR>& other) const
{
Mat<Rows, OtherColumns, Type, MatStoreType::COLUMN_MAJOR> result;
const Type* this_mat_data = this->raw_array().data();
const Type* other_mat_data = other.raw_array().data();
Type* result_mat_data = result.raw_array().data();
if constexpr (std::is_same_v<Type, float>)
{
// ReSharper disable once CppTooWideScopeInitStatement
constexpr std::size_t vector_size = 8;
for (std::size_t j = 0; j < OtherColumns; ++j)
{
auto* c_col = reinterpret_cast<float*>(result_mat_data + j * Rows);
for (std::size_t k = 0; k < Columns; ++k)
{
const float bkj = reinterpret_cast<const float*>(other_mat_data)[k + j * Columns];
__m256 bkjv = _mm256_set1_ps(bkj);
const auto* a_col_k = reinterpret_cast<const float*>(this_mat_data + k * Rows);
std::size_t i = 0;
for (; i + vector_size <= Rows; i += vector_size)
{
__m256 cvec = _mm256_loadu_ps(c_col + i);
__m256 avec = _mm256_loadu_ps(a_col_k + i);
#if defined(__FMA__)
cvec = _mm256_fmadd_ps(avec, bkjv, cvec);
#else
cvec = _mm256_add_ps(cvec, _mm256_mul_ps(avec, bkjv));
#endif
_mm256_storeu_ps(c_col + i, cvec);
}
for (; i < Rows; ++i)
c_col[i] += a_col_k[i] * bkj;
}
}
}
else if (std::is_same_v<Type, double>)
{ // double
// ReSharper disable once CppTooWideScopeInitStatement
constexpr std::size_t vector_size = 4;
for (std::size_t j = 0; j < OtherColumns; ++j)
{
auto* c_col = reinterpret_cast<double*>(result_mat_data + j * Rows);
for (std::size_t k = 0; k < Columns; ++k)
{
const double bkj = reinterpret_cast<const double*>(other_mat_data)[k + j * Columns];
__m256d bkjv = _mm256_set1_pd(bkj);
const auto* a_col_k = reinterpret_cast<const double*>(this_mat_data + k * Rows);
std::size_t i = 0;
for (; i + vector_size <= Rows; i += vector_size)
{
__m256d cvec = _mm256_loadu_pd(c_col + i);
__m256d avec = _mm256_loadu_pd(a_col_k + i);
#if defined(__FMA__)
cvec = _mm256_fmadd_pd(avec, bkjv, cvec);
#else
cvec = _mm256_add_pd(cvec, _mm256_mul_pd(avec, bkjv));
#endif
_mm256_storeu_pd(c_col + i, cvec);
}
for (; i < Rows; ++i)
c_col[i] += a_col_k[i] * bkj;
}
}
}
else
std::unreachable();
return result;
}
template<size_t OtherColumns> [[nodiscard]]
constexpr Mat<Rows, OtherColumns, Type, MatStoreType::ROW_MAJOR>
avx_multiply_row_major(const Mat<Columns, OtherColumns, Type, MatStoreType::ROW_MAJOR>& other) const
{
Mat<Rows, OtherColumns, Type, MatStoreType::ROW_MAJOR> result;
const Type* this_mat_data = this->raw_array().data();
const Type* other_mat_data = other.raw_array().data();
Type* result_mat_data = result.raw_array().data();
if constexpr (std::is_same_v<Type, float>)
{
// ReSharper disable once CppTooWideScopeInitStatement
constexpr std::size_t vector_size = 8;
for (std::size_t i = 0; i < Rows; ++i)
{
Type* c_row = result_mat_data + i * OtherColumns;
for (std::size_t k = 0; k < Columns; ++k)
{
const auto aik = static_cast<float>(this_mat_data[i * Columns + k]);
__m256 aikv = _mm256_set1_ps(aik);
const auto* b_row = reinterpret_cast<const float*>(other_mat_data + k * OtherColumns);
std::size_t j = 0;
for (; j + vector_size <= OtherColumns; j += vector_size)
{
__m256 cvec = _mm256_loadu_ps(c_row + j);
__m256 bvec = _mm256_loadu_ps(b_row + j);
#if defined(__FMA__)
cvec = _mm256_fmadd_ps(bvec, aikv, cvec);
#else
cvec = _mm256_add_ps(cvec, _mm256_mul_ps(bvec, aikv));
#endif
_mm256_storeu_ps(c_row + j, cvec);
}
for (; j < OtherColumns; ++j)
c_row[j] += aik * b_row[j];
}
}
}
else if (std::is_same_v<Type, double>)
{ // double
// ReSharper disable once CppTooWideScopeInitStatement
constexpr std::size_t vector_size = 4;
for (std::size_t i = 0; i < Rows; ++i)
{
Type* c_row = result_mat_data + i * OtherColumns;
for (std::size_t k = 0; k < Columns; ++k)
{
const auto aik = static_cast<double>(this_mat_data[i * Columns + k]);
__m256d aikv = _mm256_set1_pd(aik);
const auto* b_row = reinterpret_cast<const double*>(other_mat_data + k * OtherColumns);
std::size_t j = 0;
for (; j + vector_size <= OtherColumns; j += vector_size)
{
__m256d cvec = _mm256_loadu_pd(c_row + j);
__m256d bvec = _mm256_loadu_pd(b_row + j);
#if defined(__FMA__)
cvec = _mm256_fmadd_pd(bvec, aikv, cvec);
#else
cvec = _mm256_add_pd(cvec, _mm256_mul_pd(bvec, aikv));
#endif
_mm256_storeu_pd(c_row + j, cvec);
}
for (; j < OtherColumns; ++j)
c_row[j] += aik * b_row[j];
}
}
}
else
std::unreachable();
return result;
}
#endif
}; };
template<class Type = float, MatStoreType St = MatStoreType::ROW_MAJOR> [[nodiscard]] template<class Type = float, MatStoreType St = MatStoreType::ROW_MAJOR> [[nodiscard]]