001/*
002 * (C) Copyright 2017 Nuxeo (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     Florent Guillaume
018 */
019package org.nuxeo.ecm.core.redis.contribs;
020
021import static redis.clients.jedis.Protocol.Keyword.MESSAGE;
022import static redis.clients.jedis.Protocol.Keyword.PMESSAGE;
023import static redis.clients.jedis.Protocol.Keyword.PSUBSCRIBE;
024import static redis.clients.jedis.Protocol.Keyword.PUNSUBSCRIBE;
025import static redis.clients.jedis.Protocol.Keyword.SUBSCRIBE;
026import static redis.clients.jedis.Protocol.Keyword.UNSUBSCRIBE;
027
028import java.lang.reflect.Method;
029import java.util.Arrays;
030import java.util.List;
031import java.util.Map;
032import java.util.concurrent.CountDownLatch;
033import java.util.concurrent.TimeUnit;
034import java.util.function.BiConsumer;
035
036import org.apache.commons.logging.Log;
037import org.apache.commons.logging.LogFactory;
038import org.nuxeo.ecm.core.api.NuxeoException;
039import org.nuxeo.ecm.core.redis.RedisAdmin;
040import org.nuxeo.ecm.core.redis.RedisExecutor;
041import org.nuxeo.runtime.api.Framework;
042import org.nuxeo.runtime.pubsub.AbstractPubSubProvider;
043import org.nuxeo.runtime.pubsub.PubSubProvider;
044
045import redis.clients.jedis.Client;
046import redis.clients.jedis.JedisPubSub;
047import redis.clients.jedis.exceptions.JedisException;
048import redis.clients.util.SafeEncoder;
049
050/**
051 * Redis implementation of {@link PubSubProvider}.
052 *
053 * @since 9.1
054 */
055public class RedisPubSubProvider extends AbstractPubSubProvider {
056
057    // package-private to avoid synthetic accessor for nested class
058    static final Log log = LogFactory.getLog(RedisPubSubProvider.class);
059
060    /** Maximum delay to wait for a channel subscription on startup. */
061    public static final long TIMEOUT_SUBSCRIBE_SECONDS = 5;
062
063    protected Dispatcher dispatcher;
064
065    protected Thread thread;
066
067    @Override
068    public void initialize(Map<String, String> options, Map<String, List<BiConsumer<String, byte[]>>> subscribers) {
069        super.initialize(options, subscribers);
070        log.debug("Initializing");
071        namespace = Framework.getService(RedisAdmin.class).namespace();
072        dispatcher = new Dispatcher(namespace + "*");
073        thread = new Thread(dispatcher::run, "Nuxeo-PubSub-Redis");
074        thread.setUncaughtExceptionHandler((t, e) -> log.error("Uncaught error on thread " + t.getName(), e));
075        thread.setPriority(Thread.NORM_PRIORITY);
076        thread.setDaemon(true);
077        thread.start();
078        if (!dispatcher.awaitSubscribed(TIMEOUT_SUBSCRIBE_SECONDS, TimeUnit.SECONDS)) {
079            thread.interrupt();
080            throw new NuxeoException(
081                    "Failed to subscribe to Redis pubsub after " + TIMEOUT_SUBSCRIBE_SECONDS + "s");
082        }
083        log.debug("Initialized");
084    }
085
086    @Override
087    public void close() {
088        log.debug("Closing");
089        if (dispatcher != null) {
090            thread.interrupt();
091            thread = null;
092            dispatcher.close();
093            dispatcher = null;
094        }
095        log.debug("Closed");
096    }
097
098    /**
099     * Subscribes to the provided Redis channel pattern and dispatches received messages. Method {@code #run} must be
100     * called in a new thread.
101     */
102    public class Dispatcher extends JedisPubSub {
103
104        // we look this up during construction in the main thread,
105        // because service lookup is unavailable from alternative threads during startup
106        protected RedisExecutor redisExecutor;
107
108        protected final String pattern;
109
110        protected final CountDownLatch subscribedLatch;
111
112        protected volatile boolean stop;
113
114        public Dispatcher(String pattern) {
115            redisExecutor = Framework.getService(RedisExecutor.class);
116            this.pattern = pattern;
117            this.subscribedLatch = new CountDownLatch(1);
118        }
119
120        /**
121         * To be called from the main thread to wait for subscription to be effective.
122         */
123        public boolean awaitSubscribed(long timeout, TimeUnit unit) {
124            try {
125                return subscribedLatch.await(timeout, unit);
126            } catch (InterruptedException e) {
127                Thread.currentThread().interrupt();
128                throw new NuxeoException(e);
129            }
130        }
131
132        /**
133         * To be called from a new thread to do the actual Redis subscription and to dispatch messages.
134         */
135        public void run() {
136            log.debug("Subscribing to: " + pattern);
137            // we can't do service lookup during startup here because we're in a separate thread
138            RedisExecutor redisExecutor = this.redisExecutor;
139            this.redisExecutor = null;
140            redisExecutor.psubscribe(this, pattern);
141        }
142
143        /**
144         * To be called from the main thread to stop the subscription.
145         */
146        public void close() {
147            stop = true;
148            // send an empty message so that the dispatcher thread can be woken up and stop
149            publish("", new byte[0]);
150        }
151
152        @Override
153        public void onPSubscribe(String pattern, int subscribedChannels) {
154            subscribedLatch.countDown();
155            if (log.isDebugEnabled()) {
156                log.debug("Subscribed to: " + pattern);
157            }
158        }
159
160        public void onMessage(String channel, byte[] message) {
161            if (message == null) {
162                message = new byte[0];
163            }
164            if (log.isTraceEnabled()) {
165                log.trace("Message received from channel: " + channel + " (" + message.length + " bytes)");
166            }
167            String topic = channel.substring(namespace.length());
168            localPublish(topic, message);
169        }
170
171        public void onPMessage(String pattern, String channel, byte[] message) {
172            onMessage(channel, message);
173        }
174
175        @Override
176        public void proceed(Client client, String... channels) {
177            client.subscribe(channels);
178            flush(client);
179            processBinary(client);
180        }
181
182        @Override
183        public void proceedWithPatterns(Client client, String... patterns) {
184            client.psubscribe(patterns);
185            flush(client);
186            processBinary(client);
187        }
188
189        // stupid Jedis has a protected flush method
190        protected void flush(Client client) {
191            try {
192                Method m = redis.clients.jedis.Connection.class.getDeclaredMethod("flush");
193                m.setAccessible(true);
194                m.invoke(client);
195            } catch (ReflectiveOperationException e) {
196                throw new NuxeoException(e);
197            }
198        }
199
200        // patched process() to pass the raw binary message to onMessage and onPMessage
201        protected void processBinary(Client client) {
202            for (;;) {
203                List<Object> reply = client.getRawObjectMultiBulkReply();
204                if (stop) {
205                    return;
206                }
207                Object type = reply.get(0);
208                if (!(type instanceof byte[])) {
209                    throw new JedisException("Unknown message type: " + type);
210                }
211                byte[] btype = (byte[]) type;
212                if (Arrays.equals(MESSAGE.raw, btype)) {
213                    byte[] bchannel = (byte[]) reply.get(1);
214                    byte[] bmesg = (byte[]) reply.get(2);
215                    onMessage(toString(bchannel), bmesg);
216                } else if (Arrays.equals(PMESSAGE.raw, btype)) {
217                    byte[] bpattern = (byte[]) reply.get(1);
218                    byte[] bchannel = (byte[]) reply.get(2);
219                    byte[] bmesg = (byte[]) reply.get(3);
220                    onPMessage(toString(bpattern), toString(bchannel), bmesg);
221                } else if (Arrays.equals(SUBSCRIBE.raw, btype)) {
222                    byte[] bchannel = (byte[]) reply.get(1);
223                    onSubscribe(toString(bchannel), 0);
224                } else if (Arrays.equals(PSUBSCRIBE.raw, btype)) {
225                    byte[] bpattern = (byte[]) reply.get(1);
226                    onPSubscribe(toString(bpattern), 0);
227                } else if (Arrays.equals(UNSUBSCRIBE.raw, btype)) {
228                    byte[] bchannel = (byte[]) reply.get(1);
229                    onUnsubscribe(toString(bchannel), 0);
230                } else if (Arrays.equals(PUNSUBSCRIBE.raw, btype)) {
231                    byte[] bpattern = (byte[]) reply.get(1);
232                    onPUnsubscribe(toString(bpattern), 0);
233                } else {
234                    throw new JedisException("Unknown message: " + toString(btype));
235                }
236            }
237        }
238
239        protected String toString(byte[] bytes) {
240            return bytes == null ? null : SafeEncoder.encode(bytes);
241        }
242
243    }
244
245    // ===== PubSubService =====
246
247    @Override
248    public void publish(String topic, byte[] message) {
249        String channel = namespace + topic;
250        byte[] bchannel = SafeEncoder.encode(channel);
251        RedisExecutor redisExecutor = Framework.getService(RedisExecutor.class);
252        if (redisExecutor != null) {
253            redisExecutor.execute(jedis -> jedis.publish(bchannel, message));
254        }
255    }
256
257}