|
21 | 21 | import org.apache.flink.api.common.ExecutionConfig; |
22 | 22 | import org.apache.flink.api.common.eventtime.WatermarkStrategy; |
23 | 23 | import org.apache.flink.api.common.functions.FlatMapFunction; |
| 24 | +import org.apache.flink.api.common.state.CheckpointListener; |
24 | 25 | import org.apache.flink.api.common.typeinfo.BasicTypeInfo; |
25 | 26 | import org.apache.flink.api.common.typeinfo.TypeInformation; |
26 | 27 | import org.apache.flink.api.common.typeinfo.Types; |
@@ -254,21 +255,44 @@ public TypeInformation getTypeInformation() { |
254 | 255 | } |
255 | 256 | } |
256 | 257 |
|
| 258 | + /** |
| 259 | + * A filter that only passes through elements received before the first checkpoint completes. |
| 260 | + * |
| 261 | + * <p>The filter stops collecting elements in {@link #notifyCheckpointComplete(long)} rather |
| 262 | + * than in {@link #snapshotState(FunctionSnapshotContext)}, to avoid a race condition where the |
| 263 | + * checkpoint barrier arrives at this operator before all upstream elements (emitted in the same |
| 264 | + * checkpoint cycle) have been processed. Using {@code notifyCheckpointComplete} ensures that |
| 265 | + * the checkpoint has fully propagated through the pipeline before we stop collecting. |
| 266 | + */ |
257 | 267 | private static class FirstCheckpointFilter |
258 | | - implements FlatMapFunction<Long, Long>, CheckpointedFunction { |
| 268 | + implements FlatMapFunction<Long, Long>, CheckpointedFunction, CheckpointListener { |
259 | 269 |
|
260 | | - private volatile boolean firstCheckpoint = true; |
| 270 | + private volatile boolean firstCheckpointCompleted = false; |
| 271 | + private long firstCheckpointId = Long.MIN_VALUE; |
261 | 272 |
|
262 | 273 | @Override |
263 | 274 | public void flatMap(Long value, Collector<Long> out) throws Exception { |
264 | | - if (firstCheckpoint) { |
| 275 | + if (!firstCheckpointCompleted) { |
265 | 276 | out.collect(value); |
266 | 277 | } |
267 | 278 | } |
268 | 279 |
|
269 | 280 | @Override |
270 | 281 | public void snapshotState(FunctionSnapshotContext context) throws Exception { |
271 | | - firstCheckpoint = false; |
| 282 | + // Record the ID of the first checkpoint so we can stop collecting when it completes. |
| 283 | + if (firstCheckpointId == Long.MIN_VALUE) { |
| 284 | + firstCheckpointId = context.getCheckpointId(); |
| 285 | + } |
| 286 | + } |
| 287 | + |
| 288 | + @Override |
| 289 | + public void notifyCheckpointComplete(long checkpointId) throws Exception { |
| 290 | + // Stop collecting elements once the first checkpoint has completed. |
| 291 | + if (!firstCheckpointCompleted |
| 292 | + && checkpointId >= firstCheckpointId |
| 293 | + && firstCheckpointId != Long.MIN_VALUE) { |
| 294 | + firstCheckpointCompleted = true; |
| 295 | + } |
272 | 296 | } |
273 | 297 |
|
274 | 298 | @Override |
|
0 commit comments