dispenso 1.4.1
A library for task parallelism
Loading...
Searching...
No Matches
parallel_for.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
14#pragma once
15
16#include <cmath>
17#include <limits>
18#include <memory>
19
20#include <dispenso/detail/can_invoke.h>
21#include <dispenso/detail/per_thread_info.h>
22#include <dispenso/task_set.h>
23
24namespace dispenso {
25
26#if DISPENSO_HAS_CONCEPTS
33template <typename F, typename IntegerT>
34concept ParallelForRangeFunc = std::invocable<F, IntegerT, IntegerT>;
35
42template <typename F, typename IntegerT>
43concept ParallelForIndexFunc = std::invocable<F, IntegerT>;
44
51template <typename F, typename StateRef, typename IntegerT>
52concept ParallelForStateRangeFunc = std::invocable<F, StateRef, IntegerT, IntegerT>;
53
60template <typename F, typename StateRef, typename IntegerT>
61concept ParallelForStateIndexFunc = std::invocable<F, StateRef, IntegerT>;
62#endif // DISPENSO_HAS_CONCEPTS
63
72enum class ParForChunking { kStatic, kAuto };
73
83 uint32_t maxThreads = std::numeric_limits<int32_t>::max();
92 bool wait = true;
93
98 ParForChunking defaultChunking = ParForChunking::kStatic;
99
105 uint32_t minItemsPerChunk = 1;
106
112 bool reuseExistingState = false;
113};
114
125template <typename IntegerT = ssize_t>
126struct ChunkedRange {
127 // We need to utilize 64-bit integers to avoid overflow, e.g. passing -2**30, 2**30 as int32 will
128 // result in overflow unless we cast to 64-bit. Note that if we have a range of e.g. -2**63+1 to
129 // 2**63-1, we cannot hold the result in an int64_t. We could in a uint64_t, but it is quite
130 // tricky to make this work. However, I do not expect ranges larger than can be held in int64_t
131 // since people want their computations to finish before the heat death of the sun (slight
132 // exaggeration).
133 using size_type = std::conditional_t<std::is_signed<IntegerT>::value, int64_t, uint64_t>;
134
135 struct Static {};
136 struct Auto {};
137 static constexpr IntegerT kStatic = std::numeric_limits<IntegerT>::max();
138
146 ChunkedRange(IntegerT s, IntegerT e, IntegerT c) : start(s), end(e), chunk(c) {}
153 ChunkedRange(IntegerT s, IntegerT e, Static) : ChunkedRange(s, e, kStatic) {}
161 ChunkedRange(IntegerT s, IntegerT e, Auto) : ChunkedRange(s, e, 0) {}
162
163 bool isStatic() const {
164 return chunk == kStatic;
165 }
166
167 bool isAuto() const {
168 return chunk == 0;
169 }
170
171 bool empty() const {
172 return end <= start;
173 }
174
175 size_type size() const {
176 return static_cast<size_type>(end) - start;
177 }
178
179 template <typename OtherInt>
180 std::tuple<size_type, size_type>
181 calcChunkSize(OtherInt numLaunched, bool oneOnCaller, size_type minChunkSize) const {
182 size_type workingThreads = static_cast<size_type>(numLaunched) + size_type{oneOnCaller};
183 assert(workingThreads > 0);
184
185 if (!chunk) {
186 size_type dynFactor = std::min<size_type>(16, size() / workingThreads);
187 size_type chunkSize;
188 do {
189 size_type roughChunks = dynFactor * workingThreads;
190 chunkSize = (size() + roughChunks - 1) / roughChunks;
191 --dynFactor;
192 } while (chunkSize < minChunkSize);
193 return {chunkSize, (size() + chunkSize - 1) / chunkSize};
194 } else if (chunk == kStatic) {
195 // This should never be called. The static distribution versions of the parallel_for
196 // functions should be invoked instead.
197 std::abort();
198 }
199 return {chunk, (size() + chunk - 1) / chunk};
200 }
201
202 IntegerT start;
203 IntegerT end;
204 IntegerT chunk;
205};
206
214template <typename IntegerA, typename IntegerB>
215inline ChunkedRange<std::common_type_t<IntegerA, IntegerB>>
216makeChunkedRange(IntegerA start, IntegerB end, ParForChunking chunking = ParForChunking::kStatic) {
217 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
218 return (chunking == ParForChunking::kStatic)
219 ? ChunkedRange<IntegerT>(start, end, typename ChunkedRange<IntegerT>::Static())
220 : ChunkedRange<IntegerT>(start, end, typename ChunkedRange<IntegerT>::Auto());
221}
222
230template <typename IntegerA, typename IntegerB, typename IntegerC>
231inline ChunkedRange<std::common_type_t<IntegerA, IntegerB>>
232makeChunkedRange(IntegerA start, IntegerB end, IntegerC chunkSize) {
233 return ChunkedRange<std::common_type_t<IntegerA, IntegerB>>(start, end, chunkSize);
234}
235
236namespace detail {
237
238struct NoOpIter {
239 int& operator*() const {
240 static int i = 0;
241 return i;
242 }
243 NoOpIter& operator++() {
244 return *this;
245 }
246 NoOpIter operator++(int) {
247 return *this;
248 }
249};
250
251struct NoOpContainer {
252 size_t size() const {
253 return 0;
254 }
255
256 bool empty() const {
257 return true;
258 }
259
260 void clear() {}
261
262 NoOpIter begin() {
263 return {};
264 }
265
266 void emplace_back(int) {}
267
268 int& front() {
269 static int i;
270 return i;
271 }
272};
273
274struct NoOpStateGen {
275 int operator()() const {
276 return 0;
277 }
278};
279
280template <
281 typename TaskSetT,
282 typename IntegerT,
283 typename F,
284 typename StateContainer,
285 typename StateGen>
286void parallel_for_staticImpl(
287 TaskSetT& taskSet,
288 StateContainer& states,
289 const StateGen& defaultState,
290 const ChunkedRange<IntegerT>& range,
291 F&& f,
292 ssize_t maxThreads,
293 bool wait,
294 bool reuseExistingState) {
295 using size_type = typename ChunkedRange<IntegerT>::size_type;
296
297 size_type numThreads = std::min<size_type>(taskSet.numPoolThreads() + wait, maxThreads);
298 // Reduce threads used if they exceed work to be done.
299 numThreads = std::min(numThreads, range.size());
300
301 if (!reuseExistingState) {
302 states.clear();
303 }
304
305 size_t numToEmplace = states.size() < static_cast<size_t>(numThreads)
306 ? static_cast<size_t>(numThreads) - states.size()
307 : 0;
308
309 for (; numToEmplace--;) {
310 states.emplace_back(defaultState());
311 }
312
313 auto chunking =
314 detail::staticChunkSize(static_cast<ssize_t>(range.size()), static_cast<ssize_t>(numThreads));
315 IntegerT chunkSize = static_cast<IntegerT>(chunking.ceilChunkSize);
316
317 bool perfectlyChunked = static_cast<size_type>(chunking.transitionTaskIndex) == numThreads;
318
319 // (!perfectlyChunked) ? chunking.transitionTaskIndex : numThreads - 1;
320 size_type firstLoopLen = chunking.transitionTaskIndex - perfectlyChunked;
321
322 auto stateIt = states.begin();
323 IntegerT start = range.start;
324 size_type t;
325 for (t = 0; t < firstLoopLen; ++t) {
326 IntegerT next = static_cast<IntegerT>(start + chunkSize);
327 taskSet.schedule([it = stateIt++, start, next, f]() {
328 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
329 f(*it, start, next);
330 });
331 start = next;
332 }
333
334 // Reduce the remaining chunk sizes by 1.
335 chunkSize = static_cast<IntegerT>(chunkSize - !perfectlyChunked);
336 // Finish submitting all but the last item.
337 for (; t < numThreads - 1; ++t) {
338 IntegerT next = static_cast<IntegerT>(start + chunkSize);
339 taskSet.schedule([it = stateIt++, start, next, f]() {
340 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
341 f(*it, start, next);
342 });
343 start = next;
344 }
345
346 if (wait) {
347 f(*stateIt, start, range.end);
348 taskSet.wait();
349 } else {
350 taskSet.schedule(
351 [stateIt, start, end = range.end, f]() {
352 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
353 f(*stateIt, start, end);
354 },
355 ForceQueuingTag());
356 }
357}
358
359} // namespace detail
360
377template <
378 typename TaskSetT,
379 typename IntegerT,
380 typename F,
381 typename StateContainer,
382 typename StateGen>
384 TaskSetT& taskSet,
385 StateContainer& states,
386 const StateGen& defaultState,
387 const ChunkedRange<IntegerT>& range,
388 F&& f,
389 ParForOptions options = {}) {
390 if (range.empty()) {
391 if (options.wait) {
392 taskSet.wait();
393 }
394 return;
395 }
396
397 using size_type = typename ChunkedRange<IntegerT>::size_type;
398
399 // Ensure minItemsPerChunk is sane
400 uint32_t minItemsPerChunk = std::max<uint32_t>(1, options.minItemsPerChunk);
401
402 // 0 indicates serial execution per API spec
403 size_type maxThreads = std::max<int32_t>(options.maxThreads, 1);
404
405 bool isStatic = range.isStatic();
406
407 const size_type N = taskSet.numPoolThreads();
408 if (N == 0 || !options.maxThreads || range.size() <= minItemsPerChunk ||
409 detail::PerPoolPerThreadInfo::isParForRecursive(&taskSet.pool())) {
410 if (!options.reuseExistingState) {
411 states.clear();
412 }
413 if (states.empty()) {
414 states.emplace_back(defaultState());
415 }
416 f(*states.begin(), range.start, range.end);
417 if (options.wait) {
418 taskSet.wait();
419 }
420 return;
421 }
422
423 // Adjust down workers if we would have too-small chunks
424 if (minItemsPerChunk > 1) {
425 size_type maxWorkers = range.size() / minItemsPerChunk;
426 if (maxWorkers < maxThreads) {
427 maxThreads = static_cast<uint32_t>(maxWorkers);
428 }
429 if (range.size() / (maxThreads + options.wait) < minItemsPerChunk && range.isAuto()) {
430 isStatic = true;
431 }
432 } else if (range.size() <= N + options.wait) {
433 if (range.isAuto()) {
434 isStatic = true;
435 } else if (!range.isStatic()) {
436 maxThreads = range.size() - options.wait;
437 }
438 }
439
440 if (isStatic) {
441 detail::parallel_for_staticImpl(
442 taskSet,
443 states,
444 defaultState,
445 range,
446 std::forward<F>(f),
447 static_cast<ssize_t>(maxThreads),
448 options.wait,
449 options.reuseExistingState);
450 return;
451 }
452
453 // wanting maxThreads workers (potentially including the calling thread), capped by N
454 const size_type numToLaunch = std::min<size_type>(maxThreads - options.wait, N);
455
456 if (!options.reuseExistingState) {
457 states.clear();
458 }
459
460 size_t numToEmplace = static_cast<size_type>(states.size()) < (numToLaunch + options.wait)
461 ? (static_cast<size_t>(numToLaunch) + options.wait) - states.size()
462 : 0;
463 for (; numToEmplace--;) {
464 states.emplace_back(defaultState());
465 }
466
467 if (numToLaunch == 1 && !options.wait) {
468 taskSet.schedule(
469 [&s = states.front(), range, f = std::move(f)]() { f(s, range.start, range.end); });
470
471 return;
472 }
473
474 auto chunkInfo = range.calcChunkSize(numToLaunch, options.wait, minItemsPerChunk);
475 auto chunkSize = std::get<0>(chunkInfo);
476 auto numChunks = std::get<1>(chunkInfo);
477
478 if (options.wait) {
479 alignas(kCacheLineSize) std::atomic<decltype(numChunks)> index(0);
480 auto worker = [start = range.start, end = range.end, &index, f, chunkSize, numChunks](auto& s) {
481 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
482
483 while (true) {
484 auto cur = index.fetch_add(1, std::memory_order_relaxed);
485 if (cur >= numChunks) {
486 break;
487 }
488 auto sidx = static_cast<IntegerT>(start + cur * chunkSize);
489 if (cur + 1 == numChunks) {
490 f(s, sidx, end);
491 } else {
492 auto eidx = static_cast<IntegerT>(sidx + chunkSize);
493 f(s, sidx, eidx);
494 }
495 }
496 };
497
498 auto it = states.begin();
499 for (size_type i = 0; i < numToLaunch; ++i) {
500 taskSet.schedule([&s = *it++, worker]() { worker(s); });
501 }
502 worker(*it);
503 taskSet.wait();
504 } else {
505 struct Atomic {
506 Atomic() : index(0) {}
507 alignas(kCacheLineSize) std::atomic<decltype(numChunks)> index;
508 char buffer[kCacheLineSize - sizeof(index)];
509 };
510
511 void* ptr = detail::alignedMalloc(sizeof(Atomic), alignof(Atomic));
512 auto* atm = new (ptr) Atomic();
513
514 std::shared_ptr<Atomic> wrapper(atm, detail::AlignedFreeDeleter<Atomic>());
515 auto worker = [start = range.start,
516 end = range.end,
517 wrapper = std::move(wrapper),
518 f,
519 chunkSize,
520 numChunks](auto& s) {
521 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
522 while (true) {
523 auto cur = wrapper->index.fetch_add(1, std::memory_order_relaxed);
524 if (cur >= numChunks) {
525 break;
526 }
527 auto sidx = static_cast<IntegerT>(start + cur * chunkSize);
528 if (cur + 1 == numChunks) {
529 f(s, sidx, end);
530 } else {
531 auto eidx = static_cast<IntegerT>(sidx + chunkSize);
532 f(s, sidx, eidx);
533 }
534 }
535 };
536
537 auto it = states.begin();
538 for (size_type i = 0; i < numToLaunch; ++i) {
539 taskSet.schedule([&s = *it++, worker]() { worker(s); }, ForceQueuingTag());
540 }
541 }
542}
543
553template <typename TaskSetT, typename IntegerT, typename F>
554DISPENSO_REQUIRES(ParallelForRangeFunc<F, IntegerT>)
555void parallel_for(
556 TaskSetT& taskSet,
557 const ChunkedRange<IntegerT>& range,
558 F&& f,
559 ParForOptions options = {}) {
560 detail::NoOpContainer container;
561 parallel_for(
562 taskSet,
563 container,
564 detail::NoOpStateGen(),
565 range,
566 [f = std::move(f)](int /*noop*/, auto i, auto j) { f(i, j); },
567 options);
568}
569
579template <typename IntegerT, typename F>
580DISPENSO_REQUIRES(ParallelForRangeFunc<F, IntegerT>)
581void parallel_for(const ChunkedRange<IntegerT>& range, F&& f, ParForOptions options = {}) {
582 TaskSet taskSet(globalThreadPool());
583 options.wait = true;
584 parallel_for(taskSet, range, std::forward<F>(f), options);
585}
586
604template <typename F, typename IntegerT, typename StateContainer, typename StateGen>
606 StateContainer& states,
607 const StateGen& defaultState,
608 const ChunkedRange<IntegerT>& range,
609 F&& f,
610 ParForOptions options = {}) {
611 TaskSet taskSet(globalThreadPool());
612 options.wait = true;
613 parallel_for(taskSet, states, defaultState, range, std::forward<F>(f), options);
614}
615
626template <
627 typename TaskSetT,
628 typename IntegerA,
629 typename IntegerB,
630 typename F,
631 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
632 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
633 std::enable_if_t<detail::CanInvoke<F(IntegerA)>::value, bool> = true>
635 TaskSetT& taskSet,
636 IntegerA start,
637 IntegerB end,
638 F&& f,
639 ParForOptions options = {}) {
640 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
641
642 auto range = makeChunkedRange(start, end, options.defaultChunking);
643 parallel_for(
644 taskSet,
645 range,
646 [f = std::move(f)](IntegerT s, IntegerT e) {
647 for (IntegerT i = s; i < e; ++i) {
648 f(i);
649 }
650 },
651 options);
652}
653
655template <
656 typename TaskSetT,
657 typename IntegerA,
658 typename IntegerB,
659 typename F,
660 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
661 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
662 std::enable_if_t<detail::CanInvoke<F(IntegerA, IntegerB)>::value, bool> = true>
663void parallel_for(
664 TaskSetT& taskSet,
665 IntegerA start,
666 IntegerB end,
667 F&& f,
668 ParForOptions options = {}) {
669 auto range = makeChunkedRange(start, end, options.defaultChunking);
670 parallel_for(taskSet, range, std::forward<F>(f), options);
671}
672
683template <
684 typename IntegerA,
685 typename IntegerB,
686 typename F,
687 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
688 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true>
689void parallel_for(IntegerA start, IntegerB end, F&& f, ParForOptions options = {}) {
690 TaskSet taskSet(globalThreadPool());
691 options.wait = true;
692 parallel_for(taskSet, start, end, std::forward<F>(f), options);
693}
694
713template <
714 typename TaskSetT,
715 typename IntegerA,
716 typename IntegerB,
717 typename F,
718 typename StateContainer,
719 typename StateGen,
720 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
721 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
722 std::enable_if_t<
723 detail::CanInvoke<F(typename StateContainer::reference, IntegerA)>::value,
724 bool> = true>
726 TaskSetT& taskSet,
727 StateContainer& states,
728 const StateGen& defaultState,
729 IntegerA start,
730 IntegerB end,
731 F&& f,
732 ParForOptions options = {}) {
733 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
734 auto range = makeChunkedRange(start, end, options.defaultChunking);
735 parallel_for(
736 taskSet,
737 states,
738 defaultState,
739 range,
740 [f = std::move(f)](auto& state, IntegerT s, IntegerT e) {
741 for (IntegerT i = s; i < e; ++i) {
742 f(state, i);
743 }
744 },
745 options);
746}
747
749template <
750 typename TaskSetT,
751 typename IntegerA,
752 typename IntegerB,
753 typename F,
754 typename StateContainer,
755 typename StateGen,
756 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
757 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
758 std::enable_if_t<
759 detail::CanInvoke<F(typename StateContainer::reference, IntegerA, IntegerB)>::value,
760 bool> = true>
761void parallel_for(
762 TaskSetT& taskSet,
763 StateContainer& states,
764 const StateGen& defaultState,
765 IntegerA start,
766 IntegerB end,
767 F&& f,
768 ParForOptions options = {}) {
769 auto range = makeChunkedRange(start, end, options.defaultChunking);
770 parallel_for(taskSet, states, defaultState, range, std::forward<F>(f), options);
771}
772
792template <
793 typename IntegerA,
794 typename IntegerB,
795 typename F,
796 typename StateContainer,
797 typename StateGen,
798 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
799 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true>
801 StateContainer& states,
802 const StateGen& defaultState,
803 IntegerA start,
804 IntegerB end,
805 F&& f,
806 ParForOptions options = {}) {
807 TaskSet taskSet(globalThreadPool());
808 options.wait = true;
809 parallel_for(taskSet, states, defaultState, start, end, std::forward<F>(f), options);
810}
811
812} // namespace dispenso
void parallel_for(TaskSetT &taskSet, StateContainer &states, const StateGen &defaultState, const ChunkedRange< IntegerT > &range, F &&f, ParForOptions options={})
ChunkedRange< std::common_type_t< IntegerA, IntegerB > > makeChunkedRange(IntegerA start, IntegerB end, ParForChunking chunking=ParForChunking::kStatic)
constexpr size_t kCacheLineSize
A constant that defines a safe number of bytes+alignment to avoid false sharing.
Definition platform.h:89
ParForChunking defaultChunking