001/*
002 * (C) Copyright 2017 Nuxeo SA (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     bdelbosc
018 */
019package org.nuxeo.lib.stream.computation.log;
020
021import static java.util.concurrent.Executors.newFixedThreadPool;
022
023import java.time.Duration;
024import java.util.ArrayList;
025import java.util.List;
026import java.util.Objects;
027import java.util.Set;
028import java.util.concurrent.ExecutorService;
029import java.util.concurrent.ThreadFactory;
030import java.util.concurrent.TimeUnit;
031import java.util.concurrent.atomic.AtomicInteger;
032import java.util.function.Supplier;
033import java.util.stream.Collectors;
034
035import org.apache.commons.logging.Log;
036import org.apache.commons.logging.LogFactory;
037import org.nuxeo.lib.stream.codec.Codec;
038import org.nuxeo.lib.stream.computation.Computation;
039import org.nuxeo.lib.stream.computation.ComputationMetadataMapping;
040import org.nuxeo.lib.stream.computation.Record;
041import org.nuxeo.lib.stream.computation.Watermark;
042import org.nuxeo.lib.stream.log.LogManager;
043import org.nuxeo.lib.stream.log.LogPartition;
044
045/**
046 * Pool of ComputationRunner
047 *
048 * @since 9.3
049 */
050public class ComputationPool {
051    private static final Log log = LogFactory.getLog(ComputationPool.class);
052
053    protected final ComputationMetadataMapping metadata;
054
055    protected final int threads;
056
057    protected final LogManager manager;
058
059    protected final Supplier<Computation> supplier;
060
061    protected final List<List<LogPartition>> defaultAssignments;
062
063    protected final List<ComputationRunner> runners;
064
065    protected final Codec<Record> inputCodec;
066
067    protected final Codec<Record> outputCodec;
068
069    protected ExecutorService threadPool;
070
071    public ComputationPool(Supplier<Computation> supplier, ComputationMetadataMapping metadata,
072            List<List<LogPartition>> defaultAssignments, LogManager manager, Codec<Record> inputCodec,
073            Codec<Record> outputCodec) {
074        Objects.requireNonNull(inputCodec);
075        Objects.requireNonNull(outputCodec);
076        this.supplier = supplier;
077        this.manager = manager;
078        this.metadata = metadata;
079        this.threads = defaultAssignments.size();
080        this.inputCodec = inputCodec;
081        this.outputCodec = outputCodec;
082        this.defaultAssignments = defaultAssignments;
083        this.runners = new ArrayList<>(threads);
084    }
085
086    public String getComputationName() {
087        return metadata.name();
088    }
089
090    @SuppressWarnings("FutureReturnValueIgnored")
091    public void start() {
092        log.info(metadata.name() + ": Starting pool");
093        threadPool = newFixedThreadPool(threads, new NamedThreadFactory(metadata.name() + "Pool"));
094        defaultAssignments.forEach(assignments -> {
095            ComputationRunner runner = new ComputationRunner(supplier, metadata, assignments, manager, inputCodec,
096                    outputCodec);
097            threadPool.submit(runner);
098            runners.add(runner);
099        });
100        // close the pool no new admission
101        threadPool.shutdown();
102        log.debug(metadata.name() + ": Pool started, threads: " + threads);
103    }
104
105    public boolean isTerminated() {
106        return threadPool.isTerminated();
107    }
108
109    public boolean waitForAssignments(Duration timeout) throws InterruptedException {
110        log.info(metadata.name() + ": Wait for partitions assignments");
111        if (threadPool == null || threadPool.isTerminated()) {
112            return true;
113        }
114        for (ComputationRunner runner : runners) {
115            if (!runner.waitForAssignments(timeout)) {
116                return false;
117            }
118        }
119        return true;
120    }
121
122    public boolean drainAndStop(Duration timeout) {
123        if (threadPool == null || threadPool.isTerminated()) {
124            return true;
125        }
126        log.info(metadata.name() + ": Draining");
127        runners.forEach(ComputationRunner::drain);
128        boolean ret = awaitPoolTermination(timeout);
129        stop(Duration.ofSeconds(1));
130        return ret;
131    }
132
133    public boolean stop(Duration timeout) {
134        if (threadPool == null || threadPool.isTerminated()) {
135            return true;
136        }
137        log.info(metadata.name() + ": Stopping");
138        runners.forEach(ComputationRunner::stop);
139        boolean ret = awaitPoolTermination(timeout);
140        shutdown();
141        return ret;
142    }
143
144    public void shutdown() {
145        if (threadPool != null && !threadPool.isTerminated()) {
146            log.info(metadata.name() + ": Shutting down");
147            threadPool.shutdownNow();
148            // give a chance to end threads with valid tailer when shutdown is followed by streams.close()
149            try {
150                threadPool.awaitTermination(1, TimeUnit.SECONDS);
151            } catch (InterruptedException e) {
152                Thread.currentThread().interrupt();
153                log.warn(metadata.name() + ": Interrupted in shutdown");
154            }
155        }
156        runners.clear();
157        threadPool = null;
158    }
159
160    protected boolean awaitPoolTermination(Duration timeout) {
161        try {
162            if (!threadPool.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
163                log.warn(metadata.name() + ": Timeout on wait for pool termination");
164                return false;
165            }
166        } catch (InterruptedException e) {
167            Thread.currentThread().interrupt();
168            log.warn(metadata.name() + ": Interrupted while waiting for pool termination");
169            return false;
170        }
171        return true;
172    }
173
174    public long getLowWatermark() {
175        // Collect all the low watermark of the pool, filtering 0 (or 1 which is completed of 0)
176        Set<Watermark> watermarks = runners.stream()
177                                           .map(ComputationRunner::getLowWatermark)
178                                           .filter(wm -> wm.getValue() > 1)
179                                           .collect(Collectors.toSet());
180        // Take the lowest watermark of unprocessed (not completed) records
181        long ret = watermarks.stream().filter(wm -> !wm.isCompleted()).mapToLong(Watermark::getValue).min().orElse(0);
182        boolean pending = true;
183        if (ret == 0) {
184            pending = false;
185            // There is no known pending records we take the max completed low watermark
186            ret = watermarks.stream().filter(Watermark::isCompleted).mapToLong(Watermark::getValue).max().orElse(0);
187        }
188        if (log.isTraceEnabled() && ret > 0)
189            log.trace(metadata.name() + ": low: " + ret + " " + (pending ? "Pending" : "Completed"));
190        return ret;
191    }
192
193    protected static class NamedThreadFactory implements ThreadFactory {
194        protected final AtomicInteger count = new AtomicInteger(0);
195
196        protected final String prefix;
197
198        public NamedThreadFactory(String prefix) {
199            this.prefix = prefix;
200        }
201
202        @SuppressWarnings("NullableProblems")
203        @Override
204        public Thread newThread(Runnable r) {
205            Thread t = new Thread(r, String.format("%s-%02d", prefix, count.getAndIncrement()));
206            t.setUncaughtExceptionHandler((t1, e) -> log.error("Uncaught exception: " + e.getMessage(), e));
207            return t;
208        }
209    }
210
211}