2626#include < cassert>
2727#include < omp.h>
2828
29+ #include " ../../backend/backendInterface.hpp"
2930#include " ../../config.hpp"
3031#include " ../../tool/toolInterface.hpp"
3132
@@ -131,6 +132,31 @@ void opdi::ParallelOmpLogic::deleteFunc(void* dataPtr) {
131132 delete data;
132133}
133134
135+ opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::internalGetAdjointAccessMode (void * taskDataPtr) const {
136+ if (taskDataPtr == nullptr ) { // initial implicit task
137+ return InitialImplicitTaskAdjointAccessControl::currentMode ();
138+ }
139+ else {
140+ ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast <ImplicitTaskOmpLogic::Data*>(taskDataPtr);
141+
142+ return taskData->adjointAccessModes .back ();
143+ }
144+ }
145+
146+ void opdi::ParallelOmpLogic::internalSetAdjointAccessMode (void * taskDataPtr, AdjointAccessMode mode) {
147+
148+ if (taskDataPtr == nullptr ) { // initial implicit task
149+ InitialImplicitTaskAdjointAccessControl::currentMode () = mode;
150+ }
151+ else {
152+ ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast <ImplicitTaskOmpLogic::Data*>(taskDataPtr);
153+
154+ taskData->adjointAccessModes .push_back (mode);
155+ taskData->positions .push_back (tool->allocPosition ());
156+ tool->getTapePosition (taskData->tape , taskData->positions .back ());
157+ }
158+ }
159+
134160void * opdi::ParallelOmpLogic::onParallelBegin (void * encounteringTask, int maxThreads) {
135161
136162 if (tool->getThreadLocalTape () != nullptr && ParallelOmpLogic::skipParallelHandling == 0 ) {
@@ -146,7 +172,7 @@ void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTask, int maxThr
146172 data->activeParallelRegion = tool->isActive (tool->getThreadLocalTape ());
147173 data->parentTask = encounteringTask;
148174 data->parentTape = tool->getThreadLocalTape ();
149- data->parentAdjointAccessMode = getAdjointAccessMode ( );
175+ data->parentAdjointAccessMode = internalGetAdjointAccessMode (encounteringTask );
150176 data->childTasks .resize (maxThreads);
151177
152178 #if OPDI_OMP_LOGIC_INSTRUMENT
@@ -169,35 +195,34 @@ void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTask, int maxThr
169195
170196void opdi::ParallelOmpLogic::onParallelEnd (void * dataPtr) {
171197
172- Data* data = (Data*) dataPtr;
198+ Data* parallelData = (Data*) dataPtr;
173199
174- if (data != nullptr ) {
200+ if (parallelData != nullptr ) {
175201
176202 #if OPDI_OMP_LOGIC_INSTRUMENT
177203 for (auto & instrument : ompLogicInstruments) {
178- instrument->onParallelEnd (data );
204+ instrument->onParallelEnd (parallelData );
179205 }
180206 #endif
181207
182- if (data ->activeParallelRegion ) {
208+ if (parallelData ->activeParallelRegion ) {
183209
184210 Handle* handle = new Handle;
185- handle->data = (void *) data ;
211+ handle->data = (void *) parallelData ;
186212 handle->reverseFunc = ParallelOmpLogic::reverseFunc;
187213 handle->deleteFunc = ParallelOmpLogic::deleteFunc;
188214
189- tool->pushExternalFunction (data ->parentTape , handle);
215+ tool->pushExternalFunction (parallelData ->parentTape , handle);
190216
191217 // do not delete data, it is deleted with the handle
192218
193- // see if the adjoint access mode changed inside the parallel region
194- // if so, we have to make sure that it carries over to the containing parallel region
195- if (data->outerAdjointAccessMode != AdjointAccessControl::currentMode ()) {
196- this ->setAdjointAccessMode (AdjointAccessControl::currentMode ());
197- }
219+ // transport adjoint access mode of thread 0 to parent task
220+ ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast <ImplicitTaskOmpLogic::Data*>(parallelData->childTasks [0 ]);
221+
222+ this ->internalSetAdjointAccessMode (parallelData->parentTask , taskData->adjointAccessModes .back ());
198223
199224 } else {
200- deleteFunc (data );
225+ deleteFunc (parallelData );
201226 }
202227 }
203228 #if OPDI_OMP_LOGIC_INSTRUMENT
@@ -212,29 +237,37 @@ void opdi::ParallelOmpLogic::onParallelEnd(void* dataPtr) {
212237void opdi::ParallelOmpLogic::setAdjointAccessMode (opdi::LogicInterface::AdjointAccessMode mode) {
213238
214239 #if OPDI_VARIABLE_ADJOINT_ACCESS_MODE
215- AdjointAccessControl::currentMode () = mode;
216-
217240 #if OPDI_OMP_LOGIC_INSTRUMENT
218241 for (auto & instrument : ompLogicInstruments) {
219242 instrument->onSetAdjointAccessMode (mode);
220243 }
221244 #endif
222245
223- Data* data = (Data*) backend->getParallelData ();
224- int threadNum = omp_get_thread_num ();
225-
226- if (data != nullptr ) {
227- ImplicitTaskOmpLogic::Data* taskData = reinterpret_cast <ImplicitTaskOmpLogic::Data*>(data->childTasks [threadNum]);
228- assert (tool->getThreadLocalTape () == taskData->tape );
246+ void * parallelDataPtr = backend->getParallelData ();
247+ void * taskDataPtr = nullptr ;
229248
230- taskData-> adjointAccessModes . push_back (mode);
231- taskData-> positions . push_back (tool-> allocPosition () );
232- tool-> getTapePosition (taskData-> tape , taskData-> positions . back ()) ;
249+ if (parallelDataPtr != nullptr ) { // not initial implicit task
250+ Data* parallelData = reinterpret_cast <Data*>(parallelDataPtr );
251+ taskDataPtr = parallelData-> childTasks [ omp_get_thread_num ()] ;
233252 }
253+
254+ assert (taskDataPtr == nullptr || ((ImplicitTaskOmpLogic::Data*)taskDataPtr)->tape == tool->getThreadLocalTape ());
255+
256+ internalSetAdjointAccessMode (taskDataPtr, mode);
234257 #endif
235258}
236259
237- opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode () {
238- return AdjointAccessControl::currentMode ();
260+ opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode () const {
261+ void * parallelDataPtr = backend->getParallelData ();
262+ void * taskDataPtr = nullptr ;
263+
264+ if (parallelDataPtr != nullptr ) { // not initial implicit task
265+ Data* parallelData = reinterpret_cast <Data*>(parallelDataPtr);
266+ taskDataPtr = parallelData->childTasks [omp_get_thread_num ()];
267+ }
268+
269+ assert (taskDataPtr == nullptr || ((ImplicitTaskOmpLogic::Data*)taskDataPtr)->tape == tool->getThreadLocalTape ());
270+
271+ return internalGetAdjointAccessMode (taskDataPtr);
239272}
240273
0 commit comments