Skip to content

Commit 9897831

Browse files
committed
Background-independent treatment of parallel + firstprivate.
Move special treatment to the logic layer. Expose mechanism to skip AD handling for a parallel region, allowing backend-independent access to default tapes. Users control activity of default tapes, it should match the activity of the encountering task's tape. New ParallelForLastprivate and ParallelCopyin tests. Merge branch 'revisit/parallelFirstprivate' into develop
2 parents 16d5836 + efd802f commit 9897831

30 files changed

Lines changed: 871 additions & 48 deletions

include/opdi/backend/macro/probes.hpp

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,48 +41,19 @@ namespace opdi {
4141

4242
void* parallelData;
4343
void* taskData;
44-
void* masterPosition;
4544
bool needsAction;
4645

47-
TaskProbe() : parallelData(nullptr), taskData(nullptr), needsAction(false) {
48-
this->masterPosition = tool->allocPosition();
49-
opdi::tool->getTapePosition(tool->getThreadLocalTape(), this->masterPosition);
50-
}
46+
TaskProbe() : parallelData(nullptr), taskData(nullptr), needsAction(false) {}
5147

52-
TaskProbe(void* parallelData) : parallelData(parallelData), taskData(nullptr), needsAction(false) {
53-
this->masterPosition = tool->allocPosition();
54-
tool->getTapePosition(tool->getThreadLocalTape(), this->masterPosition);
55-
}
48+
TaskProbe(void* parallelData) : parallelData(parallelData), taskData(nullptr), needsAction(false) {}
5649

5750
TaskProbe(TaskProbe const& other) : parallelData(other.parallelData), needsAction(true) {
5851

59-
this->masterPosition = tool->allocPosition();
60-
if (omp_get_thread_num() == 0) {
61-
tool->copyPosition(this->masterPosition, other.masterPosition);
62-
}
63-
else {
64-
tool->getZeroPosition(tool->getThreadLocalTape(), this->masterPosition);
65-
}
66-
67-
void* oldTape = tool->getThreadLocalTape();
68-
69-
void* currentPosition = tool->allocPosition();
70-
tool->getTapePosition(oldTape, currentPosition);
71-
7252
DataTools::pushParallelData(this->parallelData);
7353
this->taskData = logic->onImplicitTaskBegin(false, omp_get_num_threads(), omp_get_thread_num(),
7454
this->parallelData);
7555
DataTools::pushTaskData(this->taskData);
7656

77-
// check if copy statements have been recorded before the correct tape was set
78-
// if so, move them to the correct tape
79-
if (tool->comparePosition(currentPosition, masterPosition) > 0) {
80-
tool->append(tool->getThreadLocalTape(), oldTape, masterPosition, currentPosition);
81-
tool->erase(oldTape, masterPosition, currentPosition);
82-
}
83-
84-
tool->freePosition(currentPosition);
85-
8657
ProbeScopeStatus::beginImplicitTaskProbeScope();
8758
}
8859

@@ -93,8 +64,6 @@ namespace opdi {
9364
DataTools::popTaskData();
9465
DataTools::popParallelData();
9566
}
96-
97-
tool->freePosition(this->masterPosition);
9867
}
9968
};
10069

include/opdi/logic/logicInterface.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ namespace opdi {
9494

9595
virtual void addReverseBarrier() = 0;
9696
virtual void addReverseFlush() = 0;
97+
98+
virtual void beginSkippedParallelRegion() = 0;
99+
virtual void endSkippedParallelRegion() = 0;
97100
};
98101

99102
extern LogicInterface* logic;

