001/*
002 * (C) Copyright 2017 Nuxeo (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     Florent Guillaume
018 */
019package org.nuxeo.ecm.core.redis.contribs;
020
021import static redis.clients.jedis.Protocol.Keyword.MESSAGE;
022import static redis.clients.jedis.Protocol.Keyword.PMESSAGE;
023import static redis.clients.jedis.Protocol.Keyword.PSUBSCRIBE;
024import static redis.clients.jedis.Protocol.Keyword.PUNSUBSCRIBE;
025import static redis.clients.jedis.Protocol.Keyword.SUBSCRIBE;
026import static redis.clients.jedis.Protocol.Keyword.UNSUBSCRIBE;
027
028import java.lang.reflect.Method;
029import java.util.Arrays;
030import java.util.List;
031import java.util.Map;
032import java.util.concurrent.CountDownLatch;
033import java.util.concurrent.TimeUnit;
034import java.util.function.BiConsumer;
035
036import org.apache.commons.logging.Log;
037import org.apache.commons.logging.LogFactory;
038import org.nuxeo.ecm.core.api.NuxeoException;
039import org.nuxeo.ecm.core.redis.RedisAdmin;
040import org.nuxeo.ecm.core.redis.RedisExecutor;
041import org.nuxeo.runtime.api.Framework;
042import org.nuxeo.runtime.pubsub.AbstractPubSubProvider;
043import org.nuxeo.runtime.pubsub.PubSubProvider;
044
045import redis.clients.jedis.Client;
046import redis.clients.jedis.JedisPubSub;
047import redis.clients.jedis.exceptions.JedisException;
048import redis.clients.util.SafeEncoder;
049
050/**
051 * Redis implementation of {@link PubSubProvider}.
052 *
053 * @since 9.1
054 */
055public class RedisPubSubProvider extends AbstractPubSubProvider {
056
057    // package-private to avoid synthetic accessor for nested class
058    static final Log log = LogFactory.getLog(RedisPubSubProvider.class);
059
060    /** Maximum delay to wait for a channel subscription on startup. */
061    public static final long TIMEOUT_SUBSCRIBE_SECONDS = 5;
062
063    protected String namespace;
064
065    protected Dispatcher dispatcher;
066
067    protected Thread thread;
068
069    @Override
070    public void initialize(Map<String, List<BiConsumer<String, byte[]>>> subscribers) {
071        super.initialize(subscribers);
072        log.debug("Initializing");
073        namespace = Framework.getService(RedisAdmin.class).namespace();
074        dispatcher = new Dispatcher(namespace + "*");
075        thread = new Thread(dispatcher::run, "Nuxeo-PubSub-Redis");
076        thread.setUncaughtExceptionHandler((t, e) -> log.error("Uncaught error on thread " + t.getName(), e));
077        thread.setPriority(Thread.NORM_PRIORITY);
078        thread.setDaemon(true);
079        thread.start();
080        if (!dispatcher.awaitSubscribed(TIMEOUT_SUBSCRIBE_SECONDS, TimeUnit.SECONDS)) {
081            thread.interrupt();
082            throw new NuxeoException(
083                    "Failed to subscribe to Redis pubsub after " + TIMEOUT_SUBSCRIBE_SECONDS + "s");
084        }
085        log.debug("Initialized");
086    }
087
088    @Override
089    public void close() {
090        log.debug("Closing");
091        if (dispatcher != null) {
092            thread.interrupt();
093            thread = null;
094            dispatcher.close();
095            dispatcher = null;
096        }
097        log.debug("Closed");
098    }
099
100    /**
101     * Subscribes to the provided Redis channel pattern and dispatches received messages. Method {@code #run} must be
102     * called in a new thread.
103     */
104    public class Dispatcher extends JedisPubSub {
105
106        // we look this up during construction in the main thread,
107        // because service lookup is unavailable from alternative threads during startup
108        protected RedisExecutor redisExecutor;
109
110        protected final String pattern;
111
112        protected final CountDownLatch subscribedLatch;
113
114        protected volatile boolean stop;
115
116        public Dispatcher(String pattern) {
117            redisExecutor = Framework.getService(RedisExecutor.class);
118            this.pattern = pattern;
119            this.subscribedLatch = new CountDownLatch(1);
120        }
121
122        /**
123         * To be called from the main thread to wait for subscription to be effective.
124         */
125        public boolean awaitSubscribed(long timeout, TimeUnit unit) {
126            try {
127                return subscribedLatch.await(timeout, unit);
128            } catch (InterruptedException e) {
129                Thread.currentThread().interrupt();
130                throw new NuxeoException(e);
131            }
132        }
133
134        /**
135         * To be called from a new thread to do the actual Redis subscription and to dispatch messages.
136         */
137        public void run() {
138            log.debug("Subscribing to: " + pattern);
139            // we can't do service lookup during startup here because we're in a separate thread
140            RedisExecutor redisExecutor = this.redisExecutor;
141            this.redisExecutor = null;
142            redisExecutor.psubscribe(this, pattern);
143        }
144
145        /**
146         * To be called from the main thread to stop the subscription.
147         */
148        public void close() {
149            stop = true;
150            // send an empty message so that the dispatcher thread can be woken up and stop
151            publish("", new byte[0]);
152        }
153
154        @Override
155        public void onPSubscribe(String pattern, int subscribedChannels) {
156            subscribedLatch.countDown();
157            if (log.isDebugEnabled()) {
158                log.debug("Subscribed to: " + pattern);
159            }
160        }
161
162        public void onMessage(String channel, byte[] message) {
163            if (message == null) {
164                message = new byte[0];
165            }
166            if (log.isTraceEnabled()) {
167                log.trace("Message received from channel: " + channel + " (" + message.length + " bytes)");
168            }
169            String topic = channel.substring(namespace.length());
170            localPublish(topic, message);
171        }
172
173        public void onPMessage(String pattern, String channel, byte[] message) {
174            onMessage(channel, message);
175        }
176
177        @Override
178        public void proceed(Client client, String... channels) {
179            client.subscribe(channels);
180            flush(client);
181            processBinary(client);
182        }
183
184        @Override
185        public void proceedWithPatterns(Client client, String... patterns) {
186            client.psubscribe(patterns);
187            flush(client);
188            processBinary(client);
189        }
190
191        // stupid Jedis has a protected flush method
192        protected void flush(Client client) {
193            try {
194                Method m = redis.clients.jedis.Connection.class.getDeclaredMethod("flush");
195                m.setAccessible(true);
196                m.invoke(client);
197            } catch (ReflectiveOperationException e) {
198                throw new NuxeoException(e);
199            }
200        }
201
202        // patched process() to pass the raw binary message to onMessage and onPMessage
203        protected void processBinary(Client client) {
204            for (;;) {
205                List<Object> reply = client.getRawObjectMultiBulkReply();
206                if (stop) {
207                    return;
208                }
209                Object type = reply.get(0);
210                if (!(type instanceof byte[])) {
211                    throw new JedisException("Unknown message type: " + type);
212                }
213                byte[] btype = (byte[]) type;
214                if (Arrays.equals(MESSAGE.raw, btype)) {
215                    byte[] bchannel = (byte[]) reply.get(1);
216                    byte[] bmesg = (byte[]) reply.get(2);
217                    onMessage(toString(bchannel), bmesg);
218                } else if (Arrays.equals(PMESSAGE.raw, btype)) {
219                    byte[] bpattern = (byte[]) reply.get(1);
220                    byte[] bchannel = (byte[]) reply.get(2);
221                    byte[] bmesg = (byte[]) reply.get(3);
222                    onPMessage(toString(bpattern), toString(bchannel), bmesg);
223                } else if (Arrays.equals(SUBSCRIBE.raw, btype)) {
224                    byte[] bchannel = (byte[]) reply.get(1);
225                    onSubscribe(toString(bchannel), 0);
226                } else if (Arrays.equals(PSUBSCRIBE.raw, btype)) {
227                    byte[] bpattern = (byte[]) reply.get(1);
228                    onPSubscribe(toString(bpattern), 0);
229                } else if (Arrays.equals(UNSUBSCRIBE.raw, btype)) {
230                    byte[] bchannel = (byte[]) reply.get(1);
231                    onUnsubscribe(toString(bchannel), 0);
232                } else if (Arrays.equals(PUNSUBSCRIBE.raw, btype)) {
233                    byte[] bpattern = (byte[]) reply.get(1);
234                    onPUnsubscribe(toString(bpattern), 0);
235                } else {
236                    throw new JedisException("Unknown message: " + toString(btype));
237                }
238            }
239        }
240
241        protected String toString(byte[] bytes) {
242            return bytes == null ? null : SafeEncoder.encode(bytes);
243        }
244
245    }
246
247    // ===== PubSubService =====
248
249    @Override
250    public void publish(String topic, byte[] message) {
251        String channel = namespace + topic;
252        byte[] bchannel = SafeEncoder.encode(channel);
253        RedisExecutor redisExecutor = Framework.getService(RedisExecutor.class);
254        if (redisExecutor != null) {
255            redisExecutor.execute(jedis -> jedis.publish(bchannel, message));
256        }
257    }
258
259}