TensorEvalTo.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_EVAL_TO_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_EVAL_TO_H
12 
13 namespace Eigen {
14 
22 namespace internal {
23 template<typename XprType>
24 struct traits<TensorEvalToOp<XprType> >
25 {
26  // Type promotion to handle the case where the types of the lhs and the rhs are different.
27  typedef typename XprType::Scalar Scalar;
28  typedef traits<XprType> XprTraits;
29  typedef typename packet_traits<Scalar>::type Packet;
30  typedef typename XprTraits::StorageKind StorageKind;
31  typedef typename XprTraits::Index Index;
32  typedef typename XprType::Nested Nested;
33  typedef typename remove_reference<Nested>::type _Nested;
34  static const int NumDimensions = XprTraits::NumDimensions;
35  static const int Layout = XprTraits::Layout;
36 
37  enum {
38  Flags = 0,
39  };
40 };
41 
42 template<typename XprType>
43 struct eval<TensorEvalToOp<XprType>, Eigen::Dense>
44 {
45  typedef const TensorEvalToOp<XprType>& type;
46 };
47 
48 template<typename XprType>
49 struct nested<TensorEvalToOp<XprType>, 1, typename eval<TensorEvalToOp<XprType> >::type>
50 {
51  typedef TensorEvalToOp<XprType> type;
52 };
53 
54 } // end namespace internal
55 
56 
57 
58 
59 template<typename XprType>
60 class TensorEvalToOp : public TensorBase<TensorEvalToOp<XprType> >
61 {
62  public:
63  typedef typename Eigen::internal::traits<TensorEvalToOp>::Scalar Scalar;
64  typedef typename Eigen::internal::traits<TensorEvalToOp>::Packet Packet;
65  typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
66  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
67  typedef typename internal::remove_const<typename XprType::PacketReturnType>::type PacketReturnType;
68  typedef typename Eigen::internal::nested<TensorEvalToOp>::type Nested;
69  typedef typename Eigen::internal::traits<TensorEvalToOp>::StorageKind StorageKind;
70  typedef typename Eigen::internal::traits<TensorEvalToOp>::Index Index;
71 
72  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvalToOp(CoeffReturnType* buffer, const XprType& expr)
73  : m_xpr(expr), m_buffer(buffer) {}
74 
75  EIGEN_DEVICE_FUNC
76  const typename internal::remove_all<typename XprType::Nested>::type&
77  expression() const { return m_xpr; }
78 
79  EIGEN_DEVICE_FUNC CoeffReturnType* buffer() const { return m_buffer; }
80 
81  protected:
82  typename XprType::Nested m_xpr;
83  CoeffReturnType* m_buffer;
84 };
85 
86 
87 
88 template<typename ArgType, typename Device>
89 struct TensorEvaluator<const TensorEvalToOp<ArgType>, Device>
90 {
91  typedef TensorEvalToOp<ArgType> XprType;
92  typedef typename ArgType::Scalar Scalar;
93  typedef typename ArgType::Packet Packet;
94  typedef typename TensorEvaluator<ArgType, Device>::Dimensions Dimensions;
95 
96  enum {
97  IsAligned = true,
98  PacketAccess = true,
99  Layout = TensorEvaluator<ArgType, Device>::Layout,
100  CoordAccess = false, // to be implemented
101  };
102 
103  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
104  : m_impl(op.expression(), device), m_device(device), m_buffer(op.buffer())
105  { }
106 
107  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ~TensorEvaluator() {
108  }
109 
110  typedef typename XprType::Index Index;
111  typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
112  typedef typename internal::remove_const<typename XprType::PacketReturnType>::type PacketReturnType;
113 
114  EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_impl.dimensions(); }
115 
116  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* scalar) {
117  eigen_assert(scalar == NULL);
118  return m_impl.evalSubExprsIfNeeded(m_buffer);
119  }
120 
121  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalScalar(Index i) {
122  m_buffer[i] = m_impl.coeff(i);
123  }
124  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalPacket(Index i) {
125  internal::pstoret<CoeffReturnType, PacketReturnType, Aligned>(m_buffer + i, m_impl.template packet<TensorEvaluator<ArgType, Device>::IsAligned ? Aligned : Unaligned>(i));
126  }
127 
128  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
129  m_impl.cleanup();
130  }
131 
132  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
133  {
134  return m_buffer[index];
135  }
136 
137  template<int LoadMode>
138  EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
139  {
140  return internal::ploadt<Packet, LoadMode>(m_buffer + index);
141  }
142 
143  EIGEN_DEVICE_FUNC CoeffReturnType* data() const { return NULL; }
144 
145  private:
146  TensorEvaluator<ArgType, Device> m_impl;
147  const Device& m_device;
148  CoeffReturnType* m_buffer;
149 };
150 
151 
152 } // end namespace Eigen
153 
154 #endif // EIGEN_CXX11_TENSOR_TENSOR_EVAL_TO_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13