Skip to content

Commit 67d7d18

Browse files
committed
Expose mechanism to skip AD handling for a parallel region.
Add beginSkippedParallelRegion, endSkippedParallelRegion to the logic interface. Implementation in OmpLogic. Refactor skipParallelHandling -> skipParallelRegion.
1 parent 49c2bee commit 67d7d18

3 files changed

Lines changed: 32 additions & 8 deletions

File tree

include/opdi/logic/logicInterface.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ namespace opdi {
9494

9595
virtual void addReverseBarrier() = 0;
9696
virtual void addReverseFlush() = 0;
97+
98+
virtual void beginSkippedParallelRegion() = 0;
99+
virtual void endSkippedParallelRegion() = 0;
97100
};
98101

99102
extern LogicInterface* logic;

include/opdi/logic/omp/parallelOmpLogic.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#include "implicitTaskOmpLogic.hpp"
3636
#include "parallelOmpLogic.hpp"
3737

38-
int opdi::ParallelOmpLogic::skipParallelHandling = 0;
38+
int opdi::ParallelOmpLogic::skipParallelRegion = 0;
3939

4040
void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
4141

@@ -47,7 +47,7 @@ void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
4747
}
4848
#endif
4949

50-
++ParallelOmpLogic::skipParallelHandling;
50+
ParallelOmpLogic::internalBeginSkippedParallelRegion();
5151

5252
#pragma omp parallel num_threads(parallelData->actualSizeOfTeam)
5353
{
@@ -94,7 +94,7 @@ void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
9494
#endif
9595
}
9696

97-
--ParallelOmpLogic::skipParallelHandling;
97+
ParallelOmpLogic::internalEndSkippedParallelRegion();
9898

9999
#if OPDI_OMP_LOGIC_INSTRUMENT
100100
for (auto& instrument : ompLogicInstruments) {
@@ -107,7 +107,7 @@ void opdi::ParallelOmpLogic::deleteFunc(void* parallelDataPtr) {
107107

108108
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
109109

110-
++ParallelOmpLogic::skipParallelHandling;
110+
ParallelOmpLogic::internalBeginSkippedParallelRegion();
111111

112112
// this triggers possibly pending implicit task end events
113113
#pragma omp parallel num_threads(parallelData->actualSizeOfTeam)
@@ -134,7 +134,7 @@ void opdi::ParallelOmpLogic::deleteFunc(void* parallelDataPtr) {
134134
delete implicitTaskData;
135135
}
136136

137-
--ParallelOmpLogic::skipParallelHandling;
137+
ParallelOmpLogic::internalEndSkippedParallelRegion();
138138

139139
tool->freePosition(parallelData->encounteringTaskTapePosition);
140140

@@ -169,7 +169,7 @@ void opdi::ParallelOmpLogic::internalSetAdjointAccessMode(ImplicitTaskData* impl
169169

170170
void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTaskDataPtr, int maximumSizeOfTeam) {
171171

172-
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelHandling == 0) {
172+
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelRegion == 0) {
173173

174174
ImplicitTaskData* encounteringTaskData = static_cast<ImplicitTaskData*>(encounteringTaskDataPtr);
175175

@@ -274,3 +274,18 @@ opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccess
274274
}
275275
}
276276

277+
void opdi::ParallelOmpLogic::internalBeginSkippedParallelRegion() {
278+
++ParallelOmpLogic::skipParallelRegion;
279+
}
280+
281+
void opdi::ParallelOmpLogic::internalEndSkippedParallelRegion() {
282+
--ParallelOmpLogic::skipParallelRegion;
283+
}
284+
285+
void opdi::ParallelOmpLogic::beginSkippedParallelRegion() {
286+
ParallelOmpLogic::internalBeginSkippedParallelRegion();
287+
}
288+
289+
void opdi::ParallelOmpLogic::endSkippedParallelRegion() {
290+
ParallelOmpLogic::internalEndSkippedParallelRegion();
291+
}

include/opdi/logic/omp/parallelOmpLogic.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ namespace opdi {
5454

5555
private:
5656

57-
static int skipParallelHandling;
58-
#pragma omp threadprivate(skipParallelHandling)
57+
static int skipParallelRegion;
58+
#pragma omp threadprivate(skipParallelRegion)
59+
60+
static void internalBeginSkippedParallelRegion();
61+
static void internalEndSkippedParallelRegion();
5962

6063
static void reverseFunc(void* parallelData);
6164
static void deleteFunc(void* parallelData);
@@ -70,5 +73,8 @@ namespace opdi {
7073

7174
virtual void setAdjointAccessMode(AdjointAccessMode mode);
7275
virtual AdjointAccessMode getAdjointAccessMode() const;
76+
77+
virtual void beginSkippedParallelRegion();
78+
virtual void endSkippedParallelRegion();
7379
};
7480
}

0 commit comments

Comments
 (0)