dispenso 1.4.1
A library for task parallelism
Loading...
Searching...
No Matches
parallel_for.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
14#pragma once
15
16#include <cmath>
17#include <limits>
18#include <memory>
19
20#include <dispenso/detail/can_invoke.h>
21#include <dispenso/detail/per_thread_info.h>
22#include <dispenso/task_set.h>
23
24namespace dispenso {
25
34enum class ParForChunking { kStatic, kAuto };
35
45 uint32_t maxThreads = std::numeric_limits<int32_t>::max();
54 bool wait = true;
55
60 ParForChunking defaultChunking = ParForChunking::kStatic;
61
67 uint32_t minItemsPerChunk = 1;
68
74 bool reuseExistingState = false;
75};
76
87template <typename IntegerT = ssize_t>
89 // We need to utilize 64-bit integers to avoid overflow, e.g. passing -2**30, 2**30 as int32 will
90 // result in overflow unless we cast to 64-bit. Note that if we have a range of e.g. -2**63+1 to
91 // 2**63-1, we cannot hold the result in an int64_t. We could in a uint64_t, but it is quite
92 // tricky to make this work. However, I do not expect ranges larger than can be held in int64_t
93 // since people want their computations to finish before the heat death of the sun (slight
94 // exaggeration).
95 using size_type = std::conditional_t<std::is_signed<IntegerT>::value, int64_t, uint64_t>;
96
97 struct Static {};
98 struct Auto {};
99 static constexpr IntegerT kStatic = std::numeric_limits<IntegerT>::max();
100
108 ChunkedRange(IntegerT s, IntegerT e, IntegerT c) : start(s), end(e), chunk(c) {}
115 ChunkedRange(IntegerT s, IntegerT e, Static) : ChunkedRange(s, e, kStatic) {}
123 ChunkedRange(IntegerT s, IntegerT e, Auto) : ChunkedRange(s, e, 0) {}
124
125 bool isStatic() const {
126 return chunk == kStatic;
127 }
128
129 bool isAuto() const {
130 return chunk == 0;
131 }
132
133 bool empty() const {
134 return end <= start;
135 }
136
137 size_type size() const {
138 return static_cast<size_type>(end) - start;
139 }
140
141 template <typename OtherInt>
142 std::tuple<size_type, size_type>
143 calcChunkSize(OtherInt numLaunched, bool oneOnCaller, size_type minChunkSize) const {
144 size_type workingThreads = static_cast<size_type>(numLaunched) + size_type{oneOnCaller};
145 assert(workingThreads > 0);
146
147 if (!chunk) {
148 size_type dynFactor = std::min<size_type>(16, size() / workingThreads);
149 size_type chunkSize;
150 do {
151 size_type roughChunks = dynFactor * workingThreads;
152 chunkSize = (size() + roughChunks - 1) / roughChunks;
153 --dynFactor;
154 } while (chunkSize < minChunkSize);
155 return {chunkSize, (size() + chunkSize - 1) / chunkSize};
156 } else if (chunk == kStatic) {
157 // This should never be called. The static distribution versions of the parallel_for
158 // functions should be invoked instead.
159 std::abort();
160 }
161 return {chunk, (size() + chunk - 1) / chunk};
162 }
163
164 IntegerT start;
165 IntegerT end;
166 IntegerT chunk;
167};
168
176template <typename IntegerA, typename IntegerB>
177inline ChunkedRange<std::common_type_t<IntegerA, IntegerB>>
178makeChunkedRange(IntegerA start, IntegerB end, ParForChunking chunking = ParForChunking::kStatic) {
179 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
180 return (chunking == ParForChunking::kStatic)
183}
184
192template <typename IntegerA, typename IntegerB, typename IntegerC>
193inline ChunkedRange<std::common_type_t<IntegerA, IntegerB>>
194makeChunkedRange(IntegerA start, IntegerB end, IntegerC chunkSize) {
195 return ChunkedRange<std::common_type_t<IntegerA, IntegerB>>(start, end, chunkSize);
196}
197
198namespace detail {
199
200struct NoOpIter {
201 int& operator*() const {
202 static int i = 0;
203 return i;
204 }
205 NoOpIter& operator++() {
206 return *this;
207 }
208 NoOpIter operator++(int) {
209 return *this;
210 }
211};
212
213struct NoOpContainer {
214 size_t size() const {
215 return 0;
216 }
217
218 bool empty() const {
219 return true;
220 }
221
222 void clear() {}
223
224 NoOpIter begin() {
225 return {};
226 }
227
228 void emplace_back(int) {}
229
230 int& front() {
231 static int i;
232 return i;
233 }
234};
235
236struct NoOpStateGen {
237 int operator()() const {
238 return 0;
239 }
240};
241
242template <
243 typename TaskSetT,
244 typename IntegerT,
245 typename F,
246 typename StateContainer,
247 typename StateGen>
248void parallel_for_staticImpl(
249 TaskSetT& taskSet,
250 StateContainer& states,
251 const StateGen& defaultState,
252 const ChunkedRange<IntegerT>& range,
253 F&& f,
254 ssize_t maxThreads,
255 bool wait,
256 bool reuseExistingState) {
257 using size_type = typename ChunkedRange<IntegerT>::size_type;
258
259 size_type numThreads = std::min<size_type>(taskSet.numPoolThreads() + wait, maxThreads);
260 // Reduce threads used if they exceed work to be done.
261 numThreads = std::min(numThreads, range.size());
262
263 if (!reuseExistingState) {
264 states.clear();
265 }
266
267 size_t numToEmplace = states.size() < static_cast<size_t>(numThreads)
268 ? static_cast<size_t>(numThreads) - states.size()
269 : 0;
270
271 for (; numToEmplace--;) {
272 states.emplace_back(defaultState());
273 }
274
275 auto chunking =
276 detail::staticChunkSize(static_cast<ssize_t>(range.size()), static_cast<ssize_t>(numThreads));
277 IntegerT chunkSize = static_cast<IntegerT>(chunking.ceilChunkSize);
278
279 bool perfectlyChunked = static_cast<size_type>(chunking.transitionTaskIndex) == numThreads;
280
281 // (!perfectlyChunked) ? chunking.transitionTaskIndex : numThreads - 1;
282 size_type firstLoopLen = chunking.transitionTaskIndex - perfectlyChunked;
283
284 auto stateIt = states.begin();
285 IntegerT start = range.start;
286 size_type t;
287 for (t = 0; t < firstLoopLen; ++t) {
288 IntegerT next = static_cast<IntegerT>(start + chunkSize);
289 taskSet.schedule([it = stateIt++, start, next, f]() {
290 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
291 f(*it, start, next);
292 });
293 start = next;
294 }
295
296 // Reduce the remaining chunk sizes by 1.
297 chunkSize = static_cast<IntegerT>(chunkSize - !perfectlyChunked);
298 // Finish submitting all but the last item.
299 for (; t < numThreads - 1; ++t) {
300 IntegerT next = static_cast<IntegerT>(start + chunkSize);
301 taskSet.schedule([it = stateIt++, start, next, f]() {
302 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
303 f(*it, start, next);
304 });
305 start = next;
306 }
307
308 if (wait) {
309 f(*stateIt, start, range.end);
310 taskSet.wait();
311 } else {
312 taskSet.schedule(
313 [stateIt, start, end = range.end, f]() {
314 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
315 f(*stateIt, start, end);
316 },
317 ForceQueuingTag());
318 }
319}
320
321} // namespace detail
322
339template <
340 typename TaskSetT,
341 typename IntegerT,
342 typename F,
343 typename StateContainer,
344 typename StateGen>
346 TaskSetT& taskSet,
347 StateContainer& states,
348 const StateGen& defaultState,
349 const ChunkedRange<IntegerT>& range,
350 F&& f,
351 ParForOptions options = {}) {
352 if (range.empty()) {
353 if (options.wait) {
354 taskSet.wait();
355 }
356 return;
357 }
358
359 using size_type = typename ChunkedRange<IntegerT>::size_type;
360
361 // Ensure minItemsPerChunk is sane
362 uint32_t minItemsPerChunk = std::max<uint32_t>(1, options.minItemsPerChunk);
363
364 // 0 indicates serial execution per API spec
365 size_type maxThreads = std::max<int32_t>(options.maxThreads, 1);
366
367 bool isStatic = range.isStatic();
368
369 const size_type N = taskSet.numPoolThreads();
370 if (N == 0 || !options.maxThreads || range.size() <= minItemsPerChunk ||
371 detail::PerPoolPerThreadInfo::isParForRecursive(&taskSet.pool())) {
372 if (!options.reuseExistingState) {
373 states.clear();
374 }
375 if (states.empty()) {
376 states.emplace_back(defaultState());
377 }
378 f(*states.begin(), range.start, range.end);
379 if (options.wait) {
380 taskSet.wait();
381 }
382 return;
383 }
384
385 // Adjust down workers if we would have too-small chunks
386 if (minItemsPerChunk > 1) {
387 size_type maxWorkers = range.size() / minItemsPerChunk;
388 if (maxWorkers < maxThreads) {
389 maxThreads = static_cast<uint32_t>(maxWorkers);
390 }
391 if (range.size() / (maxThreads + options.wait) < minItemsPerChunk && range.isAuto()) {
392 isStatic = true;
393 }
394 } else if (range.size() <= N + options.wait) {
395 if (range.isAuto()) {
396 isStatic = true;
397 } else if (!range.isStatic()) {
398 maxThreads = range.size() - options.wait;
399 }
400 }
401
402 if (isStatic) {
403 detail::parallel_for_staticImpl(
404 taskSet,
405 states,
406 defaultState,
407 range,
408 std::forward<F>(f),
409 static_cast<ssize_t>(maxThreads),
410 options.wait,
411 options.reuseExistingState);
412 return;
413 }
414
415 // wanting maxThreads workers (potentially including the calling thread), capped by N
416 const size_type numToLaunch = std::min<size_type>(maxThreads - options.wait, N);
417
418 if (!options.reuseExistingState) {
419 states.clear();
420 }
421
422 size_t numToEmplace = static_cast<size_type>(states.size()) < (numToLaunch + options.wait)
423 ? (static_cast<size_t>(numToLaunch) + options.wait) - states.size()
424 : 0;
425 for (; numToEmplace--;) {
426 states.emplace_back(defaultState());
427 }
428
429 if (numToLaunch == 1 && !options.wait) {
430 taskSet.schedule(
431 [&s = states.front(), range, f = std::move(f)]() { f(s, range.start, range.end); });
432
433 return;
434 }
435
436 auto chunkInfo = range.calcChunkSize(numToLaunch, options.wait, minItemsPerChunk);
437 auto chunkSize = std::get<0>(chunkInfo);
438 auto numChunks = std::get<1>(chunkInfo);
439
440 if (options.wait) {
441 alignas(kCacheLineSize) std::atomic<decltype(numChunks)> index(0);
442 auto worker = [start = range.start, end = range.end, &index, f, chunkSize, numChunks](auto& s) {
443 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
444
445 while (true) {
446 auto cur = index.fetch_add(1, std::memory_order_relaxed);
447 if (cur >= numChunks) {
448 break;
449 }
450 auto sidx = static_cast<IntegerT>(start + cur * chunkSize);
451 if (cur + 1 == numChunks) {
452 f(s, sidx, end);
453 } else {
454 auto eidx = static_cast<IntegerT>(sidx + chunkSize);
455 f(s, sidx, eidx);
456 }
457 }
458 };
459
460 auto it = states.begin();
461 for (size_type i = 0; i < numToLaunch; ++i) {
462 taskSet.schedule([&s = *it++, worker]() { worker(s); });
463 }
464 worker(*it);
465 taskSet.wait();
466 } else {
467 struct Atomic {
468 Atomic() : index(0) {}
469 alignas(kCacheLineSize) std::atomic<decltype(numChunks)> index;
470 char buffer[kCacheLineSize - sizeof(index)];
471 };
472
473 void* ptr = detail::alignedMalloc(sizeof(Atomic), alignof(Atomic));
474 auto* atm = new (ptr) Atomic();
475
476 std::shared_ptr<Atomic> wrapper(atm, detail::AlignedFreeDeleter<Atomic>());
477 auto worker = [start = range.start,
478 end = range.end,
479 wrapper = std::move(wrapper),
480 f,
481 chunkSize,
482 numChunks](auto& s) {
483 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
484 while (true) {
485 auto cur = wrapper->index.fetch_add(1, std::memory_order_relaxed);
486 if (cur >= numChunks) {
487 break;
488 }
489 auto sidx = static_cast<IntegerT>(start + cur * chunkSize);
490 if (cur + 1 == numChunks) {
491 f(s, sidx, end);
492 } else {
493 auto eidx = static_cast<IntegerT>(sidx + chunkSize);
494 f(s, sidx, eidx);
495 }
496 }
497 };
498
499 auto it = states.begin();
500 for (size_type i = 0; i < numToLaunch; ++i) {
501 taskSet.schedule([&s = *it++, worker]() { worker(s); }, ForceQueuingTag());
502 }
503 }
504}
505
515template <typename TaskSetT, typename IntegerT, typename F>
517 TaskSetT& taskSet,
518 const ChunkedRange<IntegerT>& range,
519 F&& f,
520 ParForOptions options = {}) {
521 detail::NoOpContainer container;
522 parallel_for(
523 taskSet,
524 container,
525 detail::NoOpStateGen(),
526 range,
527 [f = std::move(f)](int /*noop*/, auto i, auto j) { f(i, j); },
528 options);
529}
530
540template <typename IntegerT, typename F>
541void parallel_for(const ChunkedRange<IntegerT>& range, F&& f, ParForOptions options = {}) {
542 TaskSet taskSet(globalThreadPool());
543 options.wait = true;
544 parallel_for(taskSet, range, std::forward<F>(f), options);
545}
546
564template <typename F, typename IntegerT, typename StateContainer, typename StateGen>
566 StateContainer& states,
567 const StateGen& defaultState,
568 const ChunkedRange<IntegerT>& range,
569 F&& f,
570 ParForOptions options = {}) {
571 TaskSet taskSet(globalThreadPool());
572 options.wait = true;
573 parallel_for(taskSet, states, defaultState, range, std::forward<F>(f), options);
574}
575
586template <
587 typename TaskSetT,
588 typename IntegerA,
589 typename IntegerB,
590 typename F,
591 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
592 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
593 std::enable_if_t<detail::CanInvoke<F(IntegerA)>::value, bool> = true>
595 TaskSetT& taskSet,
596 IntegerA start,
597 IntegerB end,
598 F&& f,
599 ParForOptions options = {}) {
600 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
601
602 auto range = makeChunkedRange(start, end, options.defaultChunking);
603 parallel_for(
604 taskSet,
605 range,
606 [f = std::move(f)](IntegerT s, IntegerT e) {
607 for (IntegerT i = s; i < e; ++i) {
608 f(i);
609 }
610 },
611 options);
612}
613
614template <
615 typename TaskSetT,
616 typename IntegerA,
617 typename IntegerB,
618 typename F,
619 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
620 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
621 std::enable_if_t<detail::CanInvoke<F(IntegerA, IntegerB)>::value, bool> = true>
622void parallel_for(
623 TaskSetT& taskSet,
624 IntegerA start,
625 IntegerB end,
626 F&& f,
627 ParForOptions options = {}) {
628 auto range = makeChunkedRange(start, end, options.defaultChunking);
629 parallel_for(taskSet, range, std::forward<F>(f), options);
630}
631
642template <
643 typename IntegerA,
644 typename IntegerB,
645 typename F,
646 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
647 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true>
648void parallel_for(IntegerA start, IntegerB end, F&& f, ParForOptions options = {}) {
649 TaskSet taskSet(globalThreadPool());
650 options.wait = true;
651 parallel_for(taskSet, start, end, std::forward<F>(f), options);
652}
653
672template <
673 typename TaskSetT,
674 typename IntegerA,
675 typename IntegerB,
676 typename F,
677 typename StateContainer,
678 typename StateGen,
679 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
680 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
681 std::enable_if_t<
682 detail::CanInvoke<F(typename StateContainer::reference, IntegerA)>::value,
683 bool> = true>
685 TaskSetT& taskSet,
686 StateContainer& states,
687 const StateGen& defaultState,
688 IntegerA start,
689 IntegerB end,
690 F&& f,
691 ParForOptions options = {}) {
692 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
693 auto range = makeChunkedRange(start, end, options.defaultChunking);
694 parallel_for(
695 taskSet,
696 states,
697 defaultState,
698 range,
699 [f = std::move(f)](auto& state, IntegerT s, IntegerT e) {
700 for (IntegerT i = s; i < e; ++i) {
701 f(state, i);
702 }
703 },
704 options);
705}
706
707template <
708 typename TaskSetT,
709 typename IntegerA,
710 typename IntegerB,
711 typename F,
712 typename StateContainer,
713 typename StateGen,
714 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
715 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
716 std::enable_if_t<
717 detail::CanInvoke<F(typename StateContainer::reference, IntegerA, IntegerB)>::value,
718 bool> = true>
719void parallel_for(
720 TaskSetT& taskSet,
721 StateContainer& states,
722 const StateGen& defaultState,
723 IntegerA start,
724 IntegerB end,
725 F&& f,
726 ParForOptions options = {}) {
727 auto range = makeChunkedRange(start, end, options.defaultChunking);
728 parallel_for(taskSet, states, defaultState, range, std::forward<F>(f), options);
729}
730
750template <
751 typename IntegerA,
752 typename IntegerB,
753 typename F,
754 typename StateContainer,
755 typename StateGen,
756 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
757 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true>
759 StateContainer& states,
760 const StateGen& defaultState,
761 IntegerA start,
762 IntegerB end,
763 F&& f,
764 ParForOptions options = {}) {
765 TaskSet taskSet(globalThreadPool());
766 options.wait = true;
767 parallel_for(taskSet, states, defaultState, start, end, std::forward<F>(f), options);
768}
769
770} // namespace dispenso
void parallel_for(TaskSetT &taskSet, StateContainer &states, const StateGen &defaultState, const ChunkedRange< IntegerT > &range, F &&f, ParForOptions options={})
ChunkedRange< std::common_type_t< IntegerA, IntegerB > > makeChunkedRange(IntegerA start, IntegerB end, ParForChunking chunking=ParForChunking::kStatic)
constexpr size_t kCacheLineSize
A constant that defines a safe number of bytes+alignment to avoid false sharing.
Definition platform.h:61
ChunkedRange(IntegerT s, IntegerT e, Auto)
ChunkedRange(IntegerT s, IntegerT e, IntegerT c)
ChunkedRange(IntegerT s, IntegerT e, Static)
ParForChunking defaultChunking