001/*
002 * (C) Copyright 2018 Nuxeo (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *       Kevin Leturc <kleturc@nuxeo.com>
018 */
019package org.nuxeo.ecm.core.bulk;
020
021import static java.lang.Integer.max;
022import static java.lang.Math.min;
023import static org.nuxeo.ecm.core.bulk.BulkComponent.BULK_KV_STORE_NAME;
024import static org.nuxeo.ecm.core.bulk.BulkServiceImpl.PROCESSED_DOCUMENTS;
025import static org.nuxeo.ecm.core.bulk.BulkServiceImpl.SCROLLED_DOCUMENT_COUNT;
026import static org.nuxeo.ecm.core.bulk.BulkServiceImpl.SCROLL_END_TIME;
027import static org.nuxeo.ecm.core.bulk.BulkServiceImpl.SCROLL_START_TIME;
028import static org.nuxeo.ecm.core.bulk.BulkServiceImpl.SET_STREAM_NAME;
029import static org.nuxeo.ecm.core.bulk.BulkServiceImpl.STATE;
030import static org.nuxeo.ecm.core.bulk.BulkStatus.State.COMPLETED;
031import static org.nuxeo.ecm.core.bulk.BulkStatus.State.RUNNING;
032import static org.nuxeo.ecm.core.bulk.BulkStatus.State.SCHEDULED;
033import static org.nuxeo.ecm.core.bulk.BulkStatus.State.SCROLLING_RUNNING;
034
035import java.time.Instant;
036import java.util.ArrayList;
037import java.util.Arrays;
038import java.util.Collections;
039import java.util.HashMap;
040import java.util.List;
041import java.util.Map;
042
043import javax.security.auth.login.LoginContext;
044import javax.security.auth.login.LoginException;
045
046import org.apache.commons.logging.Log;
047import org.apache.commons.logging.LogFactory;
048import org.nuxeo.ecm.core.api.CloseableCoreSession;
049import org.nuxeo.ecm.core.api.CoreInstance;
050import org.nuxeo.ecm.core.api.NuxeoException;
051import org.nuxeo.ecm.core.api.ScrollResult;
052import org.nuxeo.lib.stream.codec.Codec;
053import org.nuxeo.lib.stream.computation.AbstractComputation;
054import org.nuxeo.lib.stream.computation.ComputationContext;
055import org.nuxeo.lib.stream.computation.Record;
056import org.nuxeo.lib.stream.computation.Topology;
057import org.nuxeo.runtime.api.Framework;
058import org.nuxeo.runtime.codec.CodecService;
059import org.nuxeo.runtime.kv.KeyValueService;
060import org.nuxeo.runtime.kv.KeyValueStore;
061import org.nuxeo.runtime.stream.StreamProcessorTopology;
062import org.nuxeo.runtime.transaction.TransactionHelper;
063
064/**
065 * Computation that consumes a {@link BulkCommand} and produce document ids. This scroller takes a query to execute on
066 * DB (by scrolling) and then produce document id to the appropriate stream.
067 *
068 * @since 10.2
069 */
070public class StreamBulkProcessor implements StreamProcessorTopology {
071
072    private static final Log log = LogFactory.getLog(StreamBulkProcessor.class);
073
074    public static final String AVRO_CODEC = "avro";
075
076    public static final String SCROLLER_COMPUTATION_NAME = "bulkDocumentScroller";
077
078    public static final String COUNTER_COMPUTATION_NAME = "bulkCounter";
079
080    public static final String KVWRITER_COMPUTATION_NAME = "keyValueWriter";
081
082    public static final String COUNTER_STREAM_NAME = "counter";
083
084    public static final String KVWRITER_STREAM_NAME = "keyValueWriter";
085
086    public static final String SCROLL_BATCH_SIZE_OPT = "scrollBatchSize";
087
088    public static final String SCROLL_KEEP_ALIVE_SECONDS_OPT = "scrollKeepAlive";
089
090    public static final String BUCKET_SIZE_OPT = "bucketSize";
091
092    public static final String COUNTER_THRESHOLD_MS_OPT = "counterThresholdMs";
093
094    public static final int DEFAULT_SCROLL_BATCH_SIZE = 100;
095
096    public static final int DEFAULT_SCROLL_KEEPALIVE_SECONDS = 60;
097
098    public static final int DEFAULT_BUCKET_SIZE = 50;
099
100    public static final int DEFAULT_COUNTER_THRESHOLD_MS = 30000;
101
102    @Override
103    public Topology getTopology(Map<String, String> options) {
104        // retrieve options
105        int scrollBatchSize = getOptionAsInteger(options, SCROLL_BATCH_SIZE_OPT, DEFAULT_SCROLL_BATCH_SIZE);
106        int scrollKeepAliveSeconds = getOptionAsInteger(options, SCROLL_KEEP_ALIVE_SECONDS_OPT,
107                DEFAULT_SCROLL_KEEPALIVE_SECONDS);
108        int bucketSize = getOptionAsInteger(options, BUCKET_SIZE_OPT, DEFAULT_BUCKET_SIZE);
109        int counterThresholdMs = getOptionAsInteger(options, COUNTER_THRESHOLD_MS_OPT, DEFAULT_COUNTER_THRESHOLD_MS);
110
111        // retrieve bulk actions to deduce output streams
112        BulkAdminService service = Framework.getService(BulkAdminService.class);
113        List<String> actions = service.getActions();
114        List<String> mapping = new ArrayList<>();
115        mapping.add("i1:" + SET_STREAM_NAME);
116        int i = 1;
117        for (String action : actions) {
118            mapping.add(String.format("o%s:%s", i, action));
119            i++;
120        }
121        mapping.add(String.format("o%s:%s", i, KVWRITER_STREAM_NAME));
122
123        return Topology.builder()
124                       .addComputation( //
125                               () -> new BulkDocumentScrollerComputation(SCROLLER_COMPUTATION_NAME, mapping.size(),
126                                       scrollBatchSize, scrollKeepAliveSeconds, bucketSize), //
127                               mapping)
128                       .addComputation(() -> new CounterComputation(COUNTER_COMPUTATION_NAME, counterThresholdMs),
129                               Arrays.asList("i1:" + COUNTER_STREAM_NAME, "o1:" + KVWRITER_STREAM_NAME))
130                       .addComputation(() -> new KeyValueWriterComputation(KVWRITER_COMPUTATION_NAME),
131                               Collections.singletonList("i1:" + KVWRITER_STREAM_NAME))
132                       .build();
133    }
134
135    public static class BulkDocumentScrollerComputation extends AbstractComputation {
136
137        protected final int scrollBatchSize;
138
139        protected final int scrollKeepAliveSeconds;
140
141        protected final int bucketSize;
142
143        protected final List<String> documentIds;
144
145        /**
146         * @param name the computation name
147         * @param nbOutputStreams the number of registered bulk action streams
148         * @param scrollBatchSize the batch size to scroll
149         * @param scrollKeepAliveSeconds the scroll lifetime
150         * @param bucketSize the number of document to send per bucket
151         */
152        public BulkDocumentScrollerComputation(String name, int nbOutputStreams, int scrollBatchSize,
153                int scrollKeepAliveSeconds, int bucketSize) {
154            super(name, 1, nbOutputStreams);
155            this.scrollBatchSize = scrollBatchSize;
156            this.scrollKeepAliveSeconds = scrollKeepAliveSeconds;
157            this.bucketSize = bucketSize;
158            documentIds = new ArrayList<>(max(scrollBatchSize, bucketSize));
159        }
160
161        @Override
162        public void processRecord(ComputationContext context, String inputStreamName, Record record) {
163            TransactionHelper.runInTransaction(() -> processRecord(context, record));
164        }
165
166        protected void processRecord(ComputationContext context, Record record) {
167            KeyValueStore kvStore = Framework.getService(KeyValueService.class).getKeyValueStore(BULK_KV_STORE_NAME);
168            try {
169                String commandId = record.getKey();
170                BulkCommand command = BulkCommands.fromBytes(record.getData());
171                if (!kvStore.compareAndSet(commandId + STATE, SCHEDULED.toString(), SCROLLING_RUNNING.toString())) {
172                    log.error("Discard record: " + record + " because it's already building");
173                    context.askForCheckpoint();
174                    return;
175                }
176                LoginContext loginContext;
177                try {
178                    loginContext = Framework.loginAsUser(command.getUsername());
179                    try (CloseableCoreSession session = CoreInstance.openCoreSession(command.getRepository())) {
180                        // scroll documents
181                        Long scrollStartTime = Instant.now().toEpochMilli();
182                        ScrollResult<String> scroll = session.scroll(command.getQuery(), scrollBatchSize,
183                                scrollKeepAliveSeconds);
184                        long documentCount = 0;
185                        long bucketNumber = 0;
186                        while (scroll.hasResults()) {
187                            List<String> docIds = scroll.getResults();
188                            documentIds.addAll(docIds);
189                            while (documentIds.size() >= bucketSize) {
190                                // we use number of sent document to make record key unique
191                                // key are prefixed with bulkId:, suffix are:
192                                // bucketSize / 2 * bucketSize / ... / total document count
193                                bucketNumber++;
194                                produceBucket(context, command.getAction(), commandId, bucketNumber * bucketSize);
195                            }
196
197                            documentCount += docIds.size();
198                            // next batch
199                            scroll = session.scroll(scroll.getScrollId());
200                            TransactionHelper.commitOrRollbackTransaction();
201                            TransactionHelper.startTransaction();
202                        }
203
204                        // send remaining document ids
205                        // there's at most one record because we loop while scrolling
206                        if (!documentIds.isEmpty()) {
207                            produceBucket(context, command.getAction(), commandId, documentCount);
208                        }
209
210                        Long scrollEndTime = Instant.now().toEpochMilli();
211
212                        BulkUpdate updates = new BulkUpdate();
213                        updates.put(commandId + SCROLL_START_TIME, scrollStartTime.toString());
214                        updates.put(commandId + SCROLL_END_TIME, scrollEndTime.toString());
215                        updates.put(commandId + STATE, RUNNING.toString());
216                        updates.put(commandId + SCROLLED_DOCUMENT_COUNT, String.valueOf(documentCount));
217                        Codec<BulkUpdate> updateCodec = Framework.getService(CodecService.class).getCodec(AVRO_CODEC,
218                                BulkUpdate.class);
219                        context.produceRecord(KVWRITER_STREAM_NAME, commandId, updateCodec.encode(updates));
220
221                    } finally {
222                        loginContext.logout();
223                    }
224                } catch (LoginException e) {
225                    throw new NuxeoException(e);
226                }
227            } catch (NuxeoException e) {
228                log.error("Discard invalid record: " + record, e);
229            }
230        }
231
232        /**
233         * Produces a bucket as a record to appropriate bulk action stream.
234         */
235        protected void produceBucket(ComputationContext context, String action, String commandId, long nbDocSent) {
236            List<String> docIds = documentIds.subList(0, min(bucketSize, documentIds.size()));
237            // send these ids as keys to the appropriate stream
238            context.produceRecord(action, BulkRecords.of(commandId, nbDocSent, docIds));
239            context.askForCheckpoint();
240            docIds.clear();
241        }
242    }
243
244    public static class CounterComputation extends AbstractComputation {
245
246        protected final int counterThresholdMs;
247
248        protected final Map<String, Long> counters;
249
250        public CounterComputation(String counterComputationName, int counterThresholdMs) {
251            super(counterComputationName, 1, 1);
252            this.counterThresholdMs = counterThresholdMs;
253            this.counters = new HashMap<>();
254        }
255
256        @Override
257        public void init(ComputationContext context) {
258            log.debug(String.format("Starting computation: %s reading on: %s, threshold: %dms",
259                    COUNTER_COMPUTATION_NAME, COUNTER_STREAM_NAME, counterThresholdMs));
260            context.setTimer("counter", System.currentTimeMillis() + counterThresholdMs);
261        }
262
263        @Override
264        public void processTimer(ComputationContext context, String key, long timestamp) {
265            KeyValueStore kvStore = Framework.getService(KeyValueService.class).getKeyValueStore(BULK_KV_STORE_NAME);
266            BulkUpdate updates = new BulkUpdate();
267            counters.forEach((bulkId, processedDocs) -> {
268                Long previousProcessedDocs = kvStore.getLong(bulkId + PROCESSED_DOCUMENTS);
269                if (previousProcessedDocs == null) {
270                    previousProcessedDocs = 0L;
271                }
272                Long currentProcessedDocs = previousProcessedDocs + processedDocs;
273                if (currentProcessedDocs.longValue() == kvStore.getLong(bulkId + SCROLLED_DOCUMENT_COUNT).longValue()) {
274                    updates.put(bulkId + STATE, COMPLETED.toString());
275                }
276                updates.put(bulkId + PROCESSED_DOCUMENTS, String.valueOf(currentProcessedDocs));
277            });
278            Codec<BulkUpdate> updateCodec = Framework.getService(CodecService.class).getCodec(AVRO_CODEC,
279                    BulkUpdate.class);
280            context.produceRecord(KVWRITER_STREAM_NAME, key, updateCodec.encode(updates));
281            counters.clear();
282            context.askForCheckpoint();
283            context.setTimer("counter", System.currentTimeMillis() + counterThresholdMs);
284        }
285
286        @Override
287        public void processRecord(ComputationContext context, String inputStreamName, Record record) {
288            Codec<BulkCounter> counterCodec = Framework.getService(CodecService.class).getCodec(AVRO_CODEC,
289                    BulkCounter.class);
290            BulkCounter counter = counterCodec.decode(record.getData());
291            String bulkId = counter.getBulkId();
292
293            counters.computeIfPresent(bulkId, (k, processedDocs) -> processedDocs + counter.getProcessedDocuments());
294            counters.putIfAbsent(bulkId, counter.getProcessedDocuments());
295        }
296    }
297
298    public static class KeyValueWriterComputation extends AbstractComputation {
299
300        public KeyValueWriterComputation(String name) {
301            super(name, 1, 0);
302        }
303
304        @Override
305        public void processRecord(ComputationContext context, String inputStreamName, Record record) {
306            KeyValueStore kvStore = Framework.getService(KeyValueService.class).getKeyValueStore(BULK_KV_STORE_NAME);
307            Codec<BulkUpdate> updateCodec = Framework.getService(CodecService.class).getCodec(AVRO_CODEC,
308                    BulkUpdate.class);
309
310            BulkUpdate updates = updateCodec.decode(record.getData());
311            updates.getValues().forEach(kvStore::put);
312            context.askForCheckpoint();
313        }
314    }
315
316    // TODO copied from StreamAuditWriter - where can we put that ?
317    protected int getOptionAsInteger(Map<String, String> options, String option, int defaultValue) {
318        String value = options.get(option);
319        return value == null ? defaultValue : Integer.parseInt(value);
320    }
321}