Skip to content

Commit d301a8e

Browse files
committed
Explicitly specify adjoint access mode for recovery.
Ensure correct behaviour when calling resetTask in the initial implicit task.
1 parent 20669b1 commit d301a8e

3 files changed

Lines changed: 15 additions & 18 deletions

File tree

include/opdi/logic/logicInterface.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,10 @@ namespace opdi {
8383
virtual void freeState(void* state) = 0;
8484
virtual void recoverState(void* state) = 0;
8585

86-
virtual void setAdjointAccessMode(AdjointAccessMode adjointAccess) = 0;
86+
virtual void setAdjointAccessMode(AdjointAccessMode mode) = 0;
8787
virtual AdjointAccessMode getAdjointAccessMode() const = 0;
8888

89-
/** @brief Complement positional tape resets by OpDiLib-specific cleanup.
90-
*
91-
* Cleanup performed for the task of the calling thread. Reverts changes of the adjoint access mode, up to the
92-
* last adjoint access mode that was set at this position.
93-
*/
94-
virtual void resetTask(void* position) = 0;
89+
virtual void resetTask(void* position, AdjointAccessMode mode) = 0;
9590

9691
virtual void addReverseBarrier() = 0;
9792
virtual void addReverseFlush() = 0;

include/opdi/logic/omp/implicitTaskOmpLogic.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -144,20 +144,22 @@ void opdi::ImplicitTaskOmpLogic::onImplicitTaskEnd(void* dataPtr) {
144144
}
145145
}
146146

147-
void opdi::ImplicitTaskOmpLogic::resetTask(void* position) {
147+
void opdi::ImplicitTaskOmpLogic::resetTask(void* position, opdi::LogicInterface::AdjointAccessMode mode) {
148148

149-
void* parallelDataPtr = backend->getParallelData();
149+
void* taskDataPtr = backend->getTaskData();
150150

151-
if (parallelDataPtr) {
152-
opdi::ParallelOmpLogic::Data* parallelData = reinterpret_cast<opdi::ParallelOmpLogic::Data*>(parallelDataPtr);
151+
if (taskDataPtr != nullptr) {
152+
opdi::ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast<opdi::ImplicitTaskOmpLogic::Data*>(taskDataPtr);
153153

154-
Data* taskData = reinterpret_cast<Data*>(parallelData->childTasks[omp_get_thread_num()]);
154+
if (!taskData->initialImplicitTask) {
155+
assert(tool->comparePosition(taskData->positions.front(), position) <= 0);
155156

156-
assert(tool->comparePosition(taskData->positions.front(), position) <= 0);
157-
158-
while (tool->comparePosition(taskData->positions.back(), position) > 0) {
159-
taskData->positions.pop_back();
160-
taskData->adjointAccessModes.pop_back();
157+
while (tool->comparePosition(taskData->positions.back(), position) > 0) {
158+
taskData->positions.pop_back();
159+
taskData->adjointAccessModes.pop_back();
160+
}
161161
}
162+
163+
taskData->adjointAccessModes.back() = mode;
162164
}
163165
}

include/opdi/logic/omp/implicitTaskOmpLogic.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,6 @@ namespace opdi {
6666
void* parallelDataPtr);
6767
virtual void onImplicitTaskEnd(void* dataPtr);
6868

69-
virtual void resetTask(void* position);
69+
virtual void resetTask(void* position, AdjointAccessMode mode);
7070
};
7171
}

0 commit comments

Comments
 (0)