diff --git a/include/omath/Mat.hpp b/include/omath/Mat.hpp index cf6c497..18907fa 100644 --- a/include/omath/Mat.hpp +++ b/include/omath/Mat.hpp @@ -17,7 +17,15 @@ namespace omath { size_t rows, columns; }; - template + + enum class MatStoreType : uint8_t + { + ROW_MAJOR = 0, + COLUMN_MAJOR + }; + + template + requires (std::is_floating_point_v || std::is_integral_v) class Mat final { public: @@ -26,8 +34,7 @@ namespace omath Clear(); } - - constexpr Mat(const std::initializer_list > &rows) + constexpr Mat(const std::initializer_list>& rows) { if (rows.size() != Rows) throw std::invalid_argument("Initializer list rows size does not match template parameter Rows"); @@ -42,18 +49,21 @@ namespace omath auto colIt = rowIt->begin(); for (size_t j = 0; j < Columns; ++j, ++colIt) { - At(i, j) = *colIt; + At(i, j) = std::move(*colIt); } } } + constexpr explicit Mat(const Type* rawData) + { + std::copy_n(rawData, Rows * Columns, m_data.begin()); + } constexpr Mat(const Mat &other) noexcept { m_data = other.m_data; } - constexpr Mat(Mat &&other) noexcept { m_data = std::move(other.m_data); @@ -77,13 +87,22 @@ namespace omath return {Rows, Columns}; } - [[nodiscard]] constexpr const Type &At(const size_t rowIndex, const size_t columnIndex) const { if (rowIndex >= Rows || columnIndex >= Columns) throw std::out_of_range("Index out of range"); - return m_data[rowIndex * Columns + columnIndex]; + if constexpr (StoreType == MatStoreType::ROW_MAJOR) + return m_data[rowIndex * Columns + columnIndex]; + + else if constexpr (StoreType == MatStoreType::COLUMN_MAJOR) + return m_data[rowIndex + columnIndex * Rows]; + + else + { + static_assert(false, "Invalid matrix access convention"); + std::unreachable(); + } } [[nodiscard]] constexpr Type &At(const size_t rowIndex, const size_t columnIndex) @@ -240,6 +259,18 @@ namespace omath return result; } + [[nodiscard]] + constexpr const std::array& RawArray() const + { + return m_data; + } + + [[nodiscard]] + constexpr std::array& RawArray() + { + return const_cast>(std::as_const(*this).RawArray()); + } + [[nodiscard]] std::string ToString() const noexcept {