001/*
002 * (C) Copyright 2009 Nuxeo SA (http://nuxeo.com/) and others.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 *
016 * Contributors:
017 *     Olivier Grisel
018 */
019
020package org.nuxeo.ecm.platform.categorization.categorizer.tfidf;
021
022import java.io.BufferedReader;
023import java.io.File;
024import java.io.FileInputStream;
025import java.io.FileNotFoundException;
026import java.io.FileOutputStream;
027import java.io.IOException;
028import java.io.InputStream;
029import java.io.InputStreamReader;
030import java.io.ObjectInputStream;
031import java.io.ObjectOutputStream;
032import java.io.OutputStream;
033import java.io.Serializable;
034import java.nio.charset.Charset;
035import java.util.ArrayList;
036import java.util.Collections;
037import java.util.Comparator;
038import java.util.LinkedHashMap;
039import java.util.LinkedList;
040import java.util.List;
041import java.util.Map;
042import java.util.Set;
043import java.util.SortedMap;
044import java.util.TreeMap;
045import java.util.TreeSet;
046import java.util.concurrent.ConcurrentHashMap;
047import java.util.zip.GZIPInputStream;
048import java.util.zip.GZIPOutputStream;
049
050import org.apache.commons.logging.Log;
051import org.apache.commons.logging.LogFactory;
052import org.apache.lucene.analysis.Analyzer;
053import org.apache.lucene.analysis.TokenStream;
054import org.apache.lucene.analysis.standard.StandardAnalyzer;
055import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
056import org.apache.lucene.util.Version;
057import org.nuxeo.ecm.platform.categorization.service.Categorizer;
058
059/**
060 * Maintains a map of TF counts vectors in memory (just for a few reference documents or topics) along with the common
061 * IDF estimate of all previously seen text content.
062 * <p>
063 * See: http://en.wikipedia.org/wiki/Tfidf
064 * <p>
065 * Classification is then achieved using the cosine similarity between the TF-IDF of the document to classify and the
066 * registered topics.
067 */
068public class TfIdfCategorizer extends PrimitiveVectorHelper implements Categorizer, Serializable {
069
070    private static final long serialVersionUID = 1L;
071
072    public static final Log log = LogFactory.getLog(TfIdfCategorizer.class);
073
074    protected final Set<String> topicNames = new TreeSet<String>();
075
076    protected final Map<String, Object> topicTermCount = new ConcurrentHashMap<String, Object>();
077
078    protected final Map<String, Object> cachedTopicTfIdf = new ConcurrentHashMap<String, Object>();
079
080    protected final Map<String, Float> cachedTopicTfIdfNorm = new ConcurrentHashMap<String, Float>();
081
082    protected long[] allTermCounts;
083
084    protected final int dim;
085
086    protected float[] cachedIdf;
087
088    protected long totalTermCount = 0;
089
090    protected final HashingVectorizer vectorizer;
091
092    protected transient Analyzer analyzer;
093
094    protected Double ratioOverMedian = 3.0;
095
096    protected boolean updateDisabled = false;
097
098    public TfIdfCategorizer() {
099        this(524288); // 2 ** 19
100    }
101
102    public TfIdfCategorizer(int dim) {
103        this.dim = dim;
104        allTermCounts = new long[dim];
105        vectorizer = new HashingVectorizer().dimension(dim);
106    }
107
108    public HashingVectorizer getVectorizer() {
109        return vectorizer;
110    }
111
112    public Analyzer getAnalyzer() {
113        if (analyzer == null) {
114            // TODO: make it possible to configure the stop words
115            analyzer = new StandardAnalyzer(Version.LUCENE_47);
116        }
117        return analyzer;
118    }
119
120    /**
121     * Precompute all the TF-IDF vectors and unload the original count vectors to spare some memory. Updates won't be
122     * possible any more.
123     */
124    public synchronized void disableUpdate() {
125        updateDisabled = true;
126        // compute all the frequencies
127        getIdf();
128        for (String topicName : topicNames) {
129            tfidf(topicName);
130            tfidfNorm(topicName);
131        }
132        // upload the count vectors
133        topicTermCount.clear();
134        allTermCounts = null;
135    }
136
137    /**
138     * Update the model to take into account the statistical properties of a document that is known to be relevant to
139     * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see
140     * #getSimilarities(List)
141     *
142     * @param topicName the name of the document topic or category
143     * @param terms the list of document tokens (use a lucene analyzer to extract theme for instance)
144     */
145    public void update(String topicName, List<String> terms) {
146        if (updateDisabled) {
147            throw new IllegalStateException("updates are no longer authorized once #disableUpdate has been called");
148        }
149        long[] counts = vectorizer.count(terms);
150        totalTermCount += sum(counts);
151        long[] topicCounts = (long[]) topicTermCount.get(topicName);
152        if (topicCounts == null) {
153            topicCounts = new long[dim];
154            topicTermCount.put(topicName, topicCounts);
155            topicNames.add(topicName);
156        }
157        add(topicCounts, counts);
158        add(allTermCounts, counts);
159        invalidateCache(topicName);
160    }
161
162    /**
163     * Update the model to take into account the statistical properties of a document that is known to be relevant to
164     * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see
165     * #getSimilarities(List)
166     *
167     * @param topicName the name of the document topic or category
168     * @param textContent textual content to be tokenized and analyzed
169     */
170    public void update(String topicName, String textContent) {
171        update(topicName, tokenize(textContent));
172    }
173
174    protected void invalidateCache(String topicName) {
175        cachedTopicTfIdf.remove(topicName);
176        cachedTopicTfIdfNorm.remove(topicName);
177        cachedIdf = null;
178    }
179
180    protected void invalidateCache() {
181        for (String topicName : topicNames) {
182            invalidateCache(topicName);
183        }
184    }
185
186    /**
187     * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the
188     * document given by a list of tokens.
189     *
190     * @param terms a tokenized document.
191     * @return a map of topic names to float values from 0 to 1 sorted by reverse value.
192     */
193    public Map<String, Float> getSimilarities(List<String> terms) {
194        SortedMap<String, Float> similarities = new TreeMap<String, Float>();
195
196        float[] tfidf1 = getTfIdf(vectorizer.count(terms));
197        float norm1 = normOf(tfidf1);
198        if (norm1 == 0) {
199            return similarities;
200        }
201
202        for (String topicName : topicNames) {
203            float[] tfidf2 = tfidf(topicName);
204            float norm2 = tfidfNorm(topicName);
205            if (norm2 == 0) {
206                continue;
207            }
208            similarities.put(topicName, dot(tfidf1, tfidf2) / (norm1 * norm2));
209        }
210        return sortByDecreasingValue(similarities);
211    }
212
213    /**
214     * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the
215     * document.
216     *
217     * @param the document to be tokenized and analyzed
218     * @return a map of topic names to float values from 0 to 1 sorted by reverse value.
219     */
220    public Map<String, Float> getSimilarities(String allThePets) {
221        return getSimilarities(tokenize(allThePets));
222    }
223
224    protected float tfidfNorm(String topicName) {
225        Float norm = cachedTopicTfIdfNorm.get(topicName);
226        if (norm == null) {
227            norm = normOf(tfidf(topicName));
228            cachedTopicTfIdfNorm.put(topicName, norm);
229        }
230        return norm.floatValue();
231    }
232
233    protected float[] tfidf(String topicName) {
234        float[] tfidf = (float[]) cachedTopicTfIdf.get(topicName);
235        if (tfidf == null) {
236            tfidf = getTfIdf((long[]) topicTermCount.get(topicName));
237            cachedTopicTfIdf.put(topicName, tfidf);
238        }
239        return tfidf;
240    }
241
242    protected float[] getTfIdf(long[] counts) {
243        float[] idf = getIdf();
244        float[] tfidf = new float[counts.length];
245        long sum = sum(counts);
246        if (sum == 0) {
247            return tfidf;
248        }
249        for (int i = 0; i < counts.length; i++) {
250            tfidf[i] = ((float) counts[i]) / sum * idf[i];
251        }
252        return tfidf;
253    }
254
255    protected float[] getIdf() {
256        if (cachedIdf == null) {
257            float[] idf = new float[allTermCounts.length];
258            for (int i = 0; i < allTermCounts.length; i++) {
259                if (allTermCounts[i] == 0) {
260                    idf[i] = 0;
261                } else {
262                    idf[i] = (float) Math.log1p(((float) totalTermCount) / allTermCounts[i]);
263                }
264            }
265            // atomic update to ensure thread-safeness
266            cachedIdf = idf;
267        }
268        return cachedIdf;
269    }
270
271    public int getDimension() {
272        return dim;
273    }
274
275    /**
276     * Utility method to initialize the parameters from a set of UTF-8 encoded text files with names used as topic
277     * names.
278     * <p>
279     * The content of the file to assumed to be lines of terms separated by whitespaces without punctuation.
280     *
281     * @param folder
282     */
283    public void learnFiles(File folder) throws IOException {
284        if (!folder.isDirectory()) {
285            throw new IOException(String.format("%s is not a folder", folder.getAbsolutePath()));
286        }
287        for (File file : folder.listFiles()) {
288            if (file.isDirectory()) {
289                continue;
290            }
291            String topicName = file.getName();
292            if (topicName.contains(".")) {
293                topicName = topicName.substring(0, topicName.indexOf('.'));
294            }
295            log.info(String.format("About to analyze file %s", file.getAbsolutePath()));
296            FileInputStream is = new FileInputStream(file);
297            try {
298                BufferedReader reader = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8")));
299                String line = reader.readLine();
300                int i = 0;
301                while (line != null) {
302                    update(topicName, line);
303                    line = reader.readLine();
304                    i++;
305                    if (i % 10000 == 0) {
306                        log.info(String.format("Analyzed %d lines from '%s'", i, file.getAbsolutePath()));
307                    }
308                }
309            } finally {
310                is.close();
311            }
312        }
313    }
314
315    /**
316     * Save the model to a compressed binary format on the filesystem.
317     *
318     * @param file where to write the model
319     */
320    public void saveToFile(File file) throws IOException {
321        FileOutputStream out = new FileOutputStream(file);
322        try {
323            saveToStream(out);
324        } finally {
325            out.close();
326        }
327    }
328
329    /**
330     * Save a compressed binary representation of the trained model.
331     *
332     * @param out the output stream to write to
333     */
334    public void saveToStream(OutputStream out) throws IOException {
335        if (updateDisabled) {
336            throw new IllegalStateException("model in disabled update mode cannot be saved");
337        }
338        invalidateCache();
339        GZIPOutputStream gzOut = new GZIPOutputStream(out);
340        ObjectOutputStream objOut = new ObjectOutputStream(gzOut);
341        objOut.writeObject(this);
342        gzOut.finish();
343    }
344
345    /**
346     * Load a TfIdfCategorizer instance from it's compressed binary representation.
347     *
348     * @param in the input stream to read from
349     * @return a new instance with parameters coming from the saved version
350     */
351    public static TfIdfCategorizer load(InputStream in) throws IOException, ClassNotFoundException {
352        GZIPInputStream gzIn = new GZIPInputStream(in);
353        ObjectInputStream objIn = new ObjectInputStream(gzIn);
354        TfIdfCategorizer cat = (TfIdfCategorizer) objIn.readObject();
355        log.info(String.format("Sucessfully loaded model with %d topics, dimension %d and density %f",
356                cat.getTopicNames().size(), cat.getDimension(), cat.getDensity()));
357        return cat;
358    }
359
360    public double getDensity() {
361        long sum = 0;
362        for (Object singleTopicTermCount : topicTermCount.values()) {
363            for (long c : (long[]) singleTopicTermCount) {
364                sum += c != 0L ? 1 : 0;
365            }
366        }
367        for (long c : allTermCounts) {
368            sum += c != 0 ? 1 : 0;
369        }
370        return ((double) sum) / ((topicNames.size() + 1) * getDimension());
371    }
372
373    public Set<String> getTopicNames() {
374        return topicNames;
375    }
376
377    /**
378     * Load a TfIdfCategorizer instance from it's compressed binary representation from a named resource in the
379     * classloading path of the current thread.
380     *
381     * @param modelPath the path of the file model in the classloading path
382     * @return a new instance with parameters coming from the saved version
383     */
384    public static TfIdfCategorizer load(String modelPath) throws IOException, ClassNotFoundException {
385        ClassLoader loader = Thread.currentThread().getContextClassLoader();
386        return load(loader.getResourceAsStream(modelPath));
387    }
388
389    public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException {
390        if (args.length < 2 || args.length > 3) {
391            System.out.println("Train a model:\n" + "First argument is the model filename (e.g. my-model.gz)\n"
392                    + "Second argument is the path to a folder with UTF-8 text files\n"
393                    + "Third optional argument is the dimension of the model");
394            System.exit(0);
395        }
396        File modelFile = new File(args[0]);
397        TfIdfCategorizer categorizer;
398        if (modelFile.exists()) {
399            log.info("Loading model from: " + modelFile.getAbsolutePath());
400            FileInputStream is = new FileInputStream(modelFile);
401            try {
402                categorizer = load(is);
403            } finally {
404                is.close();
405            }
406        } else {
407            if (args.length == 3) {
408                categorizer = new TfIdfCategorizer(Integer.valueOf(args[2]));
409            } else {
410                categorizer = new TfIdfCategorizer();
411            }
412            log.info("Initializing new model with dimension: " + categorizer.getDimension());
413        }
414        categorizer.learnFiles(new File(args[1]));
415        log.info("Saving trained model to: " + modelFile.getAbsolutePath());
416        categorizer.saveToFile(modelFile);
417    }
418
419    public List<String> guessCategories(String textContent, int maxSuggestions) {
420        return guessCategories(textContent, maxSuggestions, null);
421    }
422
423    public List<String> guessCategories(String textContent, int maxSuggestions, Double precisionThreshold) {
424        precisionThreshold = precisionThreshold == null ? ratioOverMedian : precisionThreshold;
425        Map<String, Float> sims = getSimilarities(tokenize(textContent));
426        Float median = findMedian(sims);
427        List<String> suggested = new ArrayList<String>();
428        for (Map.Entry<String, Float> sim : sims.entrySet()) {
429            double ratio = median != 0 ? sim.getValue() / median : 100;
430            if (suggested.size() >= maxSuggestions || ratio < precisionThreshold) {
431                break;
432            }
433            suggested.add(sim.getKey());
434        }
435        return suggested;
436    }
437
438    public List<String> tokenize(String textContent) {
439        try {
440            List<String> terms = new ArrayList<String>();
441            TokenStream tokenStream = getAnalyzer().tokenStream(null, textContent);
442            CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
443            tokenStream.reset();
444            while (tokenStream.incrementToken()) {
445                terms.add(charTermAttribute.toString());
446            }
447            tokenStream.end();
448            tokenStream.close();
449            return terms;
450        } catch (IOException e) {
451            throw new IllegalStateException(e);
452        }
453    }
454
455    public static Map<String, Float> sortByDecreasingValue(Map<String, Float> map) {
456        List<Map.Entry<String, Float>> list = new LinkedList<Map.Entry<String, Float>>(map.entrySet());
457        Collections.sort(list, new Comparator<Map.Entry<String, Float>>() {
458            public int compare(Map.Entry<String, Float> e1, Map.Entry<String, Float> e2) {
459                return -e1.getValue().compareTo(e2.getValue());
460            }
461        });
462        Map<String, Float> result = new LinkedHashMap<String, Float>();
463        for (Map.Entry<String, Float> e : list) {
464            result.put(e.getKey(), e.getValue());
465        }
466        return result;
467    }
468
469    public static Float findMedian(Map<String, Float> sortedMap) {
470        int remaining = sortedMap.size() / 2;
471        Float median = 0.0f;
472        for (Float value : sortedMap.values()) {
473            median = value;
474            if (remaining-- <= 0) {
475                break;
476            }
477        }
478        return median;
479    }
480
481}