Skip to content

Commit e1c8335

Browse files
committed
Tool layer choices that temporarily disable OpDiLib.
Add EmptyTool that effectively disables OpDiLib. Nullptr tool effectively disables OpDiLib. Add test drivers for primal and forward AD computations in the presence of OpDiLib. Adapt test system accordingly. Revise treatment of nullptr and add associated assertions. Merge branch 'feature/emptyTool' into develop
2 parents af5d6a9 + 8fff904 commit e1c8335

131 files changed

Lines changed: 1858 additions & 52 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/opdi/backend/macro/macros.hpp

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,43 +97,57 @@
9797
{ \
9898
bool constexpr opdiInternalBarrierIndicator = true; \
9999
bool constexpr opdiInternalBroadcastIndicator = true; \
100-
void* opdiInternalTapePosition1 = opdi::tool->allocPosition(); \
101-
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1); \
100+
void* opdiInternalTapePosition1 = nullptr; \
101+
if (opdi::tool != nullptr) { \
102+
opdiInternalTapePosition1 = opdi::tool->allocPosition(); \
103+
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1); \
104+
} \
102105
/* broadcast-related barrier */ \
103106
opdi::logic->onSyncRegion(opdi::LogicInterface::SyncRegionKind::BarrierImplementation, \
104107
opdi::LogicInterface::ScopeEndpoint::Begin); \
105108
opdi::logic->onSyncRegion(opdi::LogicInterface::SyncRegionKind::BarrierImplementation, \
106109
opdi::LogicInterface::ScopeEndpoint::End); \
107-
void* opdiInternalTapePosition2 = opdi::tool->allocPosition(); \
108-
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition2); \
110+
void* opdiInternalTapePosition2 = nullptr; \
111+
if (opdi::tool != nullptr) { \
112+
opdiInternalTapePosition2 = opdi::tool->allocPosition(); \
113+
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition2); \
114+
} \
109115
opdi::ImplicitBarrierTools::beginRegionWithImplicitBarrier(); \
110116
{ \
111117
opdi::SingleProbe localSingleProbe; /* worksharing events */ \
112118
OPDI_PRAGMA(omp single __VA_ARGS__) \
113119
{ \
114120
/* delay broadcast-related barrier for executor */ \
115-
opdi::tool->erase(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1, opdiInternalTapePosition2);
121+
if (opdi::tool != nullptr) \
122+
opdi::tool->erase(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1, opdiInternalTapePosition2);
116123

117124
#define OPDI_SINGLE_COPYPRIVATE_NOWAIT(...) \
118125
{ \
119126
bool constexpr opdiInternalBarrierIndicator = false; \
120127
bool constexpr opdiInternalBroadcastIndicator = true; \
121-
void* opdiInternalTapePosition1 = opdi::tool->allocPosition(); \
122-
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1); \
128+
void* opdiInternalTapePosition1 = nullptr; \
129+
if (opdi::tool != nullptr) { \
130+
opdiInternalTapePosition1 = opdi::tool->allocPosition(); \
131+
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1); \
132+
} \
123133
/* broadcast-related barrier */ \
124134
opdi::logic->onSyncRegion(opdi::LogicInterface::SyncRegionKind::BarrierImplementation, \
125135
opdi::LogicInterface::ScopeEndpoint::Begin); \
126136
opdi::logic->onSyncRegion(opdi::LogicInterface::SyncRegionKind::BarrierImplementation, \
127137
opdi::LogicInterface::ScopeEndpoint::End); \
128-
void* opdiInternalTapePosition2 = opdi::tool->allocPosition(); \
129-
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition2); \
138+
void* opdiInternalTapePosition2 = nullptr; \
139+
if (opdi::tool != nullptr) { \
140+
opdiInternalTapePosition2 = opdi::tool->allocPosition(); \
141+
opdi::tool->getTapePosition(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition2); \
142+
} \
130143
opdi::ImplicitBarrierTools::beginRegionWithImplicitBarrier(); \
131144
{ \
132145
opdi::SingleProbe localSingleProbe; \
133146
OPDI_PRAGMA(omp single nowait __VA_ARGS__) \
134147
{ \
135148
/* delay broadcast-related barrier for executor */ \
136-
opdi::tool->erase(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1, opdiInternalTapePosition2);
149+
if (opdi::tool != nullptr) \
150+
opdi::tool->erase(opdi::tool->getThreadLocalTape(), opdiInternalTapePosition1, opdiInternalTapePosition2);
137151

138152
#define OPDI_END_SINGLE \
139153
/* broadcast-related barrier */ \
@@ -145,7 +159,7 @@
145159
} \
146160
} \
147161
} \
148-
if (opdiInternalBroadcastIndicator) { \
162+
if (opdi::tool != nullptr && opdiInternalBroadcastIndicator) { \
149163
opdi::tool->freePosition(opdiInternalTapePosition1); \
150164
opdi::tool->freePosition(opdiInternalTapePosition2); \
151165
} \

include/opdi/logic/omp/flushOmpLogic.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ namespace opdi {
5151
public:
5252

5353
virtual void addReverseFlush() {
54-
5554
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
5655

5756
Handle* handle = new Handle;

include/opdi/logic/omp/implicitTaskOmpLogic.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(bool isInitialImplicitTask
6060
// assume that the tape does not change. OpDiLib uses the initial implicit task's data primarily to track its
6161
// adjoint access mode.
6262
if (!isInitialImplicitTask) {
63+
64+
assert(tool != nullptr);
65+
6366
if (indexInTeam == 0) {
6467
if (parallelData->maximumSizeOfTeam < actualSizeOfTeam) {
6568
OPDI_ERROR("Actual number of threads exceeds maximum number of threads.");
@@ -68,6 +71,7 @@ void* opdi::ImplicitTaskOmpLogic::onImplicitTaskBegin(bool isInitialImplicitTask
6871
}
6972

7073
implicitTaskData->oldTape = tool->getThreadLocalTape();
74+
assert(implicitTaskData->oldTape != nullptr);
7175
implicitTaskData->parallelData = parallelData;
7276

7377
void* newTape = this->tapePool.getTape(parallelData->encounteringTaskTape, indexInTeam);
@@ -149,6 +153,8 @@ void opdi::ImplicitTaskOmpLogic::onImplicitTaskEnd(void* implicitTaskDataPtr) {
149153
#endif
150154

151155
if (!implicitTaskData->isInitialImplicitTask) {
156+
assert(tool != nullptr);
157+
152158
tool->setThreadLocalTape(implicitTaskData->oldTape);
153159

154160
implicitTaskData->positions.push_back(tool->allocPosition());
@@ -185,6 +191,8 @@ void opdi::ImplicitTaskOmpLogic::resetImplicitTask(void* position, opdi::LogicIn
185191
ImplicitTaskData* implicitTaskData = static_cast<ImplicitTaskData*>(implicitTaskDataPtr);
186192

187193
if (!implicitTaskData->isInitialImplicitTask) {
194+
assert(tool != nullptr);
195+
188196
assert(tool->comparePosition(implicitTaskData->positions.front(), position) <= 0);
189197

190198
while (tool->comparePosition(implicitTaskData->positions.back(), position) > 0) {

include/opdi/logic/omp/instrument/ompLogicOutputInstrument.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,23 @@ namespace opdi {
4141
}
4242

4343
virtual void reverseImplicitTaskBegin(ImplicitTaskData* data) {
44+
assert(tool != nullptr);
4445
TapedOutput::print("R IMTB l", data->level,
4546
"t", data->indexInTeam,
4647
"tape", data->newTape,
4748
"pos", tool->positionToString(data->positions.back()));
4849
}
4950

5051
virtual void reverseImplicitTaskEnd(ImplicitTaskData* data) {
52+
assert(tool != nullptr);
5153
TapedOutput::print("R IMTE l", data->level,
5254
"t", data->indexInTeam,
5355
"tape", data->newTape,
5456
"pos", tool->positionToString(data->positions.front()));
5557
}
5658

5759
virtual void reverseImplicitTaskPart(ImplicitTaskData* data, std::size_t part) {
60+
assert(tool != nullptr);
5861
TapedOutput::print("R IMTP l", data->level,
5962
"t", data->indexInTeam,
6063
"tape", data->newTape,
@@ -68,6 +71,7 @@ namespace opdi {
6871
TapedOutput::print("F IMTB IIT");
6972
}
7073
else {
74+
assert(tool != nullptr);
7175
TapedOutput::print("F IMTB l", data->level,
7276
"t", data->indexInTeam,
7377
"tape", data->newTape,
@@ -81,6 +85,7 @@ namespace opdi {
8185
TapedOutput::print("F IMTE IIT");
8286
}
8387
else {
88+
assert(tool != nullptr);
8489
TapedOutput::print("F IMTE l", data->level,
8590
"t", data->indexInTeam,
8691
"tape", data->newTape,

include/opdi/logic/omp/maskedOmpLogic.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,20 @@ void opdi::MaskedOmpLogic::deleteFunc(void* dataPtr) {
5454
void opdi::MaskedOmpLogic::onMasked(ScopeEndpoint endpoint) {
5555

5656
#if OPDI_OMP_LOGIC_INSTRUMENT
57-
if (tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
57+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
5858

59-
for (auto& instrument : ompLogicInstruments) {
60-
instrument->onMasked(endpoint);
61-
}
59+
for (auto& instrument : ompLogicInstruments) {
60+
instrument->onMasked(endpoint);
61+
}
6262

63-
Data* data = new Data;
64-
data->endpoint = endpoint;
63+
Data* data = new Data;
64+
data->endpoint = endpoint;
6565

66-
Handle* handle = new Handle;
67-
handle->data = static_cast<void*>(data);
68-
handle->reverseFunc = MaskedOmpLogic::reverseFunc;
69-
handle->deleteFunc = MaskedOmpLogic::deleteFunc;
70-
tool->pushExternalFunction(tool->getThreadLocalTape(), handle);
66+
Handle* handle = new Handle;
67+
handle->data = static_cast<void*>(data);
68+
handle->reverseFunc = MaskedOmpLogic::reverseFunc;
69+
handle->deleteFunc = MaskedOmpLogic::deleteFunc;
70+
tool->pushExternalFunction(tool->getThreadLocalTape(), handle);
7171
}
7272
#else
7373
OPDI_UNUSED(endpoint);

include/opdi/logic/omp/mutexOmpLogic.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,13 @@ void opdi::MutexOmpLogic::onMutexReleased(MutexKind mutexKind, WaitId waitId) {
212212
}
213213
}
214214

215-
// not thread safe! only use outside parallel regions
215+
// not thread-safe! only use outside of parallel regions
216216
void opdi::MutexOmpLogic::registerInactiveMutex(MutexKind mutexKind, WaitId waitId) {
217217
checkKind(mutexKind);
218218
this->recordings[mutexKind].inactive.insert(waitId);
219219
}
220220

221+
// not thread-safe! only use outside of parallel regions
221222
void opdi::MutexOmpLogic::prepareEvaluate() {
222223
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {
223224
MutexOmpLogic::evaluationCounters[mutexKind] = this->recordings[mutexKind].counters;
@@ -239,6 +240,7 @@ void opdi::MutexOmpLogic::prepareEvaluate() {
239240
#endif
240241
}
241242

243+
// not thread-safe! only use outside of parallel regions
242244
void opdi::MutexOmpLogic::postEvaluate() {
243245
#ifdef __SANITIZE_THREAD__
244246
/* destroy lock annotations */
@@ -253,12 +255,14 @@ void opdi::MutexOmpLogic::postEvaluate() {
253255
#endif
254256
}
255257

258+
// not thread-safe! only use outside of parallel regions
256259
void opdi::MutexOmpLogic::reset() {
257260
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {
258261
this->recordings[mutexKind].counters.clear();
259262
}
260263
}
261264

265+
// not thread-safe! only use outside of parallel regions
262266
void* opdi::MutexOmpLogic::exportState() {
263267
State* state = new State;
264268
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {
@@ -272,6 +276,7 @@ void opdi::MutexOmpLogic::freeState(void* statePtr) {
272276
delete state;
273277
}
274278

279+
// not thread safe! only use outside parallel regions
275280
void opdi::MutexOmpLogic::recoverState(void* statePtr) {
276281
State* state = static_cast<State*>(statePtr);
277282
for (std::size_t mutexKind = 0; mutexKind < nMutexKind; ++mutexKind) {

include/opdi/logic/omp/ompLogic.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ namespace opdi {
5757

5858
virtual void init() {
5959

60+
assert(backend != nullptr);
61+
6062
MutexOmpLogic::internalInit();
6163
ImplicitTaskOmpLogic::internalInit();
6264

@@ -71,6 +73,9 @@ namespace opdi {
7173
}
7274

7375
virtual void finalize() {
76+
77+
assert(backend != nullptr);
78+
7479
// finalize initial implicit task
7580
ImplicitTaskData* initialImplicitTaskData = static_cast<ImplicitTaskData*>(backend->getImplicitTaskData());
7681

include/opdi/logic/omp/parallelOmpLogic.cpp

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ int opdi::ParallelOmpLogic::skipParallelRegion = 0;
3939

4040
void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
4141

42+
assert(tool != nullptr);
43+
4244
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
4345

4446
#if OPDI_OMP_LOGIC_INSTRUMENT
@@ -105,6 +107,8 @@ void opdi::ParallelOmpLogic::reverseFunc(void* parallelDataPtr) {
105107

106108
void opdi::ParallelOmpLogic::deleteFunc(void* parallelDataPtr) {
107109

110+
assert(tool != nullptr);
111+
108112
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
109113

110114
ParallelOmpLogic::internalBeginSkippedParallelRegion();
@@ -153,23 +157,25 @@ void opdi::ParallelOmpLogic::internalSetAdjointAccessMode(ImplicitTaskData* impl
153157
implicitTaskData->adjointAccessModes.back() = mode;
154158
}
155159
else {
156-
void* position = tool->allocPosition();
157-
tool->getTapePosition(implicitTaskData->newTape, position);
160+
if (tool != nullptr) {
161+
void* position = tool->allocPosition();
162+
tool->getTapePosition(implicitTaskData->newTape, position);
158163

159-
if (tool->comparePosition(implicitTaskData->positions.back(), position) == 0) {
160-
implicitTaskData->adjointAccessModes.back() = mode;
161-
tool->freePosition(position);
162-
}
163-
else {
164-
implicitTaskData->adjointAccessModes.push_back(mode);
165-
implicitTaskData->positions.push_back(position);
164+
if (tool->comparePosition(implicitTaskData->positions.back(), position) == 0) {
165+
implicitTaskData->adjointAccessModes.back() = mode;
166+
tool->freePosition(position);
167+
}
168+
else {
169+
implicitTaskData->adjointAccessModes.push_back(mode);
170+
implicitTaskData->positions.push_back(position);
171+
}
166172
}
167173
}
168174
}
169175

170176
void* opdi::ParallelOmpLogic::onParallelBegin(void* encounteringTaskDataPtr, int maximumSizeOfTeam) {
171177

172-
if (tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelRegion == 0) {
178+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && ParallelOmpLogic::skipParallelRegion == 0) {
173179

174180
ImplicitTaskData* encounteringTaskData = static_cast<ImplicitTaskData*>(encounteringTaskDataPtr);
175181

@@ -209,6 +215,8 @@ void opdi::ParallelOmpLogic::onParallelEnd(void* parallelDataPtr) {
209215

210216
if (parallelDataPtr != nullptr) {
211217

218+
assert(tool != nullptr);
219+
212220
ParallelData* parallelData = static_cast<ParallelData*>(parallelDataPtr);
213221

214222
#if OPDI_OMP_LOGIC_INSTRUMENT
@@ -251,6 +259,8 @@ void opdi::ParallelOmpLogic::onParallelEnd(void* parallelDataPtr) {
251259

252260
void opdi::ParallelOmpLogic::setAdjointAccessMode(opdi::LogicInterface::AdjointAccessMode mode) {
253261

262+
assert(backend != nullptr);
263+
254264
#if OPDI_VARIABLE_ADJOINT_ACCESS_MODE
255265
void* implicitTaskDataPtr = backend->getImplicitTaskData();
256266
if (implicitTaskDataPtr != nullptr) { // nullptr if called during tape evaluation
@@ -266,6 +276,9 @@ void opdi::ParallelOmpLogic::setAdjointAccessMode(opdi::LogicInterface::AdjointA
266276
}
267277

268278
opdi::LogicInterface::AdjointAccessMode opdi::ParallelOmpLogic::getAdjointAccessMode() const {
279+
280+
assert(backend != nullptr);
281+
269282
void* implicitTaskDataPtr = backend->getImplicitTaskData();
270283
if (implicitTaskDataPtr != nullptr) { // nullptr if called during tape evaluation
271284
return internalGetAdjointAccessMode(static_cast<ImplicitTaskData*>(implicitTaskDataPtr));

include/opdi/logic/omp/syncRegionOmpLogic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ bool opdi::SyncRegionOmpLogic::requiresReverseBarrier(SyncRegionKind kind, Scope
8282

8383
void opdi::SyncRegionOmpLogic::onSyncRegion(SyncRegionKind kind, ScopeEndpoint endpoint) {
8484

85-
if (tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
85+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
8686

8787
#if OPDI_OMP_LOGIC_INSTRUMENT
8888
for (auto& instrument : ompLogicInstruments) {

include/opdi/logic/omp/workOmpLogic.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ void opdi::WorkOmpLogic::deleteFunc(void* dataPtr) {
5353
void opdi::WorkOmpLogic::onWork(WorksharingKind kind, ScopeEndpoint endpoint) {
5454

5555
#if OPDI_OMP_LOGIC_INSTRUMENT
56-
57-
if (tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
56+
if (tool != nullptr && tool->getThreadLocalTape() != nullptr && tool->isActive(tool->getThreadLocalTape())) {
5857

5958
for (auto& instrument : ompLogicInstruments) {
6059
instrument->onWork(kind, endpoint);

0 commit comments

Comments
 (0)