10 #ifndef THYRA_TPETRA_MULTIVECTOR_HPP
11 #define THYRA_TPETRA_MULTIVECTOR_HPP
13 #include "Thyra_TpetraMultiVector_decl.hpp"
14 #include "Thyra_TpetraVectorSpace.hpp"
15 #include "Thyra_TpetraVector.hpp"
16 #include "Teuchos_Assert.hpp"
17 #include "Kokkos_Core.hpp"
26 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
31 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
38 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
42 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
49 initializeImpl(tpetraVectorSpace, domainSpace, tpetraMultiVector);
53 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
57 return tpetraMultiVector_.getNonconstObj();
61 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
65 return tpetraMultiVector_;
72 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
83 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
87 tpetraMultiVector_.getNonconstObj()->putScalar(alpha);
91 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
95 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
100 tpetraMultiVector_.getNonconstObj()->assign(*tmv);
107 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
111 tpetraMultiVector_.getNonconstObj()->scale(alpha);
115 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
121 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
127 tpetraMultiVector_.getNonconstObj()->update(alpha, *tmv, ST::one());
134 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
149 bool allCastsSuccessful =
true;
151 auto mvIter = mv.begin();
152 auto tmvIter = tmvs.begin();
153 for (; mvIter != mv.end(); ++mvIter, ++tmvIter) {
154 tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromPtr(*mvIter));
158 allCastsSuccessful =
false;
166 auto len = tmvs.
size();
168 tpetraMultiVector_.getNonconstObj()->scale(beta);
169 }
else if (len == 1 && allCastsSuccessful) {
170 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], beta);
171 }
else if (len == 2 && allCastsSuccessful) {
172 tpetraMultiVector_.getNonconstObj()->update(alpha[0], *tmvs[0], alpha[1], *tmvs[1], beta);
173 }
else if (allCastsSuccessful) {
175 auto tmvIter = tmvs.begin();
176 auto alphaIter = alpha.
begin();
181 for (; tmvIter != tmvs.end(); ++tmvIter) {
182 if (tmvIter->getRawPtr() == tpetraMultiVector_.getConstObj().getRawPtr()) {
189 tmvIter = tmvs.
begin();
193 if ((tmvs.size() % 2) == 0) {
194 tpetraMultiVector_.getNonconstObj()->scale(beta);
196 tpetraMultiVector_.getNonconstObj()->update(*alphaIter, *(*tmvIter), beta);
200 for (; tmvIter != tmvs.end(); tmvIter+=2, alphaIter+=2) {
201 tpetraMultiVector_.getNonconstObj()->update(
202 *alphaIter, *(*tmvIter), *(alphaIter+1), *(*(tmvIter+1)), ST::one());
210 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
216 auto tmv = this->getConstTpetraMultiVector(Teuchos::rcpFromRef(mv));
221 tpetraMultiVector_.getConstObj()->dot(*tmv, prods);
228 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
233 tpetraMultiVector_.getConstObj()->norm1(norms);
237 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
242 tpetraMultiVector_.getConstObj()->norm2(norms);
246 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
251 tpetraMultiVector_.getConstObj()->normInf(norms);
255 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
262 return constTpetraVector<Scalar>(
264 tpetraMultiVector_->getVector(j)
269 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
276 return tpetraVector<Scalar>(
278 tpetraMultiVector_.getNonconstObj()->getVectorNonConst(j)
283 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
289 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
290 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) const called!\n";
292 const Range1D colRng = this->validateColRange(col_rng_in);
295 this->getConstTpetraMultiVector()->subView(colRng);
298 tpetraVectorSpace<Scalar>(
299 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
300 tpetraView->getNumVectors(),
301 tpetraView->getMap()->getComm()
305 return constTpetraMultiVector(
313 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
319 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
320 std::cerr <<
"\nTpetraMultiVector::subView(Range1D) called!\n";
322 const Range1D colRng = this->validateColRange(col_rng_in);
325 this->getTpetraMultiVector()->subViewNonConst(colRng);
328 tpetraVectorSpace<Scalar>(
329 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
330 tpetraView->getNumVectors(),
331 tpetraView->getMap()->getComm()
335 return tpetraMultiVector(
343 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
349 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
350 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) const called!\n";
355 cols[i] = static_cast<std::size_t>(cols_in[i]);
358 this->getConstTpetraMultiVector()->subView(cols());
361 tpetraVectorSpace<Scalar>(
362 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
363 tpetraView->getNumVectors(),
364 tpetraView->getMap()->getComm()
368 return constTpetraMultiVector(
376 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
382 #ifdef THYRA_DEFAULT_SPMD_MULTI_VECTOR_VERBOSE_TO_ERROR_OUT
383 std::cerr <<
"\nTpetraMultiVector::subView(ArrayView) called!\n";
388 cols[i] = static_cast<std::size_t>(cols_in[i]);
391 this->getTpetraMultiVector()->subViewNonConst(cols());
394 tpetraVectorSpace<Scalar>(
395 Tpetra::createLocalMapWithNode<LocalOrdinal,GlobalOrdinal,Node>(
396 tpetraView->getNumVectors(),
397 tpetraView->getMap()->getComm()
401 return tpetraMultiVector(
409 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
416 const Ordinal primary_global_offset
421 primary_op, multi_vecs, targ_multi_vecs, reduct_objs, primary_global_offset);
425 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
438 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
451 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
510 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
514 return tpetraVectorSpace_;
518 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
523 *localValues = tpetraMultiVector_.getNonconstObj()->get1dViewNonConst();
524 *leadingDim = tpetraMultiVector_->getStride();
528 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
533 *localValues = tpetraMultiVector_->get1dView();
534 *leadingDim = tpetraMultiVector_->getStride();
538 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
558 "Error, conjugation without transposition is not allowed for complex scalar types!");
576 Y_tpetra->multiply(trans,
Teuchos::NO_TRANS, alpha, *tpetraMultiVector_.getConstObj(), *X_tpetra, beta);
587 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
588 template<
class TpetraMultiVector_t>
602 tpetraVectorSpace_ = tpetraVectorSpace;
603 domainSpace_ = domainSpace;
604 tpetraMultiVector_.initialize(tpetraMultiVector);
605 this->updateSpmdSpace();
609 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
610 RCP<Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
614 using Teuchos::rcp_dynamic_cast;
618 RCP<TMV> tmv = rcp_dynamic_cast<TMV>(mv);
620 return tmv->getTpetraMultiVector();
623 RCP<TV> tv = rcp_dynamic_cast<TV>(mv);
625 return tv->getTpetraVector();
628 return Teuchos::null;
631 template <
class Scalar,
class LocalOrdinal,
class GlobalOrdinal,
class Node>
632 RCP<const Tpetra::MultiVector<Scalar,LocalOrdinal,GlobalOrdinal,Node> >
636 using Teuchos::rcp_dynamic_cast;
640 RCP<const TMV> tmv = rcp_dynamic_cast<
const TMV>(mv);
642 return tmv->getConstTpetraMultiVector();
645 RCP<const TV> tv = rcp_dynamic_cast<
const TV>(mv);
647 return tv->getConstTpetraVector();
650 return Teuchos::null;
657 #endif // THYRA_TPETRA_MULTIVECTOR_HPP
virtual void updateImpl(Scalar alpha, const MultiVectorBase< Scalar > &mv)
Concrete implementation of Thyra::MultiVector in terms of Tpetra::MultiVector.
TpetraMultiVector()
Construct to uninitialized.
RCP< const VectorBase< Scalar > > colImpl(Ordinal j) const
void getNonconstLocalMultiVectorDataImpl(const Ptr< ArrayRCP< Scalar > > &localValues, const Ptr< Ordinal > &leadingDim)
Concrete implementation of an SPMD vector space for Tpetra.
RCP< Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getTpetraMultiVector()
Extract the underlying non-const Tpetra::MultiVector object.
virtual void assignImpl(Scalar alpha)
EOpTransp
Enumeration for determining how a linear operator is applied. `*.
RCP< MultiVectorBase< Scalar > > nonconstContigSubViewImpl(const Range1D &colRng)
void euclideanApply(const EOpTransp M_trans, const MultiVectorBase< Scalar > &X, const Ptr< MultiVectorBase< Scalar > > &Y, const Scalar alpha, const Scalar beta) const
Uses GEMM() and Teuchos::reduceAll() to implement.
virtual void assignMultiVecImpl(const MultiVectorBase< Scalar > &mv)
Default implementation of assign(MV) using RTOps.
#define TEUCHOS_TEST_FOR_EXCEPTION(throw_exception_test, Exception, msg)
Use the non-transposed operator.
virtual void euclideanApply(const EOpTransp M_trans, const MultiVectorBase< Scalar > &X, const Ptr< MultiVectorBase< Scalar > > &Y, const Scalar alpha, const Scalar beta) const
virtual void mvMultiReductApplyOpImpl(const RTOpPack::RTOpT< Scalar > &primary_op, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &multi_vecs, const ArrayView< const Ptr< MultiVectorBase< Scalar > > > &targ_multi_vecs, const ArrayView< const Ptr< RTOpPack::ReductTarget > > &reduct_objs, const Ordinal primary_global_offset) const
virtual void linearCombinationImpl(const ArrayView< const Scalar > &alpha, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &mv, const Scalar &beta)
Default implementation of linear_combination using RTOps.
Use the transposed operator with complex-conjugate clements (same as TRANS for real scalar types)...
RCP< const MultiVectorBase< Scalar > > nonContigSubViewImpl(const ArrayView< const int > &cols_in) const
Use the non-transposed operator with complex-conjugate elements (same as NOTRANS for real scalar type...
void initialize(const RCP< const TpetraVectorSpace< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraVectorSpace, const RCP< const ScalarProdVectorSpaceBase< Scalar > > &domainSpace, const RCP< Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraMultiVector)
Initialize.
Use the transposed operator.
#define TEUCHOS_ASSERT_IN_RANGE_UPPER_EXCLUSIVE(index, lower_inclusive, upper_exclusive)
virtual void scaleImpl(Scalar alpha)
TEUCHOS_DEPRECATED RCP< T > rcp(T *p, Dealloc_T dealloc, bool owns_mem)
Teuchos::Ordinal Ordinal
Type for the dimension of a vector space. `*.
RCP< const MultiVectorBase< Scalar > > contigSubViewImpl(const Range1D &colRng) const
virtual void mvMultiReductApplyOpImpl(const RTOpPack::RTOpT< Scalar > &primary_op, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &multi_vecs, const ArrayView< const Ptr< MultiVectorBase< Scalar > > > &targ_multi_vecs, const ArrayView< const Ptr< RTOpPack::ReductTarget > > &reduct_objs, const Ordinal primary_global_offset) const
Interface for a collection of column vectors called a multi-vector.
RCP< MultiVectorBase< Scalar > > nonconstNonContigSubViewImpl(const ArrayView< const int > &cols_in)
void constInitialize(const RCP< const TpetraVectorSpace< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraVectorSpace, const RCP< const ScalarProdVectorSpaceBase< Scalar > > &domainSpace, const RCP< const Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > &tpetraMultiVector)
Initialize.
virtual void norms2Impl(const ArrayView< typename ScalarTraits< Scalar >::magnitudeType > &norms) const
virtual void linearCombinationImpl(const ArrayView< const Scalar > &alpha, const ArrayView< const Ptr< const MultiVectorBase< Scalar > > > &mv, const Scalar &beta)
Concrete Thyra::SpmdVectorBase using Tpetra::Vector.
void acquireDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::ConstSubMultiVectorView< Scalar > *sub_mv) const
virtual void dotsImpl(const MultiVectorBase< Scalar > &mv, const ArrayView< Scalar > &prods) const
RCP< VectorBase< Scalar > > nonconstColImpl(Ordinal j)
void commitNonconstDetachedMultiVectorViewImpl(RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
bool nonnull(const boost::shared_ptr< T > &p)
void getLocalMultiVectorDataImpl(const Ptr< ArrayRCP< const Scalar > > &localValues, const Ptr< Ordinal > &leadingDim) const
virtual void assignMultiVecImpl(const MultiVectorBase< Scalar > &mv)
virtual void norms1Impl(const ArrayView< typename ScalarTraits< Scalar >::magnitudeType > &norms) const
RCP< const Tpetra::MultiVector< Scalar, LocalOrdinal, GlobalOrdinal, Node > > getConstTpetraMultiVector() const
Extract the underlying const Tpetra::MultiVector object.
RCP< const SpmdVectorSpaceBase< Scalar > > spmdSpaceImpl() const
#define TEUCHOS_ASSERT(assertion_test)
#define TEUCHOS_ASSERT_EQUALITY(val1, val2)
void acquireNonconstDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
virtual void normsInfImpl(const ArrayView< typename ScalarTraits< Scalar >::magnitudeType > &norms) const
RCP< const ScalarProdVectorSpaceBase< Scalar > > domainScalarProdVecSpc() const
void acquireNonconstDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
void acquireDetachedMultiVectorViewImpl(const Range1D &rowRng, const Range1D &colRng, RTOpPack::ConstSubMultiVectorView< Scalar > *sub_mv) const
virtual void dotsImpl(const MultiVectorBase< Scalar > &mv, const ArrayView< Scalar > &prods) const
Default implementation of dots using RTOps.
void commitNonconstDetachedMultiVectorViewImpl(RTOpPack::SubMultiVectorView< Scalar > *sub_mv)
virtual void updateImpl(Scalar alpha, const MultiVectorBase< Scalar > &mv)
Default implementation of update using RTOps.