improved matrix class

This commit is contained in:
2024-11-23 12:31:01 +03:00
parent 5454c43b18
commit ea4e27b87b

View File

@@ -17,7 +17,15 @@ namespace omath
{
size_t rows, columns;
};
template<size_t Rows = 0, size_t Columns = 0, class Type = float>
enum class MatStoreType : uint8_t
{
ROW_MAJOR = 0,
COLUMN_MAJOR
};
template<size_t Rows = 0, size_t Columns = 0, class Type = float, MatStoreType StoreType = MatStoreType::ROW_MAJOR>
requires (std::is_floating_point_v<Type> || std::is_integral_v<Type>)
class Mat final
{
public:
@@ -26,7 +34,6 @@ namespace omath
Clear();
}
constexpr Mat(const std::initializer_list<std::initializer_list<Type>>& rows)
{
if (rows.size() != Rows)
@@ -42,18 +49,21 @@ namespace omath
auto colIt = rowIt->begin();
for (size_t j = 0; j < Columns; ++j, ++colIt)
{
At(i, j) = *colIt;
At(i, j) = std::move(*colIt);
}
}
}
constexpr explicit Mat(const Type* rawData)
{
std::copy_n(rawData, Rows * Columns, m_data.begin());
}
constexpr Mat(const Mat &other) noexcept
{
m_data = other.m_data;
}
constexpr Mat(Mat &&other) noexcept
{
m_data = std::move(other.m_data);
@@ -77,13 +87,22 @@ namespace omath
return {Rows, Columns};
}
[[nodiscard]] constexpr const Type &At(const size_t rowIndex, const size_t columnIndex) const
{
if (rowIndex >= Rows || columnIndex >= Columns)
throw std::out_of_range("Index out of range");
if constexpr (StoreType == MatStoreType::ROW_MAJOR)
return m_data[rowIndex * Columns + columnIndex];
else if constexpr (StoreType == MatStoreType::COLUMN_MAJOR)
return m_data[rowIndex + columnIndex * Rows];
else
{
static_assert(false, "Invalid matrix access convention");
std::unreachable();
}
}
[[nodiscard]] constexpr Type &At(const size_t rowIndex, const size_t columnIndex)
@@ -240,6 +259,18 @@ namespace omath
return result;
}
[[nodiscard]]
constexpr const std::array<Type, Rows*Columns>& RawArray() const
{
return m_data;
}
[[nodiscard]]
constexpr std::array<Type, Rows*Columns>& RawArray()
{
return const_cast<std::array<Type, Rows*Columns>>(std::as_const(*this).RawArray());
}
[[nodiscard]]
std::string ToString() const noexcept
{