include/opdi/logic/omp/implicitTaskOmpLogic.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,31 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(bool isInitialImplicitTask
9090
implicitTaskData->adjointAccessModes.push_back(parallelData->encounteringTaskAdjointAccessMode);
9191

9292
parallelData->childTaskData[indexInTeam] = implicitTaskData;
93+
94+
// check for copies due to firstprivate/copyin that were recorded on the wrong tapes
95+
// move them to the correct tapes if needed
96+
97+
void* oldTapePosition = tool->allocPosition();
98+
tool->getTapePosition(implicitTaskData->oldTape, oldTapePosition);
99+
100+
void* referencePosition = tool->allocPosition();
101+
if (indexInTeam == 0) {
102+
tool->copyPosition(referencePosition, parallelData->encounteringTaskTapePosition);
103+
}
104+
else {
105+
tool->getZeroPosition(implicitTaskData->oldTape, referencePosition);
106+
}
107+
108+
if (tool->comparePosition(oldTapePosition, referencePosition) > 0) {
109+
// users should ensure that activity of default tapes and encountering task's tape match
110+
assert(parallelData->isActiveParallelRegion);
111+
112+
tool->append(newTape, implicitTaskData->oldTape, referencePosition, oldTapePosition);
113+
tool->erase(implicitTaskData->oldTape, referencePosition, oldTapePosition);
114+
}
115+
116+
tool->freePosition(referencePosition);
117+
tool->freePosition(oldTapePosition);
93118
}
94119
else {
95120
implicitTaskData->oldTape = nullptr;

include/opdi/logic/omp/parallelOmpLogic.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#include "implicitTaskOmpLogic.hpp"
3636
#include "parallelOmpLogic.hpp"
3737

38-
int opdi::ParallelOmpLogic::skipParallelHandling = 0;
38+
int opdi::ParallelOmpLogic::skipParallelRegion = 0;
3939

4040
void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
4141

@@ -47,7 +47,7 @@ void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
4747
}
4848
#endif
4949

50-
++ParallelOmpLogic::skipParallelHandling;
50+
ParallelOmpLogic::internalBeginSkippedParallelRegion();
5151

5252
#pragma omp parallel num_threads(parallelData->actualSizeOfTeam)
5353
{
@@ -94,7 +94,7 @@ void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
9494
#endif
9595
}
9696

97-
--ParallelOmpLogic::skipParallelHandling;
97+
ParallelOmpLogic::internalEndSkippedParallelRegion();
9898

