dispenso
A library for task parallelism
 
Loading...
Searching...
No Matches
parallel_for.h
Go to the documentation of this file.
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
13#pragma once
14
15#include <cmath>
16#include <limits>
17#include <memory>
18
19#include <dispenso/detail/can_invoke.h>
20#include <dispenso/detail/per_thread_info.h>
21#include <dispenso/task_set.h>
22
23namespace dispenso {
24
33enum class ParForChunking { kStatic, kAuto };
34
44 uint32_t maxThreads = std::numeric_limits<int32_t>::max();
53 bool wait = true;
54
59 ParForChunking defaultChunking = ParForChunking::kStatic;
60
67
73 bool reuseExistingState = false;
74};
75
86template <typename IntegerT = ssize_t>
88 // We need to utilize 64-bit integers to avoid overflow, e.g. passing -2**30, 2**30 as int32 will
89 // result in overflow unless we cast to 64-bit. Note that if we have a range of e.g. -2**63+1 to
90 // 2**63-1, we cannot hold the result in an int64_t. We could in a uint64_t, but it is quite
91 // tricky to make this work. However, I do not expect ranges larger than can be held in int64_t
92 // since people want their computations to finish before the heat death of the sun (slight
93 // exaggeration).
94 using size_type = std::conditional_t<std::is_signed<IntegerT>::value, int64_t, uint64_t>;
95
96 struct Static {};
97 struct Auto {};
98 static constexpr IntegerT kStatic = std::numeric_limits<IntegerT>::max();
99
107 ChunkedRange(IntegerT s, IntegerT e, IntegerT c) : start(s), end(e), chunk(c) {}
123
124 bool isStatic() const {
125 return chunk == kStatic;
126 }
127
128 bool isAuto() const {
129 return chunk == 0;
130 }
131
132 bool empty() const {
133 return end <= start;
134 }
135
136 size_type size() const {
137 return static_cast<size_type>(end) - start;
138 }
139
140 template <typename OtherInt>
141 std::tuple<size_type, size_type>
142 calcChunkSize(OtherInt numLaunched, bool oneOnCaller, size_type minChunkSize) const {
143 size_type workingThreads = static_cast<size_type>(numLaunched) + size_type{oneOnCaller};
144 assert(workingThreads > 0);
145
146 if (!chunk) {
147 size_type dynFactor = std::min<size_type>(16, size() / workingThreads);
148 size_type chunkSize;
149 do {
150 size_type roughChunks = dynFactor * workingThreads;
151 chunkSize = (size() + roughChunks - 1) / roughChunks;
152 --dynFactor;
153 } while (chunkSize < minChunkSize);
154 return {chunkSize, (size() + chunkSize - 1) / chunkSize};
155 } else if (chunk == kStatic) {
156 // This should never be called. The static distribution versions of the parallel_for
157 // functions should be invoked instead.
158 std::abort();
159 }
160 return {chunk, (size() + chunk - 1) / chunk};
161 }
162
163 IntegerT start;
164 IntegerT end;
165 IntegerT chunk;
166};
167
175template <typename IntegerA, typename IntegerB>
176inline ChunkedRange<std::common_type_t<IntegerA, IntegerB>>
177makeChunkedRange(IntegerA start, IntegerB end, ParForChunking chunking = ParForChunking::kStatic) {
178 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
179 return (chunking == ParForChunking::kStatic)
182}
183
191template <typename IntegerA, typename IntegerB, typename IntegerC>
192inline ChunkedRange<std::common_type_t<IntegerA, IntegerB>>
196
197namespace detail {
198
199struct NoOpIter {
200 int& operator*() const {
201 static int i = 0;
202 return i;
203 }
204 NoOpIter& operator++() {
205 return *this;
206 }
207 NoOpIter operator++(int) {
208 return *this;
209 }
210};
211
212struct NoOpContainer {
213 size_t size() const {
214 return 0;
215 }
216
217 bool empty() const {
218 return true;
219 }
220
221 void clear() {}
222
223 NoOpIter begin() {
224 return {};
225 }
226
227 void emplace_back(int) {}
228
229 int& front() {
230 static int i;
231 return i;
232 }
233};
234
235struct NoOpStateGen {
236 int operator()() const {
237 return 0;
238 }
239};
240
241template <
242 typename TaskSetT,
243 typename IntegerT,
244 typename F,
245 typename StateContainer,
246 typename StateGen>
247void parallel_for_staticImpl(
248 TaskSetT& taskSet,
249 StateContainer& states,
250 const StateGen& defaultState,
251 const ChunkedRange<IntegerT>& range,
252 F&& f,
253 ssize_t maxThreads,
254 bool wait,
255 bool reuseExistingState) {
256 using size_type = typename ChunkedRange<IntegerT>::size_type;
257
258 size_type numThreads = std::min<size_type>(taskSet.numPoolThreads() + wait, maxThreads);
259 // Reduce threads used if they exceed work to be done.
260 numThreads = std::min(numThreads, range.size());
261
262 if (!reuseExistingState) {
263 states.clear();
264 }
265
266 size_t numToEmplace = states.size() < static_cast<size_t>(numThreads)
267 ? static_cast<size_t>(numThreads) - states.size()
268 : 0;
269
270 for (; numToEmplace--;) {
271 states.emplace_back(defaultState());
272 }
273
274 auto chunking =
275 detail::staticChunkSize(static_cast<ssize_t>(range.size()), static_cast<ssize_t>(numThreads));
276 IntegerT chunkSize = static_cast<IntegerT>(chunking.ceilChunkSize);
277
278 bool perfectlyChunked = static_cast<size_type>(chunking.transitionTaskIndex) == numThreads;
279
280 // (!perfectlyChunked) ? chunking.transitionTaskIndex : numThreads - 1;
281 size_type firstLoopLen = chunking.transitionTaskIndex - perfectlyChunked;
282
283 auto stateIt = states.begin();
284 IntegerT start = range.start;
285 size_type t;
286 for (t = 0; t < firstLoopLen; ++t) {
287 IntegerT next = static_cast<IntegerT>(start + chunkSize);
288 taskSet.schedule([it = stateIt++, start, next, f]() {
289 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
290 f(*it, start, next);
291 });
292 start = next;
293 }
294
295 // Reduce the remaining chunk sizes by 1.
296 chunkSize = static_cast<IntegerT>(chunkSize - !perfectlyChunked);
297 // Finish submitting all but the last item.
298 for (; t < numThreads - 1; ++t) {
299 IntegerT next = static_cast<IntegerT>(start + chunkSize);
300 taskSet.schedule([it = stateIt++, start, next, f]() {
301 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
302 f(*it, start, next);
303 });
304 start = next;
305 }
306
307 if (wait) {
308 f(*stateIt, start, range.end);
309 taskSet.wait();
310 } else {
311 taskSet.schedule(
312 [stateIt, start, end = range.end, f]() {
313 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
314 f(*stateIt, start, end);
315 },
316 ForceQueuingTag());
317 }
318}
319
320} // namespace detail
321
338template <
339 typename TaskSetT,
340 typename IntegerT,
341 typename F,
342 typename StateContainer,
343 typename StateGen>
347 const StateGen& defaultState,
349 F&& f,
350 ParForOptions options = {}) {
351 if (range.empty()) {
352 if (options.wait) {
353 taskSet.wait();
354 }
355 return;
356 }
357
358 using size_type = typename ChunkedRange<IntegerT>::size_type;
359
360 // Ensure minItemsPerChunk is sane
361 uint32_t minItemsPerChunk = std::max<uint32_t>(1, options.minItemsPerChunk);
362
363 // 0 indicates serial execution per API spec
364 size_type maxThreads = std::max<int32_t>(options.maxThreads, 1);
365
366 bool isStatic = range.isStatic();
367
368 const size_type N = taskSet.numPoolThreads();
369 if (N == 0 || !options.maxThreads || range.size() <= minItemsPerChunk ||
370 detail::PerPoolPerThreadInfo::isParForRecursive(&taskSet.pool())) {
371 if (!options.reuseExistingState) {
372 states.clear();
373 }
374 if (states.empty()) {
375 states.emplace_back(defaultState());
376 }
377 f(*states.begin(), range.start, range.end);
378 if (options.wait) {
379 taskSet.wait();
380 }
381 return;
382 }
383
384 // Adjust down workers if we would have too-small chunks
385 if (minItemsPerChunk > 1) {
386 size_type maxWorkers = range.size() / minItemsPerChunk;
387 if (maxWorkers < maxThreads) {
388 maxThreads = static_cast<uint32_t>(maxWorkers);
389 }
390 if (range.size() / (maxThreads + options.wait) < minItemsPerChunk && range.isAuto()) {
391 isStatic = true;
392 }
393 } else if (range.size() <= N + options.wait) {
394 if (range.isAuto()) {
395 isStatic = true;
396 } else if (!range.isStatic()) {
397 maxThreads = range.size() - options.wait;
398 }
399 }
400
401 if (isStatic) {
402 detail::parallel_for_staticImpl(
403 taskSet,
404 states,
405 defaultState,
406 range,
407 std::forward<F>(f),
408 static_cast<ssize_t>(maxThreads),
409 options.wait,
410 options.reuseExistingState);
411 return;
412 }
413
414 // wanting maxThreads workers (potentially including the calling thread), capped by N
415 const size_type numToLaunch = std::min<size_type>(maxThreads - options.wait, N);
416
417 if (!options.reuseExistingState) {
418 states.clear();
419 }
420
421 size_t numToEmplace = static_cast<size_type>(states.size()) < (numToLaunch + options.wait)
422 ? (static_cast<size_t>(numToLaunch) + options.wait) - states.size()
423 : 0;
424 for (; numToEmplace--;) {
425 states.emplace_back(defaultState());
426 }
427
428 if (numToLaunch == 1 && !options.wait) {
429 taskSet.schedule(
430 [&s = states.front(), range, f = std::move(f)]() { f(s, range.start, range.end); });
431
432 return;
433 }
434
435 auto chunkInfo = range.calcChunkSize(numToLaunch, options.wait, minItemsPerChunk);
436 auto chunkSize = std::get<0>(chunkInfo);
437 auto numChunks = std::get<1>(chunkInfo);
438
439 if (options.wait) {
440 alignas(kCacheLineSize) std::atomic<decltype(numChunks)> index(0);
441 auto worker = [start = range.start, end = range.end, &index, f, chunkSize, numChunks](auto& s) {
442 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
443
444 while (true) {
445 auto cur = index.fetch_add(1, std::memory_order_relaxed);
446 if (cur >= numChunks) {
447 break;
448 }
449 auto sidx = static_cast<IntegerT>(start + cur * chunkSize);
450 if (cur + 1 == numChunks) {
451 f(s, sidx, end);
452 } else {
453 auto eidx = static_cast<IntegerT>(sidx + chunkSize);
454 f(s, sidx, eidx);
455 }
456 }
457 };
458
459 auto it = states.begin();
460 for (size_type i = 0; i < numToLaunch; ++i) {
461 taskSet.schedule([&s = *it++, worker]() { worker(s); });
462 }
463 worker(*it);
464 taskSet.wait();
465 } else {
466 struct Atomic {
467 Atomic() : index(0) {}
468 alignas(kCacheLineSize) std::atomic<decltype(numChunks)> index;
469 char buffer[kCacheLineSize - sizeof(index)];
470 };
471
472 void* ptr = detail::alignedMalloc(sizeof(Atomic), alignof(Atomic));
473 auto* atm = new (ptr) Atomic();
474
475 std::shared_ptr<Atomic> wrapper(atm, detail::AlignedFreeDeleter<Atomic>());
476 auto worker = [start = range.start,
477 end = range.end,
478 wrapper = std::move(wrapper),
479 f,
480 chunkSize,
481 numChunks](auto& s) {
482 auto recurseInfo = detail::PerPoolPerThreadInfo::parForRecurse();
483 while (true) {
484 auto cur = wrapper->index.fetch_add(1, std::memory_order_relaxed);
485 if (cur >= numChunks) {
486 break;
487 }
488 auto sidx = static_cast<IntegerT>(start + cur * chunkSize);
489 if (cur + 1 == numChunks) {
490 f(s, sidx, end);
491 } else {
492 auto eidx = static_cast<IntegerT>(sidx + chunkSize);
493 f(s, sidx, eidx);
494 }
495 }
496 };
497
498 auto it = states.begin();
499 for (size_type i = 0; i < numToLaunch; ++i) {
500 taskSet.schedule([&s = *it++, worker]() { worker(s); }, ForceQueuingTag());
501 }
502 }
503}
504
514template <typename TaskSetT, typename IntegerT, typename F>
518 F&& f,
519 ParForOptions options = {}) {
520 detail::NoOpContainer container;
521 parallel_for(
522 taskSet,
523 container,
524 detail::NoOpStateGen(),
525 range,
526 [f = std::move(f)](int /*noop*/, auto i, auto j) { f(i, j); },
527 options);
528}
529
539template <typename IntegerT, typename F>
541 TaskSet taskSet(globalThreadPool());
542 options.wait = true;
543 parallel_for(taskSet, range, std::forward<F>(f), options);
544}
545
563template <typename F, typename IntegerT, typename StateContainer, typename StateGen>
566 const StateGen& defaultState,
568 F&& f,
569 ParForOptions options = {}) {
570 TaskSet taskSet(globalThreadPool());
571 options.wait = true;
572 parallel_for(taskSet, states, defaultState, range, std::forward<F>(f), options);
573}
574
585template <
586 typename TaskSetT,
587 typename IntegerA,
588 typename IntegerB,
589 typename F,
590 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
591 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
592 std::enable_if_t<detail::CanInvoke<F(IntegerA)>::value, bool> = true>
595 IntegerA start,
596 IntegerB end,
597 F&& f,
598 ParForOptions options = {}) {
599 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
600
601 auto range = makeChunkedRange(start, end, options.defaultChunking);
602 parallel_for(
603 taskSet,
604 range,
605 [f = std::move(f)](IntegerT s, IntegerT e) {
606 for (IntegerT i = s; i < e; ++i) {
607 f(i);
608 }
609 },
610 options);
611}
612
613template <
614 typename TaskSetT,
615 typename IntegerA,
616 typename IntegerB,
617 typename F,
618 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
619 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
620 std::enable_if_t<detail::CanInvoke<F(IntegerA, IntegerB)>::value, bool> = true>
621void parallel_for(
622 TaskSetT& taskSet,
623 IntegerA start,
624 IntegerB end,
625 F&& f,
626 ParForOptions options = {}) {
627 auto range = makeChunkedRange(start, end, options.defaultChunking);
628 parallel_for(taskSet, range, std::forward<F>(f), options);
629}
630
641template <
642 typename IntegerA,
643 typename IntegerB,
644 typename F,
645 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
646 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true>
648 TaskSet taskSet(globalThreadPool());
649 options.wait = true;
650 parallel_for(taskSet, start, end, std::forward<F>(f), options);
651}
652
671template <
672 typename TaskSetT,
673 typename IntegerA,
674 typename IntegerB,
675 typename F,
676 typename StateContainer,
677 typename StateGen,
678 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
679 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
680 std::enable_if_t<
681 detail::CanInvoke<F(typename StateContainer::reference, IntegerA)>::value,
682 bool> = true>
686 const StateGen& defaultState,
687 IntegerA start,
688 IntegerB end,
689 F&& f,
690 ParForOptions options = {}) {
691 using IntegerT = std::common_type_t<IntegerA, IntegerB>;
692 auto range = makeChunkedRange(start, end, options.defaultChunking);
693 parallel_for(
694 taskSet,
695 states,
696 defaultState,
697 range,
698 [f = std::move(f)](auto& state, IntegerT s, IntegerT e) {
699 for (IntegerT i = s; i < e; ++i) {
700 f(state, i);
701 }
702 },
703 options);
704}
705
706template <
707 typename TaskSetT,
708 typename IntegerA,
709 typename IntegerB,
710 typename F,
711 typename StateContainer,
712 typename StateGen,
713 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
714 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true,
715 std::enable_if_t<
716 detail::CanInvoke<F(typename StateContainer::reference, IntegerA, IntegerB)>::value,
717 bool> = true>
718void parallel_for(
719 TaskSetT& taskSet,
720 StateContainer& states,
721 const StateGen& defaultState,
722 IntegerA start,
723 IntegerB end,
724 F&& f,
725 ParForOptions options = {}) {
726 auto range = makeChunkedRange(start, end, options.defaultChunking);
727 parallel_for(taskSet, states, defaultState, range, std::forward<F>(f), options);
728}
729
749template <
750 typename IntegerA,
751 typename IntegerB,
752 typename F,
753 typename StateContainer,
754 typename StateGen,
755 std::enable_if_t<std::is_integral<IntegerA>::value, bool> = true,
756 std::enable_if_t<std::is_integral<IntegerB>::value, bool> = true>
759 const StateGen& defaultState,
760 IntegerA start,
761 IntegerB end,
762 F&& f,
763 ParForOptions options = {}) {
764 TaskSet taskSet(globalThreadPool());
765 options.wait = true;
766 parallel_for(taskSet, states, defaultState, start, end, std::forward<F>(f), options);
767}
768
769} // namespace dispenso
void parallel_for(TaskSetT &taskSet, StateContainer &states, const StateGen &defaultState, const ChunkedRange< IntegerT > &range, F &&f, ParForOptions options={})
ChunkedRange< std::common_type_t< IntegerA, IntegerB > > makeChunkedRange(IntegerA start, IntegerB end, ParForChunking chunking=ParForChunking::kStatic)
detail::OpResult< T > OpResult
Definition pipeline.h:29
ChunkedRange(IntegerT s, IntegerT e, Auto)
ChunkedRange(IntegerT s, IntegerT e, IntegerT c)
ChunkedRange(IntegerT s, IntegerT e, Static)
ParForChunking defaultChunking