Skip to content

Commit 082a179

Browse files
committed
Recheck and revise treatment of nullptr tapes/layers.
Add assertions about nullptr layers. Tool returning nullptr tapes effectively deactivates OpDiLib. Nullptr tool effectively deactivates OpDiLib. Improve support for using OpenMP prior to initializing OpDiLib.
1 parent 3cbf9aa commit 082a179

8 files changed

Lines changed: 55 additions & 26 deletions

File tree

include/opdi/logic/omp/flushOmpLogic.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ namespace opdi {
5151
public:
5252

5353
virtual void addReverseFlush() {
54-
5554
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
5655

5756
Handle* handle = new Handle;

include/opdi/logic/omp/implicitTaskOmpLogic.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(bool isInitialImplicitTask
6060
// assume that the tape does not change. OpDiLib uses the initial implicit task's data primarily to track its
6161
// adjoint access mode.
6262
if (!isInitialImplicitTask) {
63+
64+
assert(tool != nullptr);
65+
6366
if (indexInTeam == 0) {
6467
if (parallelData->maximumSizeOfTeam < actualSizeOfTeam) {
6568
OPDI_ERROR("Actual number of threads exceeds maximum number of threads.");
@@ -68,6 +71,7 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(bool isInitialImplicitTask
6871
}
6972

7073
implicitTaskData->oldTape = tool->getThreadLocalTape();
74+
assert(implicitTaskData->oldTape != nullptr);
7175
implicitTaskData->parallelData = parallelData;
7276

7377
void* newTape = this->tapePool.getTape(parallelData->encounteringTaskTape, indexInTeam);
@@ -149,6 +153,8 @@ void opdi::ImplicitTaskOmpLogic::onImplicitTaskEnd(void* implicitTaskDataPtr) {
149153
#endif
150154

151155
if (!implicitTaskData->isInitialImplicitTask) {
156+
assert(tool != nullptr);
157+
152158
tool->setThreadLocalTape(implicitTaskData->oldTape);
153159

154160
implicitTaskData->positions.push_back(tool->allocPosition());
@@ -185,6 +191,8 @@ void opdi::ImplicitTaskOmpLogic::resetImplicitTask(void* position, opdi::LogicIn
185191
ImplicitTaskData* implicitTaskData = static_cast<ImplicitTaskData*>(implicitTaskDataPtr);
186192

187193
if (!implicitTaskData->isInitialImplicitTask) {
194+
assert(tool != nullptr);
195+
188196
assert(tool->comparePosition(implicitTaskData->positions.front(), position) <= 0);
189197

190198
while (tool->comparePosition(implicitTaskData->positions.back(), position) > 0) {

include/opdi/logic/omp/maskedOmpLogic.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,20 @@ void opdi::MaskedOmpLogic::deleteFunc(void* dataPtr) {
5454
void opdi::MaskedOmpLogic::onMasked(ScopeEndpoint endpoint) {
5555

5656
#if OPDI_OMP_LOGIC_INSTRUMENT
57-
if (tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
57+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
5858

59-
for (auto& instrument : ompLogicInstruments) {
60-
instrument->onMasked(endpoint);
61-
}
59+
for (auto& instrument : ompLogicInstruments) {
60+
instrument->onMasked(endpoint);
61+
}
6262

63-
Data* data = new Data;
64-
data->endpoint = endpoint;
63+
Data* data = new Data;
64+
data->endpoint = endpoint;
6565

66-
Handle* handle = new Handle;
67-
handle->data = static_cast<void*>(data);
68-
handle->reverseFunc = MaskedOmpLogic::reverseFunc;
69-
handle->deleteFunc = MaskedOmpLogic::deleteFunc;
70-
tool->pushExternalFunction(tool->getThreadLocalTape(), handle);
66+
Handle* handle = new Handle;
67+
handle->data = static_cast<void*>(data);
68+
handle->reverseFunc = MaskedOmpLogic::reverseFunc;
69+
handle->deleteFunc = MaskedOmpLogic::deleteFunc;
70+
tool->pushExternalFunction(tool->getThreadLocalTape(), handle);
7171
}
7272
#else
7373
OPDI_UNUSED(endpoint);

include/opdi/logic/omp/mutexOmpLogic.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,13 @@ void opdi::MutexOmpLogic::onMutexReleased(MutexKind mutexKind, WaitId waitId) {
212212
}
213213
}
214214

215-
// not thread safe! only use outside parallel regions
215+
// not thread-safe! only use outside of parallel regions
216216
void opdi::MutexOmpLogic::registerInactiveMutex(MutexKind mutexKind, WaitId waitId) {
217217
checkKind(mutexKind);
218218
this->recordings[mutexKind].inactive.insert(waitId);
219219
}
220220

221+
// not thread-safe! only use outside of parallel regions
221222
void opdi::MutexOmpLogic::prepareEvaluate() {
222223
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {
223224
MutexOmpLogic::evaluationCounters[mutexKind] = this->recordings[mutexKind].counters;
@@ -239,6 +240,7 @@ void opdi::MutexOmpLogic::prepareEvaluate() {
239240
#endif
240241
}
241242

243+
// not thread-safe! only use outside of parallel regions
242244
void opdi::MutexOmpLogic::postEvaluate() {
243245
#ifdef __SANITIZE_THREAD__
244246
/* destroy lock annotations */
@@ -253,12 +255,14 @@ void opdi::MutexOmpLogic::postEvaluate() {
253255
#endif
254256
}
255257

258+
// not thread-safe! only use outside of parallel regions
256259
void opdi::MutexOmpLogic::reset() {
257260
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {
258261
this->recordings[mutexKind].counters.clear();
259262
}
260263
}
261264

265+
// not thread-safe! only use outside of parallel regions
262266
void* opdi::MutexOmpLogic::exportState() {
263267
State* state = new State;
264268
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {
@@ -272,6 +276,7 @@ void opdi::MutexOmpLogic::freeState(void* statePtr) {
272276
delete state;
273277
}
274278

279+
// not thread safe! only use outside parallel regions
275280
void opdi::MutexOmpLogic::recoverState(void* statePtr) {
276281
State* state = static_cast<State*>(statePtr);
277282
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {

include/opdi/logic/omp/ompLogic.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ namespace opdi {
5757

5858
virtual void init() {
5959

60+
assert(backend != nullptr);
61+
6062
MutexOmpLogic::internalInit();
6163
ImplicitTaskOmpLogic::internalInit();
6264

@@ -71,6 +73,9 @@ namespace opdi {
7173
}
7274

7375
virtual void finalize() {
76+
77+
assert(backend != nullptr);
78+
7479
// finalize initial implicit task
7580
ImplicitTaskData* initialImplicitTaskData = static_cast<ImplicitTaskData*>(backend->getImplicitTaskData());
7681

include/opdi/logic/omp/parallelOmpLogic.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ int opdi::ParallelOmpLogic::skipParallelRegion = 0;
3939

4040
void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
4141

42+
assert(tool != nullptr);
43+
4244
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
4345

4446
#if OPDI_OMP_LOGIC_INSTRUMENT
@@ -105,6 +107,8 @@ void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
105107

106108
void opdi::ParallelOmpLogic::deleteFunc(void* parallelDataPtr) {
107109

110+
assert(tool != nullptr);
111+
108112
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
109113

110114
ParallelOmpLogic::internalBeginSkippedParallelRegion();
@@ -153,23 +157,25 @@ void opdi::ParallelOmpLogic::internalSetAdjointAccessMode(ImplicitTaskData* impl
153157
implicitTaskData->adjointAccessModes.back() = mode;
154158
}
155159
else {
156-
void* position = tool->allocPosition();
157-
tool->getTapePosition(implicitTaskData->newTape, position);
160+
if (tool != nullptr) {
161+
void* position = tool->allocPosition();
162+
tool->getTapePosition(implicitTaskData->newTape, position);
158163

159-
if (tool->comparePosition(implicitTaskData->positions.back(), position) == 0) {
160-
implicitTaskData->adjointAccessModes.back() = mode;
161-
tool->freePosition(position);
162-
}
163-
else {
164-
implicitTaskData->adjointAccessModes.push_back(mode);
165-
implicitTaskData->positions.push_back(position);
164+
if (tool->comparePosition(implicitTaskData->positions.back(), position) == 0) {
165+
implicitTaskData->adjointAccessModes.back() = mode;
166+
tool->freePosition(position);
167+
}
168+
else {
169+
implicitTaskData->adjointAccessModes.push_back(mode);
170+
implicitTaskData->positions.push_back(position);
171+
}
166172
}
167173
}
168174
}
169175

170176
void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTaskDataPtr, int maximumSizeOfTeam) {
171177

172-
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelRegion == 0) {
178+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelRegion == 0) {
173179

174180
ImplicitTaskData* encounteringTaskData = static_cast<ImplicitTaskData*>(encounteringTaskDataPtr);
175181

@@ -209,6 +215,8 @@ void opdi::ParallelOmpLogic::onParallelEnd(void* parallelDataPtr) {
209215

210216
if (parallelDataPtr != nullptr) {
211217

218+
assert(tool != nullptr);
219+
212220
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
213221

214222
#if OPDI_OMP_LOGIC_INSTRUMENT
@@ -251,6 +259,8 @@ void opdi::ParallelOmpLogic::onParallelEnd(void* parallelDataPtr) {
251259

252260
void opdi::ParallelOmpLogic::setAdjointAccessMode(opdi::LogicInterface::AdjointAccessMode mode) {
253261

262+
assert(backend != nullptr);
263+
254264
#if OPDI_VARIABLE_ADJOINT_ACCESS_MODE
255265
void* implicitTaskDataPtr = backend->getImplicitTaskData();
256266
if (implicitTaskDataPtr != nullptr) { // nullptr if called during tape evaluation
@@ -266,6 +276,9 @@ void opdi::ParallelOmpLogic::setAdjointAccessMode(opdi::LogicInterface::AdjointA
266276
}
267277

268278
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode() const {
279+
280+
assert(backend != nullptr);
281+
269282
void* implicitTaskDataPtr = backend->getImplicitTaskData();
270283
if (implicitTaskDataPtr != nullptr) { // nullptr if called during tape evaluation
271284
return internalGetAdjointAccessMode(static_cast<ImplicitTaskData*>(implicitTaskDataPtr));

include/opdi/logic/omp/syncRegionOmpLogic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ bool opdi::SyncRegionOmpLogic::requiresReverseBarrier(SyncRegionKind kind, Scope
8282

8383
void opdi::SyncRegionOmpLogic::onSyncRegion(SyncRegionKind kind, ScopeEndpoint endpoint) {
8484

85-
if (tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
85+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
8686

8787
#if OPDI_OMP_LOGIC_INSTRUMENT
8888
for (auto& instrument : ompLogicInstruments) {

include/opdi/logic/omp/workOmpLogic.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ void opdi::WorkOmpLogic::deleteFunc(void* dataPtr) {
5353
void opdi::WorkOmpLogic::onWork(WorksharingKind kind, ScopeEndpoint endpoint) {
5454

5555
#if OPDI_OMP_LOGIC_INSTRUMENT
56-
57-
if (tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
56+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
5857

5958
for (auto& instrument : ompLogicInstruments) {
6059
instrument->onWork(kind, endpoint);

0 commit comments

Comments
 (0)