TensorContractionThreadPool.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
12 
13 // evaluator for thread pool device
14 #ifdef EIGEN_USE_THREADS
15 
16 namespace Eigen {
17 namespace internal {
18 
19 template<typename LhsScalar, typename LhsMapper, typename Index>
20 struct packLhsArg {
21  LhsScalar* blockA;
22  const LhsMapper& lhs;
23  const Index m_start;
24  const Index k_start;
25  const Index mc;
26  const Index kc;
27 };
28 
29 template<typename LhsScalar, typename RhsScalar, typename RhsMapper, typename OutputMapper, typename Index>
30 struct packRhsAndKernelArg {
31  const std::vector<LhsScalar*>* blockAs;
32  RhsScalar* blockB;
33  const RhsMapper& rhs;
34  OutputMapper& output;
35  const Index m;
36  const Index k;
37  const Index n;
38  const Index mc;
39  const Index kc;
40  const Index nc;
41  const Index num_threads;
42  const Index num_blockAs;
43  const Index max_m;
44  const Index k_block_idx;
45  const Index m_block_idx;
46  const Index n_block_idx;
47  const Index m_blocks;
48  const Index n_blocks;
49  std::vector<Notification*>* kernel_notifications;
50  const std::vector<Notification*>* lhs_notifications;
51  const bool need_to_pack;
52 };
53 
54 } // end namespace internal
55 
56 
57 template<typename Indices, typename LeftArgType, typename RightArgType>
58 struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
59  public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
60 
61  typedef ThreadPoolDevice Device;
62 
63  typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
64  typedef TensorContractionEvaluatorBase<Self> Base;
65 
66  typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
67  typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
68  typedef typename XprType::Packet Packet;
69  typedef typename XprType::Index Index;
70  typedef typename XprType::CoeffReturnType CoeffReturnType;
71  typedef typename XprType::PacketReturnType PacketReturnType;
72 
73  enum {
74  Layout = TensorEvaluator<LeftArgType, Device>::Layout,
75  };
76 
77  // Most of the code is assuming that both input tensors are ColMajor. If the
78  // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
79  // If we want to compute A * B = C, where A is LHS and B is RHS, the code
80  // will pretend B is LHS and A is RHS.
81  typedef typename internal::conditional<
82  static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
83  typedef typename internal::conditional<
84  static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
85 
86  static const int LDims =
87  internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
88  static const int RDims =
89  internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
90  static const int ContractDims = internal::array_size<Indices>::value;
91 
92  typedef array<Index, LDims> left_dim_mapper_t;
93  typedef array<Index, RDims> right_dim_mapper_t;
94 
95  typedef array<Index, ContractDims> contract_t;
96  typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
97  typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
98 
99  static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
100 
101  typedef DSizes<Index, NumDims> Dimensions;
102 
103  // typedefs needed in evalTo
104  typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
105  typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
106  typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
107 
108  typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
109  typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
110 
111  TensorEvaluator(const XprType& op, const Device& device) :
112  Base(op, device) {}
113 
114  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
115  void evalProduct(Scalar* buffer) const {
116  if (this->m_j_size == 1) {
117  this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
118  return;
119  }
120 
121  evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
122  }
123 
124  template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
125  void evalGemm(Scalar* buffer) const {
126  // columns in left side, rows in right side
127  const Index k = this->m_k_size;
128 
129  // rows in left side
130  const Index m = this->m_i_size;
131 
132  // columns in right side
133  const Index n = this->m_j_size;
134 
135  // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
136  this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
137 
138 
139  const int lhs_packet_size = internal::packet_traits<LhsScalar>::size;
140  const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
141 
142  typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
143  LeftEvaluator, left_nocontract_t,
144  contract_t, lhs_packet_size,
145  lhs_inner_dim_contiguous,
146  false, Unaligned> LhsMapper;
147 
148  typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
149  RightEvaluator, right_nocontract_t,
150  contract_t, rhs_packet_size,
151  rhs_inner_dim_contiguous,
152  rhs_inner_dim_reordered, Unaligned> RhsMapper;
153 
154  typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
155 
156  // TODO: packing could be faster sometimes if we supported row major tensor mappers
157  typedef internal::gemm_pack_lhs<LhsScalar, Index, typename LhsMapper::SubMapper, Traits::mr,
158  Traits::LhsProgress, ColMajor> LhsPacker;
159  typedef internal::gemm_pack_rhs<RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> RhsPacker;
160 
161  // TODO: replace false, false with conjugate values?
162  typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
163  Traits::mr, Traits::nr, false, false> GebpKernel;
164 
165  typedef internal::packLhsArg<LhsScalar, LhsMapper, Index> packLArg;
166  typedef internal::packRhsAndKernelArg<LhsScalar, RhsScalar, RhsMapper, OutputMapper, Index> packRKArg;
167 
168  // initialize data mappers
169  LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
170  this->m_left_contracting_strides, this->m_k_strides);
171 
172  RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
173  this->m_right_contracting_strides, this->m_k_strides);
174 
175  OutputMapper output(buffer, m);
176 
177  // compute block sizes (which depend on number of threads)
178  const Index num_threads = this->m_device.numThreads();
179  Index mc = m;
180  Index nc = n;
181  Index kc = k;
182  internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc, num_threads);
183  eigen_assert(mc <= m);
184  eigen_assert(nc <= n);
185  eigen_assert(kc <= k);
186 
187 #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
188  const Index k_blocks = CEIL_DIV(k, kc);
189  const Index n_blocks = CEIL_DIV(n, nc);
190  const Index m_blocks = CEIL_DIV(m, mc);
191  const Index sizeA = mc * kc;
192  const Index sizeB = kc * nc;
193 
194  /* cout << "m: " << m << " n: " << n << " k: " << k << endl;
195  cout << "mc: " << mc << " nc: " << nc << " kc: " << kc << endl;
196  cout << "m_blocks: " << m_blocks << " n_blocks: " << n_blocks << " k_blocks: " << k_blocks << endl;
197  cout << "num threads: " << num_threads << endl;
198  */
199 
200  // note: m_device.allocate should return 16 byte aligned pointers, but if blockA and blockB
201  // aren't 16 byte aligned segfaults will happen due to SIMD instructions
202  // note: You can get away with allocating just a single blockA and offsets and meet the
203  // the alignment requirements with the assumption that
204  // (Traits::mr * sizeof(ResScalar)) % 16 == 0
205  const Index numBlockAs = numext::mini(num_threads, m_blocks);
206  std::vector<LhsScalar *> blockAs;
207  blockAs.reserve(num_threads);
208  for (int i = 0; i < num_threads; i++) {
209  blockAs.push_back(static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))));
210  }
211 
212  // To circumvent alignment issues, I'm just going to separately allocate the memory for each thread
213  // TODO: is this too much memory to allocate? This simplifies coding a lot, but is wasteful.
214  // Other options: (1) reuse memory when a thread finishes. con: tricky
215  // (2) allocate block B memory in each thread. con: overhead
216  std::vector<RhsScalar *> blockBs;
217  blockBs.reserve(n_blocks);
218  for (int i = 0; i < n_blocks; i++) {
219  blockBs.push_back(static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))));
220  }
221 
222  // lhs_notifications starts with all null Notifications
223  std::vector<Notification*> lhs_notifications(num_threads, nullptr);
224 
225  // this should really be numBlockAs * n_blocks;
226  const Index num_kernel_notifications = num_threads * n_blocks;
227  std::vector<Notification*> kernel_notifications(num_kernel_notifications,
228  nullptr);
229 
230  for (Index k_block_idx = 0; k_block_idx < k_blocks; k_block_idx++) {
231  const Index k_start = k_block_idx * kc;
232  // make sure we don't overshoot right edge of left matrix
233  const Index actual_kc = numext::mini(k_start + kc, k) - k_start;
234 
235  for (Index m_block_idx = 0; m_block_idx < m_blocks; m_block_idx += numBlockAs) {
236  const Index num_blocks = numext::mini(m_blocks-m_block_idx, numBlockAs);
237 
238  for (Index mt_block_idx = m_block_idx; mt_block_idx < m_block_idx+num_blocks; mt_block_idx++) {
239  const Index m_start = mt_block_idx * mc;
240  const Index actual_mc = numext::mini(m_start + mc, m) - m_start;
241  eigen_assert(actual_mc > 0);
242 
243  Index blockAId = (k_block_idx * m_blocks + mt_block_idx) % num_threads;
244 
245  for (int i = 0; i < n_blocks; ++i) {
246  Index notification_id = (blockAId * n_blocks + i);
247  // Wait for any current kernels using this slot to complete
248  // before using it.
249  if (kernel_notifications[notification_id]) {
250  wait_until_ready(kernel_notifications[notification_id]);
251  delete kernel_notifications[notification_id];
252  }
253  kernel_notifications[notification_id] = new Notification();
254  }
255  const packLArg arg = {
256  blockAs[blockAId], // blockA
257  lhs, // lhs
258  m_start, // m
259  k_start, // k
260  actual_mc, // mc
261  actual_kc, // kc
262  };
263 
264  // Delete any existing notification since we may be
265  // replacing it. The algorithm should ensure that there are
266  // no existing waiters on this notification.
267  delete lhs_notifications[blockAId];
268  lhs_notifications[blockAId] =
269  this->m_device.enqueue(&Self::packLhs<packLArg, LhsPacker>, arg);
270  }
271 
272  // now start kernels.
273  const Index m_base_start = m_block_idx * mc;
274  const bool need_to_pack = m_block_idx == 0;
275 
276  for (Index n_block_idx = 0; n_block_idx < n_blocks; n_block_idx++) {
277  const Index n_start = n_block_idx * nc;
278  const Index actual_nc = numext::mini(n_start + nc, n) - n_start;
279 
280  // first make sure the previous kernels are all done before overwriting rhs. Also wait if
281  // we're going to start new k. In both cases need_to_pack is true.
282  if (need_to_pack) {
283  for (Index i = num_blocks; i < num_threads; ++i) {
284  Index blockAId = (k_block_idx * m_blocks + i + m_block_idx) % num_threads;
285  Index future_id = (blockAId * n_blocks + n_block_idx);
286  wait_until_ready(kernel_notifications[future_id]);
287  }
288  }
289 
290  packRKArg arg = {
291  &blockAs, // blockA
292  blockBs[n_block_idx], // blockB
293  rhs, // rhs
294  output, // output
295  m_base_start, // m
296  k_start, // k
297  n_start, // n
298  mc, // mc
299  actual_kc, // kc
300  actual_nc, // nc
301  num_threads,
302  numBlockAs,
303  m,
304  k_block_idx,
305  m_block_idx,
306  n_block_idx, // n_block_idx
307  m_blocks, // m_blocks
308  n_blocks, // n_blocks
309  &kernel_notifications, // kernel notifications
310  &lhs_notifications, // lhs notifications
311  need_to_pack, // need_to_pack
312  };
313 
314  // We asynchronously kick off this function, which ends up
315  // notifying the appropriate kernel_notifications objects,
316  // which this thread waits on before exiting.
317  this->m_device.enqueueNoNotification(&Self::packRhsAndKernel<packRKArg, RhsPacker, GebpKernel>, arg);
318  }
319  }
320  }
321 
322  // Make sure all the kernels are done.
323  for (size_t i = 0; i < kernel_notifications.size(); ++i) {
324  wait_until_ready(kernel_notifications[i]);
325  delete kernel_notifications[i];
326  }
327 
328  // No need to wait for lhs notifications since they should have
329  // already been waited on. Just clean them up.
330  for (size_t i = 0; i < lhs_notifications.size(); ++i) {
331  delete lhs_notifications[i];
332  }
333 
334  // deallocate all of the memory for both A and B's
335  for (size_t i = 0; i < blockAs.size(); i++) {
336  this->m_device.deallocate(blockAs[i]);
337  }
338  for (size_t i = 0; i < blockBs.size(); i++) {
339  this->m_device.deallocate(blockBs[i]);
340  }
341 
342 #undef CEIL_DIV
343  }
344 
345  /*
346  * Packs a LHS block of size (mt, kc) starting at lhs(m, k). Before packing
347  * the LHS block, check that all of the kernels that worked on the same
348  * mt_block_idx in the previous m_block are done.
349  */
350  template <typename packLArg, typename LhsPacker>
351  static void packLhs(const packLArg arg) {
352  // perform actual packing
353  LhsPacker pack_lhs;
354  pack_lhs(arg.blockA, arg.lhs.getSubMapper(arg.m_start, arg.k_start), arg.kc, arg.mc);
355  }
356 
357  /*
358  * Packs a RHS block of size (kc, nc) starting at (k, n) after checking that
359  * all kernels in the previous block are done.
360  * Then for each LHS future, we wait on the future and then call GEBP
361  * on the area packed by the future (which starts at
362  * blockA + future_idx * mt * kc) on the LHS and with the full packed
363  * RHS block.
364  * The output of this GEBP is written to output(m + i * mt, n).
365  */
366  template <typename packRKArg, typename RhsPacker, typename GebpKernel>
367  static void packRhsAndKernel(packRKArg arg) {
368  if (arg.need_to_pack) {
369  RhsPacker pack_rhs;
370  pack_rhs(arg.blockB, arg.rhs.getSubMapper(arg.k, arg.n), arg.kc, arg.nc);
371  }
372 
373  GebpKernel gebp;
374  for (Index mt_block_idx = 0; mt_block_idx < arg.num_blockAs; mt_block_idx++) {
375  const Index m_base_start = arg.m + arg.mc*mt_block_idx;
376  if (m_base_start < arg.max_m) {
377  Index blockAId = (arg.k_block_idx * arg.m_blocks + mt_block_idx + arg.m_block_idx) % arg.num_threads;
378  wait_until_ready((*arg.lhs_notifications)[blockAId]);
379  const Index actual_mc = numext::mini(m_base_start + arg.mc, arg.max_m) - m_base_start;
380  gebp(arg.output.getSubMapper(m_base_start, arg.n),
381  (*arg.blockAs)[blockAId], arg.blockB,
382  actual_mc, arg.kc, arg.nc, 1.0, -1, -1, 0, 0);
383 
384  // Notify that the kernel is done.
385  const Index set_idx = blockAId * arg.n_blocks + arg.n_block_idx;
386  (*arg.kernel_notifications)[set_idx]->Notify();
387  }
388  }
389  }
390 };
391 
392 } // end namespace Eigen
393 
394 #endif // EIGEN_USE_THREADS
395 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_THREAD_POOL_H
Namespace containing all symbols from the Eigen library.
Definition: CXX11Meta.h:13