Skip to content

Commit 27e9b10

Browse files
committed
Track adjoint access mode via task data.
1 parent b90bca3 commit 27e9b10

8 files changed

Lines changed: 79 additions & 61 deletions

include/opdi/logic/logicInterface.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ namespace opdi {
8383
virtual void recoverState(void* state) = 0;
8484

8585
virtual void setAdjointAccessMode(AdjointAccessMode adjointAccess) = 0;
86-
virtual AdjointAccessMode getAdjointAccessMode() = 0;
86+
virtual AdjointAccessMode getAdjointAccessMode() const = 0;
8787

8888
virtual void addReverseBarrier() = 0;
8989
virtual void addReverseFlush() = 0;

include/opdi/logic/omp/adjointAccessControl.hpp

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,27 @@
2727

2828
#include <list>
2929

30-
#include "../../backend/backendInterface.hpp"
31-
#include "../../tool/toolInterface.hpp"
32-
3330
#include "../logicInterface.hpp"
3431

3532
namespace opdi {
3633

37-
struct AdjointAccessControl : public virtual LogicInterface {
34+
// The initial implicit task has no associated task data or parallel region data. This is a drop-in to track the
35+
// initial implicit task's adjoint access mode. The initial implicit task uses the adjoint access mode implied by the
36+
// AD tool, but the access mode will be regarded in spawned parallel regions.
37+
struct InitialImplicitTaskAdjointAccessControl {
3838
public:
3939
using AdjointAccessMode = LogicInterface::AdjointAccessMode;
4040

4141
private:
42-
static std::list<AdjointAccessMode> currentAdjointAccess;
43-
#pragma omp threadprivate(currentAdjointAccess)
42+
static AdjointAccessMode currentAdjointAccess;
4443

4544
protected:
4645
AdjointAccessMode& currentMode() {
47-
return currentAdjointAccess.back();
46+
return currentAdjointAccess;
4847
}
4948

5049
const AdjointAccessMode& currentMode() const {
51-
return currentAdjointAccess.back();
52-
}
53-
54-
void pushMode(AdjointAccessMode adjointAccess) {
55-
currentAdjointAccess.push_back(adjointAccess);
56-
}
57-
58-
void popMode() {
59-
currentAdjointAccess.pop_back();
50+
return currentAdjointAccess;
6051
}
6152
};
6253
}

include/opdi/logic/omp/implicitTaskOmpLogic.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,16 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(int actualParallelism, int
6868

6969
tool->setThreadLocalTape(newTape);
7070

71-
AdjointAccessControl::pushMode(parallelData->parentAdjointAccessMode);
7271
data->adjointAccessModes.push_back(parallelData->parentAdjointAccessMode);
7372

73+
parallelData->childTasks[index] = data;
74+
7475
#if OPDI_OMP_LOGIC_INSTRUMENT
7576
for (auto& instrument : ompLogicInstruments) {
7677
instrument->onImplicitTaskBegin(data);
7778
}
7879
#endif
7980

80-
parallelData->childTasks[index] = data;
81-
8281
return data;
8382
}
8483

@@ -90,10 +89,6 @@ void opdi::ImplicitTaskOmpLogic::onImplicitTaskEnd(void* dataPtr) {
9089
if (dataPtr != nullptr) {
9190
Data* data = (Data*) dataPtr;
9291

93-
AdjointAccessMode lastAccessMode = AdjointAccessControl::currentMode();
94-
AdjointAccessControl::popMode();
95-
AdjointAccessControl::currentMode() = lastAccessMode;
96-
9792
tool->setThreadLocalTape(data->oldTape);
9893

9994
data->positions.push_back(tool->allocPosition());

include/opdi/logic/omp/implicitTaskOmpLogic.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,11 @@
3131

3232
#include "../logicInterface.hpp"
3333

34-
#include "adjointAccessControl.hpp"
3534
#include "parallelOmpLogic.hpp"
3635

3736
namespace opdi {
3837

39-
struct ImplicitTaskOmpLogic : public virtual LogicInterface,
40-
public virtual AdjointAccessControl {
38+
struct ImplicitTaskOmpLogic : public virtual LogicInterface {
4139
protected:
4240
TapePool tapePool;
4341

include/opdi/logic/omp/ompLogic.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@
2727
#include "adjointAccessControl.hpp"
2828

2929
#if OPDI_DEFAULT_ADJOINT_ACCESS_MODE == OPDI_ADJOINT_ACCESS_ATOMIC
30-
std::list<opdi::LogicInterface::AdjointAccessMode> opdi::AdjointAccessControl::currentAdjointAccess
31-
{opdi::LogicInterface::AdjointAccessMode::Atomic};
30+
opdi::LogicInterface::AdjointAccessMode opdi::InitialImplicitTaskAdjointAccessControl::currentAdjointAccess =
31+
opdi::LogicInterface::AdjointAccessMode::Atomic;
3232
#elif OPDI_DEFAULT_ADJOINT_ACCESS_MODE == OPDI_ADJOINT_ACCESS_CLASSICAL
33-
std::list<opdi::LogicInterface::AdjointAccessMode> opdi::AdjointAccessControl::currentAdjointAccess
34-
{opdi::LogicInterface::AdjointAccessMode::Classical};
33+
opdi::LogicInterface::AdjointAccessMode opdi::InitialImplicitTaskAdjointAccessControl::currentAdjointAccess =
34+
opdi::LogicInterface::AdjointAccessMode::Classical;
3535
#else
3636
#error Unknown adjoint access mode.
3737
#endif

include/opdi/logic/omp/ompLogic.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030

3131
#include "../logicInterface.hpp"
3232

33-
#include "adjointAccessControl.hpp"
3433
#include "flushOmpLogic.hpp"
3534
#include "implicitTaskOmpLogic.hpp"
3635
#include "masterOmpLogic.hpp"
@@ -48,7 +47,6 @@ namespace opdi {
4847
public ParallelOmpLogic,
4948
public SyncRegionOmpLogic,
5049
public WorkOmpLogic,
51-
public virtual AdjointAccessControl,
5250
public virtual LogicInterface
5351
{
5452
public:

include/opdi/logic/omp/parallelOmpLogic.cpp

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cassert>
2727
#include <omp.h>
2828

29+
#include "../../backend/backendInterface.hpp"
2930
#include "../../config.hpp"
3031
#include "../../tool/toolInterface.hpp"
3132

@@ -131,6 +132,31 @@ void opdi::ParallelOmpLogic::deleteFunc(void* dataPtr) {
131132
delete data;
132133
}
133134

135+
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::internalGetAdjointAccessMode(void* taskDataPtr) const {
136+
if (taskDataPtr == nullptr) { // initial implicit task
137+
return InitialImplicitTaskAdjointAccessControl::currentMode();
138+
}
139+
else {
140+
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(taskDataPtr);
141+
142+
return taskData->adjointAccessModes.back();
143+
}
144+
}
145+
146+
void opdi::ParallelOmpLogic::internalSetAdjointAccessMode(void* taskDataPtr, AdjointAccessMode mode) {
147+
148+
if (taskDataPtr == nullptr) { // initial implicit task
149+
InitialImplicitTaskAdjointAccessControl::currentMode() = mode;
150+
}
151+
else {
152+
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(taskDataPtr);
153+
154+
taskData->adjointAccessModes.push_back(mode);
155+
taskData->positions.push_back(tool->allocPosition());
156+
tool->getTapePosition(taskData->tape, taskData->positions.back());
157+
}
158+
}
159+
134160
void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTask, int maxThreads) {
135161

136162
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelHandling == 0) {
@@ -146,7 +172,7 @@ void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTask, int maxThr
146172
data->activeParallelRegion = tool->isActive(tool->getThreadLocalTape());
147173
data->parentTask = encounteringTask;
148174
data->parentTape = tool->getThreadLocalTape();
149-
data->parentAdjointAccessMode = getAdjointAccessMode();
175+
data->parentAdjointAccessMode = internalGetAdjointAccessMode(encounteringTask);
150176
data->childTasks.resize(maxThreads);
151177

152178
#if OPDI_OMP_LOGIC_INSTRUMENT
@@ -169,35 +195,34 @@ void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTask, int maxThr
169195

170196
void opdi::ParallelOmpLogic::onParallelEnd(void* dataPtr) {
171197

172-
Data* data = (Data*) dataPtr;
198+
Data* parallelData = (Data*) dataPtr;
173199

174-
if (data != nullptr) {
200+
if (parallelData != nullptr) {
175201

176202
#if OPDI_OMP_LOGIC_INSTRUMENT
177203
for (auto& instrument : ompLogicInstruments) {
178-
instrument->onParallelEnd(data);
204+
instrument->onParallelEnd(parallelData);
179205
}
180206
#endif
181207

182-
if (data->activeParallelRegion) {
208+
if (parallelData->activeParallelRegion) {
183209

184210
Handle* handle = new Handle;
185-
handle->data = (void*) data;
211+
handle->data = (void*) parallelData;
186212
handle->reverseFunc = ParallelOmpLogic::reverseFunc;
187213
handle->deleteFunc = ParallelOmpLogic::deleteFunc;
188214

189-
tool->pushExternalFunction(data->parentTape, handle);
215+
tool->pushExternalFunction(parallelData->parentTape, handle);
190216

191217
// do not delete data, it is deleted with the handle
192218

193-
// see if the adjoint access mode changed inside the parallel region
194-
// if so, we have to make sure that it carries over to the containing parallel region
195-
if (data->outerAdjointAccessMode != AdjointAccessControl::currentMode()) {
196-
this->setAdjointAccessMode(AdjointAccessControl::currentMode());
197-
}
219+
// transport adjoint access mode of thread 0 to parent task
220+
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(parallelData->childTasks[0]);
221+
222+
this->internalSetAdjointAccessMode(parallelData->parentTask, taskData->adjointAccessModes.back());
198223

199224
} else {
200-
deleteFunc(data);
225+
deleteFunc(parallelData);
201226
}
202227
}
203228
#if OPDI_OMP_LOGIC_INSTRUMENT
@@ -212,29 +237,37 @@ void opdi::ParallelOmpLogic::onParallelEnd(void* dataPtr) {
212237
void opdi::ParallelOmpLogic::setAdjointAccessMode(opdi::LogicInterface::AdjointAccessMode mode) {
213238

214239
#if OPDI_VARIABLE_ADJOINT_ACCESS_MODE
215-
AdjointAccessControl::currentMode() = mode;
216-
217240
#if OPDI_OMP_LOGIC_INSTRUMENT
218241
for (auto& instrument : ompLogicInstruments) {
219242
instrument->onSetAdjointAccessMode(mode);
220243
}
221244
#endif
222245

223-
Data* data = (Data*) backend->getParallelData();
224-
int threadNum = omp_get_thread_num();
225-
226-
if (data != nullptr) {
227-
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(data->childTasks[threadNum]);
228-
assert(tool->getThreadLocalTape() == taskData->tape);
246+
void* parallelDataPtr = backend->getParallelData();
247+
void* taskDataPtr = nullptr;
229248

230-
taskData->adjointAccessModes.push_back(mode);
231-
taskData->positions.push_back(tool->allocPosition());
232-
tool->getTapePosition(taskData->tape, taskData->positions.back());
249+
if (parallelDataPtr != nullptr) { // not initial implicit task
250+
Data* parallelData = reinterpret_cast<Data*>(parallelDataPtr);
251+
taskDataPtr = parallelData->childTasks[omp_get_thread_num()];
233252
}
253+
254+
assert(taskDataPtr == nullptr || ((ImplicitTaskOmpLogic::Data*)taskDataPtr)->tape == tool->getThreadLocalTape());
255+
256+
internalSetAdjointAccessMode(taskDataPtr, mode);
234257
#endif
235258
}
236259

237-
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode() {
238-
return AdjointAccessControl::currentMode();
260+
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode() const {
261+
void* parallelDataPtr = backend->getParallelData();
262+
void* taskDataPtr = nullptr;
263+
264+
if (parallelDataPtr != nullptr) { // not initial implicit task
265+
Data* parallelData = reinterpret_cast<Data*>(parallelDataPtr);
266+
taskDataPtr = parallelData->childTasks[omp_get_thread_num()];
267+
}
268+
269+
assert(taskDataPtr == nullptr || ((ImplicitTaskOmpLogic::Data*)taskDataPtr)->tape == tool->getThreadLocalTape());
270+
271+
return internalGetAdjointAccessMode(taskDataPtr);
239272
}
240273

include/opdi/logic/omp/parallelOmpLogic.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
namespace opdi {
3737

38-
struct ParallelOmpLogic : public virtual AdjointAccessControl,
38+
struct ParallelOmpLogic : public InitialImplicitTaskAdjointAccessControl,
3939
public virtual LogicInterface,
4040
public virtual TapePool {
4141
public:
@@ -61,12 +61,15 @@ namespace opdi {
6161
static void reverseFunc(void* dataPtr);
6262
static void deleteFunc(void* dataPtr);
6363

64+
AdjointAccessMode internalGetAdjointAccessMode(void* taskDataPtr) const;
65+
void internalSetAdjointAccessMode(void* taskDataPtr, AdjointAccessMode mode);
66+
6467
public:
6568

6669
virtual void* onParallelBegin(void* encounteringTask, int maxThreads);
6770
virtual void onParallelEnd(void* dataPtr);
6871

6972
virtual void setAdjointAccessMode(AdjointAccessMode mode);
70-
virtual AdjointAccessMode getAdjointAccessMode();
73+
virtual AdjointAccessMode getAdjointAccessMode() const;
7174
};
7275
}

0 commit comments

Comments
 (0)