Skip to content

Commit 8f48e7a

Browse files
committed
Track adjoint access mode entirely via task data.
1 parent 6d84a76 commit 8f48e7a

5 files changed

Lines changed: 10 additions & 101 deletions

File tree

include/opdi/logic/omp/adjointAccessControl.hpp

Lines changed: 0 additions & 53 deletions
This file was deleted.

include/opdi/logic/omp/implicitTaskOmpLogic.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(bool initialImplicitTask,
8080

8181
tool->setThreadLocalTape(newTape);
8282

83-
AdjointAccessControl::pushMode(parallelData->parentAdjointAccessMode);
8483
data->adjointAccessModes.push_back(parallelData->parentAdjointAccessMode);
8584

8685
parallelData->childTasks[index] = data;
@@ -116,10 +115,6 @@ void opdi::ImplicitTaskOmpLogic::onImplicitTaskEnd(void* dataPtr) {
116115
}
117116
#endif
118117

119-
AdjointAccessMode lastAccessMode = AdjointAccessControl::currentMode();
120-
AdjointAccessControl::popMode();
121-
AdjointAccessControl::currentMode() = lastAccessMode;
122-
123118
if (!data->initialImplicitTask) {
124119
tool->setThreadLocalTape(data->oldTape);
125120

include/opdi/logic/omp/ompLogic.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,11 @@
2424
*/
2525

2626
#include "instrument/ompLogicInstrumentInterface.hpp"
27-
#include "adjointAccessControl.hpp"
2827

2928
#if OPDI_DEFAULT_ADJOINT_ACCESS_MODE == OPDI_ADJOINT_ACCESS_ATOMIC
30-
opdi::LogicInterface::AdjointAccessMode opdi::InitialImplicitTaskAdjointAccessControl::currentAdjointAccess =
31-
opdi::LogicInterface::AdjointAccessMode::Atomic;
32-
3329
opdi::LogicInterface::AdjointAccessMode const opdi::ImplicitTaskOmpLogic::defaultAdjointAccessMode
3430
= opdi::LogicInterface::AdjointAccessMode::Atomic;
3531
#elif OPDI_DEFAULT_ADJOINT_ACCESS_MODE == OPDI_ADJOINT_ACCESS_CLASSICAL
36-
opdi::LogicInterface::AdjointAccessMode opdi::InitialImplicitTaskAdjointAccessControl::currentAdjointAccess
37-
= opdi::LogicInterface::AdjointAccessMode::Classical;
38-
3932
opdi::LogicInterface::AdjointAccessMode const opdi::ImplicitTaskOmpLogic::defaultAdjointAccessMode
4033
= opdi::LogicInterface::AdjointAccessMode::Classical;
4134
#else

include/opdi/logic/omp/parallelOmpLogic.cpp

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -133,24 +133,17 @@ void opdi::ParallelOmpLogic::deleteFunc(void* dataPtr) {
133133
}
134134

135135
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::internalGetAdjointAccessMode(void* taskDataPtr) const {
136-
if (taskDataPtr == nullptr) { // initial implicit task
137-
return InitialImplicitTaskAdjointAccessControl::currentMode();
138-
}
139-
else {
140-
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(taskDataPtr);
141-
142-
return taskData->adjointAccessModes.back();
143-
}
136+
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(taskDataPtr);
137+
return taskData->adjointAccessModes.back();
144138
}
145139

146140
void opdi::ParallelOmpLogic::internalSetAdjointAccessMode(void* taskDataPtr, AdjointAccessMode mode) {
141+
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(taskDataPtr);
147142

148-
if (taskDataPtr == nullptr) { // initial implicit task
149-
InitialImplicitTaskAdjointAccessControl::currentMode() = mode;
143+
if (taskData->initialImplicitTask) {
144+
taskData->adjointAccessModes.back() = mode;
150145
}
151146
else {
152-
ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<ImplicitTaskOmpLogic::Data*>(taskDataPtr);
153-
154147
taskData->adjointAccessModes.push_back(mode);
155148
taskData->positions.push_back(tool->allocPosition());
156149
tool->getTapePosition(taskData->tape, taskData->positions.back());
@@ -245,31 +238,15 @@ void opdi::ParallelOmpLogic::setAdjointAccessMode(opdi::LogicInterface::AdjointA
245238
}
246239
#endif
247240

248-
void* parallelDataPtr = backend->getParallelData();
249-
void* taskDataPtr = nullptr;
250-
251-
if (parallelDataPtr != nullptr) { // not initial implicit task
252-
Data* parallelData = reinterpret_cast<Data*>(parallelDataPtr);
253-
taskDataPtr = parallelData->childTasks[omp_get_thread_num()];
254-
}
255-
256-
assert(taskDataPtr == nullptr || ((ImplicitTaskOmpLogic::Data*)taskDataPtr)->tape == tool->getThreadLocalTape());
257-
241+
void* taskDataPtr = backend->getTaskData();
242+
assert(taskDataPtr != nullptr);
258243
internalSetAdjointAccessMode(taskDataPtr, mode);
259244
#endif
260245
}
261246

262247
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode() const {
263-
void* parallelDataPtr = backend->getParallelData();
264-
void* taskDataPtr = nullptr;
265-
266-
if (parallelDataPtr != nullptr) { // not initial implicit task
267-
Data* parallelData = reinterpret_cast<Data*>(parallelDataPtr);
268-
taskDataPtr = parallelData->childTasks[omp_get_thread_num()];
269-
}
270-
271-
assert(taskDataPtr == nullptr || ((ImplicitTaskOmpLogic::Data*)taskDataPtr)->tape == tool->getThreadLocalTape());
272-
248+
void* taskDataPtr = backend->getTaskData();
249+
assert(taskDataPtr != nullptr);
273250
return internalGetAdjointAccessMode(taskDataPtr);
274251
}
275252

include/opdi/logic/omp/parallelOmpLogic.hpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,9 @@
3131

3232
#include "../logicInterface.hpp"
3333

34-
#include "adjointAccessControl.hpp"
35-
3634
namespace opdi {
3735

38-
struct ParallelOmpLogic : public InitialImplicitTaskAdjointAccessControl,
39-
public virtual LogicInterface,
36+
struct ParallelOmpLogic : public virtual LogicInterface,
4037
public virtual TapePool {
4138
public:
4239

0 commit comments

Comments
 (0)