001/*
002 * (C) Copyright 2018 Nuxeo SA (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     bdelbosc
018 */
019package org.nuxeo.lib.stream.computation;
020
021import java.util.ArrayList;
022import java.util.List;
023
024import org.apache.commons.logging.Log;
025import org.apache.commons.logging.LogFactory;
026
027/**
028 * An abstract {@link Computation} that processes records by batch.
029 * <p>
030 * The batch capacity and threshold are defined in the computation policy.
031 *
032 * @since 10.3
033 */
034public abstract class AbstractBatchComputation extends AbstractComputation {
035
036    private static final Log log = LogFactory.getLog(AbstractBatchComputation.class);
037
038    public static final String TIMER_BATCH = "batch";
039
040    protected List<Record> batchRecords;
041
042    protected String currentInputStream;
043
044    protected boolean newBatch = true;
045
046    protected long thresholdMillis;
047
048    protected boolean removeLastRecordOnRetry;
049
050    /**
051     * Constructor
052     *
053     * @param name the name of the computation
054     * @param nbInputStreams the number of input streams
055     * @param nbOutputStreams the number of output streams
056     */
057    public AbstractBatchComputation(String name, int nbInputStreams, int nbOutputStreams) {
058        super(name, nbInputStreams, nbOutputStreams);
059    }
060
061    /**
062     * Called when:<br>
063     * - the batch capacity is reached<br/>
064     * - the time threshold is reached<br/>
065     * - the inputStreamName has changed<br/>
066     * If this method raises an exception the retry policy is applied.
067     *
068     * @param context used to send records to output streams, note that the checkpoint is managed automatically.
069     * @param inputStreamName the input streams where the records are coming from
070     * @param records the batch of records
071     */
072    protected abstract void batchProcess(ComputationContext context, String inputStreamName, List<Record> records);
073
074    /**
075     * Called when the retry policy has failed.
076     */
077    public abstract void batchFailure(ComputationContext context, String inputStreamName, List<Record> records);
078
079    @Override
080    public void init(ComputationContext context) {
081        thresholdMillis = context.getPolicy().getBatchThreshold().toMillis();
082        context.setTimer(TIMER_BATCH, System.currentTimeMillis() + thresholdMillis);
083        batchRecords = new ArrayList<>(context.getPolicy().batchCapacity);
084    }
085
086    @Override
087    public void processTimer(ComputationContext context, String key, long timestamp) {
088        if (!TIMER_BATCH.equals(key)) {
089            return;
090        }
091        if (!batchRecords.isEmpty()) {
092            batchProcess(context);
093        }
094        context.setTimer(TIMER_BATCH, System.currentTimeMillis() + thresholdMillis);
095    }
096
097    @Override
098    public void processRecord(ComputationContext context, String inputStreamName, Record record) {
099        if (!inputStreamName.equals(currentInputStream) && !batchRecords.isEmpty()) {
100            batchProcess(context);
101        }
102        if (newBatch) {
103            currentInputStream = inputStreamName;
104            newBatch = false;
105        }
106        batchRecords.add(record);
107        if (batchRecords.size() >= context.getPolicy().getBatchCapacity()) {
108            removeLastRecordOnRetry = true;
109            batchProcess(context);
110            removeLastRecordOnRetry = false;
111        }
112    }
113
114    private void batchProcess(ComputationContext context) {
115        batchProcess(context, currentInputStream, batchRecords);
116        checkpointBatch(context);
117    }
118
119    protected void checkpointBatch(ComputationContext context) {
120        context.askForCheckpoint();
121        batchRecords.clear();
122        newBatch = true;
123    }
124
125    @Override
126    public void processRetry(ComputationContext context, Throwable failure) {
127        if (removeLastRecordOnRetry) {
128            // the batchProcess has failed, processRecord will be retried with the same record
129            // but first we have to remove the record from the batch
130            batchRecords.remove(batchRecords.size() -1);
131            removeLastRecordOnRetry = false;
132        }
133        log.warn(String.format("Computation: %s fails to process batch of %d records, last record: %s, retrying ...",
134                metadata.name(), batchRecords.size(), context.getLastOffset()), failure);
135    }
136
137    @Override
138    public void processFailure(ComputationContext context, Throwable failure) {
139        log.error(String.format(
140                "Computation: %s fails to process batch of %d records after retries, last record: %s, policy: %s",
141                metadata.name(), batchRecords.size(), context.getLastOffset(), context.getPolicy()), failure);
142        batchFailure(context, currentInputStream, batchRecords);
143        batchRecords.clear();
144        newBatch = true;
145    }
146
147}