001/*
002 * (C) Copyright 2020 Nuxeo (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     bdelbosc
018 */
019package org.nuxeo.runtime.codec;
020
021import java.io.ByteArrayOutputStream;
022import java.io.IOException;
023import java.nio.ByteBuffer;
024import java.util.ArrayList;
025import java.util.List;
026
027import org.apache.avro.Schema;
028import org.apache.avro.SchemaBuilder;
029import org.apache.avro.generic.GenericData;
030import org.apache.avro.generic.GenericRecord;
031import org.apache.avro.message.RawMessageDecoder;
032import org.apache.avro.message.RawMessageEncoder;
033import org.apache.avro.reflect.ReflectData;
034import org.apache.logging.log4j.LogManager;
035import org.apache.logging.log4j.Logger;
036import org.nuxeo.lib.stream.StreamRuntimeException;
037import org.nuxeo.lib.stream.codec.AvroConfluentCodec;
038import org.nuxeo.lib.stream.codec.Codec;
039import org.nuxeo.lib.stream.computation.Record;
040import org.nuxeo.lib.stream.computation.Watermark;
041
042import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient;
043import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException;
044import io.confluent.kafka.serializers.KafkaAvroSerializer;
045
046/**
047 * Instead of having an Avro Record envelop that contains a data encoded in Avro, this structure is a flat Avro message
048 * joining schemas of the Record and data.
049 *
050 * This encoding can then be read by any Confluent Avro reader.
051 *
052 * @since 11.4
053 */
054public class AvroRecordCodec<T extends Record> implements Codec<T> {
055    private static final Logger log = LogManager.getLogger(AvroRecordCodec.class);
056
057    public static final String NAME = "avroRecord";
058
059    public static final String RECORD_KEY = "recordKey";
060
061    public static final String RECORD_WATERMARK = "recordWatermark";
062
063    public static final String RECORD_TIMESTAMP = "recordTimestamp";
064
065    public static final String RECORD_FLAGS = "recordFlags";
066
067    protected final Schema schema;
068
069    protected final int schemaId;
070
071    protected final Schema messageSchema;
072
073    protected final int messageSchemaId;
074
075    protected final RawMessageDecoder<GenericRecord> messageDecoder;
076
077    protected final RawMessageEncoder<GenericRecord> messageEncoder;
078
079    protected final KafkaAvroSerializer serializer;
080
081    protected final RawMessageEncoder<GenericRecord> encoder;
082
083    protected final SchemaRegistryClient client;
084
085    public AvroRecordCodec(Schema messageSchema, String schemaRegistryUrls) {
086        this.messageSchema = messageSchema;
087        this.client = AvroConfluentCodec.getRegistryClient(schemaRegistryUrls);
088        this.serializer = new KafkaAvroSerializer(client);
089        // extends the schema to support record fields
090        this.schema = addRecordFieldsToSchema(messageSchema);
091        log.trace("msg schema: {}", () -> this.messageSchema.toString(true));
092        log.trace("rec + msg schema: {}", () -> this.schema.toString(true));
093        // register schemas
094        try {
095            this.messageSchemaId = client.register(messageSchema.getName(), messageSchema);
096            this.schemaId = client.register(schema.getName(), schema);
097        } catch (RestClientException | IOException e) {
098            throw new StreamRuntimeException(e);
099        }
100        // create encoder and decoder
101        this.encoder = new RawMessageEncoder<>(GenericData.get(), schema);
102        this.messageDecoder = new RawMessageDecoder<>(GenericData.get(), messageSchema);
103        this.messageEncoder = new RawMessageEncoder<>(GenericData.get(), messageSchema);
104    }
105
106    public AvroRecordCodec(String messageClassName, String schemaRegistryUrls) throws ClassNotFoundException {
107        this(ReflectData.get().getSchema(Class.forName(messageClassName)), schemaRegistryUrls);
108    }
109
110    @Override
111    public String getName() {
112        return NAME;
113    }
114
115    @Override
116    public byte[] encode(T record) {
117        try {
118            // decode the message as generic record
119            GenericRecord message = messageDecoder.decode(record.getData(), null);
120            // Create a new generic record that contains both record and message fields
121            GenericRecord newRecord = createRecordFromMessage(message);
122            // populate record fields
123            newRecord.put(RECORD_KEY, record.getKey());
124            newRecord.put(RECORD_WATERMARK, record.getWatermark());
125            newRecord.put(RECORD_TIMESTAMP, Watermark.ofValue(record.getWatermark()).getTimestamp());
126            newRecord.put(RECORD_FLAGS, Byte.valueOf(record.getFlagsAsByte()).intValue());
127            // encode
128            ByteArrayOutputStream out = new ByteArrayOutputStream();
129            out.write(AvroConfluentCodec.MAGIC_BYTE);
130            try {
131                out.write(ByteBuffer.allocate(AvroConfluentCodec.ID_SIZE).putInt(schemaId).array());
132                out.write(encoder.encode(newRecord).array());
133            } catch (IOException e) {
134                throw new StreamRuntimeException(e);
135            }
136            return out.toByteArray();
137        } catch (IOException e) {
138            throw new IllegalArgumentException(e);
139        }
140    }
141
142    protected GenericRecord createRecordFromMessage(GenericRecord message) {
143        GenericData.Record ret = new GenericData.Record(schema);
144        for (Schema.Field field : message.getSchema().getFields()) {
145            Object value = message.get(field.pos());
146            ret.put(field.name(), value);
147        }
148        return ret;
149    }
150
151    protected Schema addRecordFieldsToSchema(Schema schema) {
152        List<Schema.Field> fields = new ArrayList<>();
153        for (Schema.Field field : schema.getFields()) {
154            fields.add(new Schema.Field(field.name(), field.schema(), field.doc(), field.defaultVal()));
155        }
156        fields.add(new Schema.Field(RECORD_KEY, SchemaBuilder.builder().stringType(), "record key", null));
157        fields.add(new Schema.Field(RECORD_WATERMARK, SchemaBuilder.builder().longType(), "record watermark", 0L));
158        fields.add(new Schema.Field(RECORD_TIMESTAMP, SchemaBuilder.builder().longType(), "record timestamp", 0L));
159        fields.add(new Schema.Field(RECORD_FLAGS, SchemaBuilder.builder().intType(), "record flags", 0));
160        return Schema.createRecord(schema.getName() + "Record", schema.getDoc(), schema.getNamespace(), false, fields);
161    }
162
163    @Override
164    public T decode(byte[] data) {
165        ByteBuffer buffer = ByteBuffer.wrap(data);
166        if (buffer.get() != AvroConfluentCodec.MAGIC_BYTE) {
167            throw new IllegalArgumentException("Invalid Avro Confluent message, expecting magic byte");
168        }
169        int id = buffer.getInt();
170        Schema writeSchema;
171        try {
172            writeSchema = client.getById(id);
173        } catch (IOException | RestClientException e) {
174            throw new StreamRuntimeException("Cannot retrieve write schema id: " + id, e);
175        }
176        RawMessageDecoder<GenericRecord> decoder = new RawMessageDecoder<>(GenericData.get(), writeSchema, schema);
177        try {
178            GenericRecord rec = decoder.decode(buffer.slice(), null);
179            log.trace("GenericRecord: {}", rec);
180            String key = rec.get(RECORD_KEY).toString();
181            long wm = (Long) rec.get(RECORD_WATERMARK);
182            int flag = (Integer) rec.get(RECORD_FLAGS);
183            byte[] msgData = messageEncoder.encode(rec).array();
184            Record ret = new Record(key, msgData, wm);
185            ret.setFlags((byte) flag);
186            return (T) ret;
187        } catch (IOException | IndexOutOfBoundsException e) {
188            throw new IllegalArgumentException(e);
189        }
190    }
191
192}