diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index ff720d5f..a0d1951e 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -50,45 +50,56 @@ template using is_eigen_base = all_of< template struct type_caster::value && !is_eigen_ref::value>> { typedef typename Type::Scalar Scalar; - static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; + static constexpr bool rowMajor = Type::IsRowMajor; static constexpr bool isVector = Type::IsVectorAtCompileTime; + static constexpr auto nRows = Type::RowsAtCompileTime, nCols = Type::ColsAtCompileTime; bool load(handle src, bool) { auto buf = array_t::ensure(src); if (!buf) return false; - if (buf.ndim() == 1) { - typedef Eigen::InnerStride<> Strides; - if (!isVector && - !(Type::RowsAtCompileTime == Eigen::Dynamic && - Type::ColsAtCompileTime == Eigen::Dynamic)) + using namespace Eigen; + + using Strides = Stride; + using AMap = Map; + + if (buf.ndim() == 1) { // A one-dimensional array + Index n_elts = buf.shape(0); + size_t str = buf.strides(0) / sizeof(Scalar); + Strides stride(str, str); // Whether we map to inner or outer is irrelevant + if (isVector) { + if (Type::SizeAtCompileTime != Dynamic && Type::SizeAtCompileTime != n_elts) + return false; // Vector size mismatch + value = AMap(buf.mutable_data(), nRows == 1 ? 1 : n_elts, nCols == 1 ? 1 : n_elts, stride); + } + else if (Type::SizeAtCompileTime != Dynamic) { + // The type has a fixed size, but is not a vector: abort return false; - - if (Type::SizeAtCompileTime != Eigen::Dynamic && - buf.shape(0) != (size_t) Type::SizeAtCompileTime) - return false; - - Strides::Index n_elts = (Strides::Index) buf.shape(0); - Strides::Index unity = 1; - - value = Eigen::Map( - buf.mutable_data(), - rowMajor ? unity : n_elts, - rowMajor ? n_elts : unity, - Strides(buf.strides(0) / sizeof(Scalar)) - ); + } + else if (nRows == Dynamic && nCols == Dynamic) { + // Fully dynamic size. numpy doesn't distinguish between a row vector and column + // vector, so we'll (arbitrarily) choose a column vector. + value = AMap(buf.mutable_data(), n_elts, 1, stride); + } + else if (nRows != Dynamic) { + // Since this isn't a vector, nRows must be != 1. We allow this only if it exactly + // equals the number of elements (nCols is Dynamic, and so 1 is allowed). + if (nRows != n_elts) return false; + value = AMap(buf.mutable_data(), n_elts, 1, stride); + } + else { // nCols != Dynamic; same as above, but for fixed columns + if (nCols != n_elts) return false; + value = AMap(buf.mutable_data(), 1, n_elts, stride); + } } else if (buf.ndim() == 2) { - typedef Eigen::Stride Strides; - if ((Type::RowsAtCompileTime != Eigen::Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) || - (Type::ColsAtCompileTime != Eigen::Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime)) + if ((Type::RowsAtCompileTime != Dynamic && buf.shape(0) != (size_t) Type::RowsAtCompileTime) || + (Type::ColsAtCompileTime != Dynamic && buf.shape(1) != (size_t) Type::ColsAtCompileTime)) return false; - value = Eigen::Map( - buf.mutable_data(), - typename Strides::Index(buf.shape(0)), - typename Strides::Index(buf.shape(1)), + value = AMap( + buf.mutable_data(), buf.shape(0), buf.shape(1), Strides(buf.strides(rowMajor ? 0 : 1) / sizeof(Scalar), buf.strides(rowMajor ? 1 : 0) / sizeof(Scalar)) ); @@ -176,7 +187,7 @@ struct type_caster::value>> { typedef typename Type::Scalar Scalar; typedef typename std::remove_reference().outerIndexPtr())>::type StorageIndex; typedef typename Type::Index Index; - static constexpr bool rowMajor = Type::Flags & Eigen::RowMajorBit; + static constexpr bool rowMajor = Type::IsRowMajor; bool load(handle src, bool) { if (!src) @@ -227,7 +238,7 @@ struct type_caster::value>> { ).release(); } - PYBIND11_TYPE_CASTER(Type, _<(Type::Flags & Eigen::RowMajorBit) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") + PYBIND11_TYPE_CASTER(Type, _<(Type::IsRowMajor) != 0>("scipy.sparse.csr_matrix[", "scipy.sparse.csc_matrix[") + npy_format_descriptor::name() + _("]")); }; diff --git a/tests/test_eigen.cpp b/tests/test_eigen.cpp index 588cdceb..b10d025d 100644 --- a/tests/test_eigen.cpp +++ b/tests/test_eigen.cpp @@ -37,6 +37,10 @@ test_initializer eigen([](py::module &m) { typedef Eigen::Matrix FixedMatrixC; typedef Eigen::Matrix DenseMatrixR; typedef Eigen::Matrix DenseMatrixC; + typedef Eigen::Matrix FourRowMatrixC; + typedef Eigen::Matrix FourColMatrixC; + typedef Eigen::Matrix FourRowMatrixR; + typedef Eigen::Matrix FourColMatrixR; typedef Eigen::SparseMatrix SparseMatrixR; typedef Eigen::SparseMatrix SparseMatrixC; @@ -131,4 +135,17 @@ test_initializer eigen([](py::module &m) { m.def("sparse_passthrough_c", [](const SparseMatrixC &m) -> SparseMatrixC { return m; }); + + m.def("partial_passthrough_four_rm_r", [](const FourRowMatrixR &m) -> FourRowMatrixR { + return m; + }); + m.def("partial_passthrough_four_rm_c", [](const FourColMatrixR &m) -> FourColMatrixR { + return m; + }); + m.def("partial_passthrough_four_cm_r", [](const FourRowMatrixC &m) -> FourRowMatrixC { + return m; + }); + m.def("partial_passthrough_four_cm_c", [](const FourColMatrixC &m) -> FourColMatrixC { + return m; + }); }); diff --git a/tests/test_eigen.py b/tests/test_eigen.py index b0092fc8..5d4d94a5 100644 --- a/tests/test_eigen.py +++ b/tests/test_eigen.py @@ -41,6 +41,24 @@ def test_dense(): assert_equal_ref(dense_passthrough_r(dense_c())) assert_equal_ref(dense_passthrough_c(dense_r())) +@pytest.requires_eigen_and_numpy +def test_partially_fixed(): + from pybind11_tests import partial_passthrough_four_rm_r, partial_passthrough_four_rm_c, partial_passthrough_four_cm_r, partial_passthrough_four_cm_c + + ref2 = np.array([[0,1,2,3], [4,5,6,7], [8,9,10,11], [12,13,14,15]]) + np.testing.assert_array_equal(partial_passthrough_four_rm_r(ref2), ref2) + np.testing.assert_array_equal(partial_passthrough_four_rm_c(ref2), ref2) + np.testing.assert_array_equal(partial_passthrough_four_rm_r(ref2[:, 1]), ref2[:, [1]]) + np.testing.assert_array_equal(partial_passthrough_four_rm_c(ref2[0, :]), ref2[[0], :]) + np.testing.assert_array_equal(partial_passthrough_four_rm_r(ref2[:, (0, 2)]), ref2[:, (0,2)]) + np.testing.assert_array_equal(partial_passthrough_four_rm_c(ref2[(3,1,2), :]), ref2[(3,1,2), :]) + + np.testing.assert_array_equal(partial_passthrough_four_cm_r(ref2), ref2) + np.testing.assert_array_equal(partial_passthrough_four_cm_c(ref2), ref2) + np.testing.assert_array_equal(partial_passthrough_four_cm_r(ref2[:, 1]), ref2[:, [1]]) + np.testing.assert_array_equal(partial_passthrough_four_cm_c(ref2[0, :]), ref2[[0], :]) + np.testing.assert_array_equal(partial_passthrough_four_cm_r(ref2[:, (0, 2)]), ref2[:, (0,2)]) + np.testing.assert_array_equal(partial_passthrough_four_cm_c(ref2[(3,1,2), :]), ref2[(3,1,2), :]) @pytest.requires_eigen_and_numpy def test_nonunit_stride_from_python(): @@ -49,16 +67,16 @@ def test_nonunit_stride_from_python(): counting_mat = np.arange(9.0, dtype=np.float32).reshape((3, 3)) first_row = counting_mat[0, :] first_col = counting_mat[:, 0] - assert np.array_equal(double_row(first_row), 2.0 * first_row) - assert np.array_equal(double_col(first_row), 2.0 * first_row) - assert np.array_equal(double_row(first_col), 2.0 * first_col) - assert np.array_equal(double_col(first_col), 2.0 * first_col) + np.testing.assert_array_equal(double_row(first_row), 2.0 * first_row) + np.testing.assert_array_equal(double_col(first_row), 2.0 * first_row) + np.testing.assert_array_equal(double_row(first_col), 2.0 * first_col) + np.testing.assert_array_equal(double_col(first_col), 2.0 * first_col) counting_3d = np.arange(27.0, dtype=np.float32).reshape((3, 3, 3)) slices = [counting_3d[0, :, :], counting_3d[:, 0, :], counting_3d[:, :, 0]] for slice_idx, ref_mat in enumerate(slices): - assert np.array_equal(double_mat_cm(ref_mat), 2.0 * ref_mat) - assert np.array_equal(double_mat_rm(ref_mat), 2.0 * ref_mat) + np.testing.assert_array_equal(double_mat_cm(ref_mat), 2.0 * ref_mat) + np.testing.assert_array_equal(double_mat_rm(ref_mat), 2.0 * ref_mat) @pytest.requires_eigen_and_numpy