@@ -40,45 +40,62 @@ void opdi::ImplicitTaskOmpLogic::internalFinalize() {
4040 this ->tapePool .finalize ();
4141}
4242
43- void * opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin (int actualParallelism, int index, void * parallelDataPtr) {
43+ void * opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin (bool initialImplicitTask, int actualParallelism, int index,
44+ void * parallelDataPtr) {
4445
4546 ParallelData* parallelData = (ParallelData*) parallelDataPtr;
4647
47- if (parallelData != nullptr ) {
48- if (index == 0 ) {
49- parallelData->actualThreads = actualParallelism;
50- }
48+ // check if the handling of the parallel region was skipped
49+ if (parallelData != nullptr || initialImplicitTask) {
5150
5251 Data* data = new Data;
52+ data->initialImplicitTask = initialImplicitTask;
5353 data->level = omp_get_level ();
5454 data->index = index;
55- data->oldTape = tool->getThreadLocalTape ();
56- data->parallelData = parallelData;
5755
58- void * newTape = this ->tapePool .getTape (parallelData->parentTape , index);
56+ // OpDiLib does not interfere with the initial implicit task AD-wise, e.g., does not track its tape / does not assume
57+ // that the tape does not change. OpDiLib uses the initial implicit task's data primarily to track its adjoint access
58+ // mode.
59+ if (!initialImplicitTask) {
60+ if (index == 0 ) {
61+ parallelData->actualThreads = actualParallelism;
62+ }
5963
60- if (parallelData->activeParallelRegion ) {
61- tool->setActive (newTape, true );
62- }
64+ data->oldTape = tool->getThreadLocalTape ();
65+ data->parallelData = parallelData;
6366
64- data-> tape = newTape;
67+ void * newTape = this -> tapePool . getTape (parallelData-> parentTape , index) ;
6568
66- data->positions .push_back (tool->allocPosition ());
67- tool->getTapePosition (newTape, data->positions .back ());
69+ if (parallelData->activeParallelRegion ) {
70+ tool->setActive (newTape, true );
71+ }
6872
69- tool-> setThreadLocalTape ( newTape) ;
73+ data-> tape = newTape;
7074
71- AdjointAccessControl::pushMode (parallelData->parentAdjointAccessMode );
72- data->adjointAccessModes .push_back (parallelData->parentAdjointAccessMode );
75+ data->positions .push_back (tool->allocPosition ());
76+ tool->getTapePosition (newTape, data->positions .back ());
77+
78+ tool->setThreadLocalTape (newTape);
79+
80+ AdjointAccessControl::pushMode (parallelData->parentAdjointAccessMode );
81+ data->adjointAccessModes .push_back (parallelData->parentAdjointAccessMode );
82+
83+ parallelData->childTasks [index] = data;
84+ }
85+ else {
86+ data->oldTape = nullptr ;
87+ data->tape = nullptr ;
88+ data->parallelData = nullptr ;
89+
90+ data->adjointAccessModes .push_back (ImplicitTaskOmpLogic::defaultAdjointAccessMode);
91+ }
7392
7493 #if OPDI_OMP_LOGIC_INSTRUMENT
7594 for (auto & instrument : ompLogicInstruments) {
7695 instrument->onImplicitTaskBegin (data);
7796 }
7897 #endif
7998
80- parallelData->childTasks [index] = data;
81-
8299 return data;
83100 }
84101
@@ -90,35 +107,41 @@ void opdi::ImplicitTaskOmpLogic::onImplicitTaskEnd(void* dataPtr) {
90107 if (dataPtr != nullptr ) {
91108 Data* data = (Data*) dataPtr;
92109
110+ #if OPDI_OMP_LOGIC_INSTRUMENT
111+ for (auto & instrument : ompLogicInstruments) {
112+ instrument->onImplicitTaskEnd (data);
113+ }
114+ #endif
115+
93116 AdjointAccessMode lastAccessMode = AdjointAccessControl::currentMode ();
94117 AdjointAccessControl::popMode ();
95118 AdjointAccessControl::currentMode () = lastAccessMode;
96119
97- tool->setThreadLocalTape (data->oldTape );
120+ if (!data->initialImplicitTask ) {
121+ tool->setThreadLocalTape (data->oldTape );
98122
99- data->positions .push_back (tool->allocPosition ());
100- tool->getTapePosition (data->tape , data->positions .back ());
123+ data->positions .push_back (tool->allocPosition ());
124+ tool->getTapePosition (data->tape , data->positions .back ());
101125
102- if (!data->parallelData ->activeParallelRegion ) {
103- if (tool->comparePosition (data->positions .front (), data->positions .back ()) != 0 ) {
104- OPDI_WARNING (" Something became active during a passive parallel region. This is not supported and will not be " ,
105- " differentiated correctly." );
126+ if (!data->parallelData ->activeParallelRegion ) {
127+ if (tool->comparePosition (data->positions .front (), data->positions .back ()) != 0 ) {
128+ OPDI_WARNING (" Something became active during a passive parallel region. This is not supported and will not be " ,
129+ " differentiated correctly." );
130+ }
106131 }
107- }
108132
109- #if OPDI_OMP_LOGIC_INSTRUMENT
110- for (auto & instrument : ompLogicInstruments) {
111- instrument->onImplicitTaskEnd (data);
112- }
113- #endif
133+ tool->setActive (data->tape , false );
114134
115- tool->setActive (data->tape , false );
135+ // ensure that the most recent activity change *per thread* reflects the current activity
136+ if (data->oldTape == data->parallelData ->parentTape && data->parallelData ->activeParallelRegion ) {
137+ tool->setActive (data->oldTape , true );
138+ }
116139
117- // ensure that the most recent activity change *per thread* reflects the current activity
118- if (data->oldTape == data->parallelData ->parentTape && data->parallelData ->activeParallelRegion ) {
119- tool->setActive (data->oldTape , true );
140+ // do not delete data, it is deleted as part of parallel regions
141+ }
142+ else {
143+ // delete task data, there is no parallel region to do so
144+ delete data;
120145 }
121-
122- // do not delete data, it is deleted as part of parallel regions
123146 }
124147}
0 commit comments