9999
#if OPDI_OMP_LOGIC_INSTRUMENT
100100
for (auto& instrument : ompLogicInstruments) {
@@ -107,7 +107,7 @@ void opdi::ParallelOmpLogic::deleteFunc(void* parallelDataPtr) {
107107

108108
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
109109

110-
++ParallelOmpLogic::skipParallelHandling;
110+
ParallelOmpLogic::internalBeginSkippedParallelRegion();
111111

112112
// this triggers possibly pending implicit task end events
113113
#pragma omp parallel num_threads(parallelData->actualSizeOfTeam)
@@ -134,7 +134,9 @@ void opdi::ParallelOmpLogic::deleteFunc(void* parallelDataPtr) {
134134
delete implicitTaskData;
135135
}
136136

137-
--ParallelOmpLogic::skipParallelHandling;
137+
ParallelOmpLogic::internalEndSkippedParallelRegion();
138+
139+
tool->freePosition(parallelData->encounteringTaskTapePosition);
138140

139141
// delete data of the parallel region
140142
delete parallelData;
@@ -167,7 +169,7 @@ void opdi::ParallelOmpLogic::internalSetAdjointAccessMode(ImplicitTaskData* impl
167169

168170
void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTaskDataPtr, int maximumSizeOfTeam) {
169171

170-
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelHandling == 0) {
172+
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelRegion == 0) {
171173

172174
ImplicitTaskData* encounteringTaskData = static_cast<ImplicitTaskData*>(encounteringTaskDataPtr);
173175

@@ -180,6 +182,8 @@ void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTaskDataPtr, int
180182
parallelData->isActiveParallelRegion = tool->isActive(tool->getThreadLocalTape());
181183
parallelData->encounteringTaskData = encounteringTaskData;
182184
parallelData->encounteringTaskTape = tool->getThreadLocalTape();
185+
parallelData->encounteringTaskTapePosition = tool->allocPosition();
186+
tool->getTapePosition(parallelData->encounteringTaskTape, parallelData->encounteringTaskTapePosition);
183187
parallelData->encounteringTaskAdjointAccessMode = internalGetAdjointAccessMode(encounteringTaskData);
184188
parallelData->childTaskData.resize(maximumSizeOfTeam);
185189

@@ -270,3 +274,18 @@ opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccess
270274
}
271275
}
272276

277+
void opdi::ParallelOmpLogic::internalBeginSkippedParallelRegion() {
278+
++ParallelOmpLogic::skipParallelRegion;
279+
}
280+
281+
void opdi::ParallelOmpLogic::internalEndSkippedParallelRegion() {
282+
--ParallelOmpLogic::skipParallelRegion;
283+
}
284+
285+
void opdi::ParallelOmpLogic::beginSkippedParallelRegion() {
286+
ParallelOmpLogic::internalBeginSkippedParallelRegion();
287+
}
288+
289+
void opdi::ParallelOmpLogic::endSkippedParallelRegion() {
290+
ParallelOmpLogic::internalEndSkippedParallelRegion();
291+
}

include/opdi/logic/omp/parallelOmpLogic.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ namespace opdi {
4242
bool isActiveParallelRegion;
4343
ImplicitTaskData* encounteringTaskData;
4444
void* encounteringTaskTape;
45+
void* encounteringTaskTapePosition;
4546
LogicInterface::AdjointAccessMode encounteringTaskAdjointAccessMode;
4647
std::vector<ImplicitTaskData*> childTaskData;
4748
};
@@ -53,8 +54,11 @@ namespace opdi {
5354

5455
private:
5556

56-
static int skipParallelHandling;
57-
#pragma omp threadprivate(skipParallelHandling)
57+
static int skipParallelRegion;
58+
#pragma omp threadprivate(skipParallelRegion)
59+
60+
static void internalBeginSkippedParallelRegion();
61+
static void internalEndSkippedParallelRegion();
5862

5963
static void reverseFunc(void* parallelData);
6064
static void deleteFunc(void* parallelData);
@@ -69,5 +73,8 @@ namespace opdi {
6973

7074
virtual void setAdjointAccessMode(AdjointAccessMode mode);
7175
virtual AdjointAccessMode getAdjointAccessMode() const;
76+
77+
virtual void beginSkippedParallelRegion();
78+
virtual void endSkippedParallelRegion();
7279
};
7380
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Point 0 :
2+
19.4044
3+
1.88152e+06
4+
-72463.7
5+
7448.71
6+
199533
7+
Point 1 :
8+
-22.677
9+
-2.57708e+06
10+
868308
11+
1.47639e+06
12+
-6.31777e+06
13+
Point 2 :
14+
-36.8071
15+
-805435
16+
-1.05138e+06
17+
-347778
18+
-1.36383e+06
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Point 0 :
2+
19.4044
3+
1.88152e+06
4+
-72463.7
5+
7448.71
6+
199533
7+
Point 1 :
8+
-22.677
9+
-2.57708e+06
10+
868308
11+
1.47639e+06
12+
-6.31777e+06
13+
Point 2 :
14+
-36.8071
15+
-805435
16+
-1.05138e+06
17+
-347778
18+
-1.36383e+06
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Point 0 :
2+
-2.30162
3+
159.64
4+
25.0196
5+
-654.451
6+
-132.877
7+
Point 1 :
8+
32.8041
9+
-12507.5
10+
-7865.59
11+
-128.973
12+
-3249.85
13+
Point 2 :
14+
-83.9704
15+
-1378.79
16+
474.401
17+
-1044.34
18+
-63769.4
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Point 0 :
2+
19.4044
3+
1.88152e+06
4+
-72463.7
5+
7448.71
6+
199533
7+
Point 1 :
8+
-22.677
9+
-2.57708e+06
10+
868308
11+
1.47639e+06
12+
-6.31777e+06
13+
Point 2 :
14+
-36.8071
15+
-805435
16+
-1.05138e+06
17+
-347778
18+
-1.36383e+06
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
Point 0 :
2+
19.4044
3+
1.88152e+06
4+
-72463.7
5+
7448.71
6+
199533
7+
Point 1 :
8+
-22.677
9+
-2.57708e+06
10+
868308
11+
1.47639e+06
12+
-6.31777e+06
13+
Point 2 :
14+
-36.8071
15+
-805435
16+
-1.05138e+06
17+
-347778
18+
-1.36383e+06

0 commit comments

Comments
 (0)