001/*
002 * (C) Copyright 2009 Nuxeo SAS (http://nuxeo.com/) and contributors.
003 *
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the GNU Lesser General Public License
006 * (LGPL) version 2.1 which accompanies this distribution, and is available at
007 * http://www.gnu.org/licenses/lgpl.html
008 *
009 * This library is distributed in the hope that it will be useful,
010 * but WITHOUT ANY WARRANTY; without even the implied warranty of
011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
012 * Lesser General Public License for more details.
013 *
014 * Contributors:
015 *     Olivier Grisel
016 */
017
018package org.nuxeo.ecm.platform.categorization.categorizer.tfidf;
019
020import java.io.BufferedReader;
021import java.io.File;
022import java.io.FileInputStream;
023import java.io.FileNotFoundException;
024import java.io.FileOutputStream;
025import java.io.IOException;
026import java.io.InputStream;
027import java.io.InputStreamReader;
028import java.io.ObjectInputStream;
029import java.io.ObjectOutputStream;
030import java.io.OutputStream;
031import java.io.Serializable;
032import java.nio.charset.Charset;
033import java.util.ArrayList;
034import java.util.Collections;
035import java.util.Comparator;
036import java.util.LinkedHashMap;
037import java.util.LinkedList;
038import java.util.List;
039import java.util.Map;
040import java.util.Set;
041import java.util.SortedMap;
042import java.util.TreeMap;
043import java.util.TreeSet;
044import java.util.concurrent.ConcurrentHashMap;
045import java.util.zip.GZIPInputStream;
046import java.util.zip.GZIPOutputStream;
047
048import org.apache.commons.logging.Log;
049import org.apache.commons.logging.LogFactory;
050import org.apache.lucene.analysis.Analyzer;
051import org.apache.lucene.analysis.TokenStream;
052import org.apache.lucene.analysis.standard.StandardAnalyzer;
053import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
054import org.apache.lucene.util.Version;
055import org.nuxeo.ecm.platform.categorization.service.Categorizer;
056
057/**
058 * Maintains a map of TF counts vectors in memory (just for a few reference documents or topics) along with the common
059 * IDF estimate of all previously seen text content.
060 * <p>
061 * See: http://en.wikipedia.org/wiki/Tfidf
062 * <p>
063 * Classification is then achieved using the cosine similarity between the TF-IDF of the document to classify and the
064 * registered topics.
065 */
066public class TfIdfCategorizer extends PrimitiveVectorHelper implements Categorizer, Serializable {
067
068    private static final long serialVersionUID = 1L;
069
070    public static final Log log = LogFactory.getLog(TfIdfCategorizer.class);
071
072    protected final Set<String> topicNames = new TreeSet<String>();
073
074    protected final Map<String, Object> topicTermCount = new ConcurrentHashMap<String, Object>();
075
076    protected final Map<String, Object> cachedTopicTfIdf = new ConcurrentHashMap<String, Object>();
077
078    protected final Map<String, Float> cachedTopicTfIdfNorm = new ConcurrentHashMap<String, Float>();
079
080    protected long[] allTermCounts;
081
082    protected final int dim;
083
084    protected float[] cachedIdf;
085
086    protected long totalTermCount = 0;
087
088    protected final HashingVectorizer vectorizer;
089
090    protected transient Analyzer analyzer;
091
092    protected Double ratioOverMedian = 3.0;
093
094    protected boolean updateDisabled = false;
095
096    public TfIdfCategorizer() {
097        this(524288); // 2 ** 19
098    }
099
100    public TfIdfCategorizer(int dim) {
101        this.dim = dim;
102        allTermCounts = new long[dim];
103        vectorizer = new HashingVectorizer().dimension(dim);
104    }
105
106    public HashingVectorizer getVectorizer() {
107        return vectorizer;
108    }
109
110    public Analyzer getAnalyzer() {
111        if (analyzer == null) {
112            // TODO: make it possible to configure the stop words
113            analyzer = new StandardAnalyzer(Version.LUCENE_47);
114        }
115        return analyzer;
116    }
117
118    /**
119     * Precompute all the TF-IDF vectors and unload the original count vectors to spare some memory. Updates won't be
120     * possible any more.
121     */
122    public synchronized void disableUpdate() {
123        updateDisabled = true;
124        // compute all the frequencies
125        getIdf();
126        for (String topicName : topicNames) {
127            tfidf(topicName);
128            tfidfNorm(topicName);
129        }
130        // upload the count vectors
131        topicTermCount.clear();
132        allTermCounts = null;
133    }
134
135    /**
136     * Update the model to take into account the statistical properties of a document that is known to be relevant to
137     * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see
138     * #getSimilarities(List)
139     *
140     * @param topicName the name of the document topic or category
141     * @param terms the list of document tokens (use a lucene analyzer to extract theme for instance)
142     */
143    public void update(String topicName, List<String> terms) {
144        if (updateDisabled) {
145            throw new IllegalStateException("updates are no longer authorized once #disableUpdate has been called");
146        }
147        long[] counts = vectorizer.count(terms);
148        totalTermCount += sum(counts);
149        long[] topicCounts = (long[]) topicTermCount.get(topicName);
150        if (topicCounts == null) {
151            topicCounts = new long[dim];
152            topicTermCount.put(topicName, topicCounts);
153            topicNames.add(topicName);
154        }
155        add(topicCounts, counts);
156        add(allTermCounts, counts);
157        invalidateCache(topicName);
158    }
159
160    /**
161     * Update the model to take into account the statistical properties of a document that is known to be relevant to
162     * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see
163     * #getSimilarities(List)
164     *
165     * @param topicName the name of the document topic or category
166     * @param textContent textual content to be tokenized and analyzed
167     */
168    public void update(String topicName, String textContent) {
169        update(topicName, tokenize(textContent));
170    }
171
172    protected void invalidateCache(String topicName) {
173        cachedTopicTfIdf.remove(topicName);
174        cachedTopicTfIdfNorm.remove(topicName);
175        cachedIdf = null;
176    }
177
178    protected void invalidateCache() {
179        for (String topicName : topicNames) {
180            invalidateCache(topicName);
181        }
182    }
183
184    /**
185     * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the
186     * document given by a list of tokens.
187     *
188     * @param terms a tokenized document.
189     * @return a map of topic names to float values from 0 to 1 sorted by reverse value.
190     */
191    public Map<String, Float> getSimilarities(List<String> terms) {
192        SortedMap<String, Float> similarities = new TreeMap<String, Float>();
193
194        float[] tfidf1 = getTfIdf(vectorizer.count(terms));
195        float norm1 = normOf(tfidf1);
196        if (norm1 == 0) {
197            return similarities;
198        }
199
200        for (String topicName : topicNames) {
201            float[] tfidf2 = tfidf(topicName);
202            float norm2 = tfidfNorm(topicName);
203            if (norm2 == 0) {
204                continue;
205            }
206            similarities.put(topicName, dot(tfidf1, tfidf2) / (norm1 * norm2));
207        }
208        return sortByDecreasingValue(similarities);
209    }
210
211    /**
212     * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the
213     * document.
214     *
215     * @param the document to be tokenized and analyzed
216     * @return a map of topic names to float values from 0 to 1 sorted by reverse value.
217     */
218    public Map<String, Float> getSimilarities(String allThePets) {
219        return getSimilarities(tokenize(allThePets));
220    }
221
222    protected float tfidfNorm(String topicName) {
223        Float norm = cachedTopicTfIdfNorm.get(topicName);
224        if (norm == null) {
225            norm = normOf(tfidf(topicName));
226            cachedTopicTfIdfNorm.put(topicName, norm);
227        }
228        return norm.floatValue();
229    }
230
231    protected float[] tfidf(String topicName) {
232        float[] tfidf = (float[]) cachedTopicTfIdf.get(topicName);
233        if (tfidf == null) {
234            tfidf = getTfIdf((long[]) topicTermCount.get(topicName));
235            cachedTopicTfIdf.put(topicName, tfidf);
236        }
237        return tfidf;
238    }
239
240    protected float[] getTfIdf(long[] counts) {
241        float[] idf = getIdf();
242        float[] tfidf = new float[counts.length];
243        long sum = sum(counts);
244        if (sum == 0) {
245            return tfidf;
246        }
247        for (int i = 0; i < counts.length; i++) {
248            tfidf[i] = ((float) counts[i]) / sum * idf[i];
249        }
250        return tfidf;
251    }
252
253    protected float[] getIdf() {
254        if (cachedIdf == null) {
255            float[] idf = new float[allTermCounts.length];
256            for (int i = 0; i < allTermCounts.length; i++) {
257                if (allTermCounts[i] == 0) {
258                    idf[i] = 0;
259                } else {
260                    idf[i] = (float) Math.log1p(((float) totalTermCount) / allTermCounts[i]);
261                }
262            }
263            // atomic update to ensure thread-safeness
264            cachedIdf = idf;
265        }
266        return cachedIdf;
267    }
268
269    public int getDimension() {
270        return dim;
271    }
272
273    /**
274     * Utility method to initialize the parameters from a set of UTF-8 encoded text files with names used as topic
275     * names.
276     * <p>
277     * The content of the file to assumed to be lines of terms separated by whitespaces without punctuation.
278     *
279     * @param folder
280     */
281    public void learnFiles(File folder) throws IOException {
282        if (!folder.isDirectory()) {
283            throw new IOException(String.format("%s is not a folder", folder.getAbsolutePath()));
284        }
285        for (File file : folder.listFiles()) {
286            if (file.isDirectory()) {
287                continue;
288            }
289            String topicName = file.getName();
290            if (topicName.contains(".")) {
291                topicName = topicName.substring(0, topicName.indexOf('.'));
292            }
293            log.info(String.format("About to analyze file %s", file.getAbsolutePath()));
294            FileInputStream is = new FileInputStream(file);
295            try {
296                BufferedReader reader = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8")));
297                String line = reader.readLine();
298                int i = 0;
299                while (line != null) {
300                    update(topicName, line);
301                    line = reader.readLine();
302                    i++;
303                    if (i % 10000 == 0) {
304                        log.info(String.format("Analyzed %d lines from '%s'", i, file.getAbsolutePath()));
305                    }
306                }
307            } finally {
308                is.close();
309            }
310        }
311    }
312
313    /**
314     * Save the model to a compressed binary format on the filesystem.
315     *
316     * @param file where to write the model
317     */
318    public void saveToFile(File file) throws IOException {
319        FileOutputStream out = new FileOutputStream(file);
320        try {
321            saveToStream(out);
322        } finally {
323            out.close();
324        }
325    }
326
327    /**
328     * Save a compressed binary representation of the trained model.
329     *
330     * @param out the output stream to write to
331     */
332    public void saveToStream(OutputStream out) throws IOException {
333        if (updateDisabled) {
334            throw new IllegalStateException("model in disabled update mode cannot be saved");
335        }
336        invalidateCache();
337        GZIPOutputStream gzOut = new GZIPOutputStream(out);
338        ObjectOutputStream objOut = new ObjectOutputStream(gzOut);
339        objOut.writeObject(this);
340        gzOut.finish();
341    }
342
343    /**
344     * Load a TfIdfCategorizer instance from it's compressed binary representation.
345     *
346     * @param in the input stream to read from
347     * @return a new instance with parameters coming from the saved version
348     */
349    public static TfIdfCategorizer load(InputStream in) throws IOException, ClassNotFoundException {
350        GZIPInputStream gzIn = new GZIPInputStream(in);
351        ObjectInputStream objIn = new ObjectInputStream(gzIn);
352        TfIdfCategorizer cat = (TfIdfCategorizer) objIn.readObject();
353        log.info(String.format("Sucessfully loaded model with %d topics, dimension %d and density %f",
354                cat.getTopicNames().size(), cat.getDimension(), cat.getDensity()));
355        return cat;
356    }
357
358    public double getDensity() {
359        long sum = 0;
360        for (Object singleTopicTermCount : topicTermCount.values()) {
361            for (long c : (long[]) singleTopicTermCount) {
362                sum += c != 0L ? 1 : 0;
363            }
364        }
365        for (long c : allTermCounts) {
366            sum += c != 0 ? 1 : 0;
367        }
368        return ((double) sum) / ((topicNames.size() + 1) * getDimension());
369    }
370
371    public Set<String> getTopicNames() {
372        return topicNames;
373    }
374
375    /**
376     * Load a TfIdfCategorizer instance from it's compressed binary representation from a named resource in the
377     * classloading path of the current thread.
378     *
379     * @param modelPath the path of the file model in the classloading path
380     * @return a new instance with parameters coming from the saved version
381     */
382    public static TfIdfCategorizer load(String modelPath) throws IOException, ClassNotFoundException {
383        ClassLoader loader = Thread.currentThread().getContextClassLoader();
384        return load(loader.getResourceAsStream(modelPath));
385    }
386
387    public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException {
388        if (args.length < 2 || args.length > 3) {
389            System.out.println("Train a model:\n" + "First argument is the model filename (e.g. my-model.gz)\n"
390                    + "Second argument is the path to a folder with UTF-8 text files\n"
391                    + "Third optional argument is the dimension of the model");
392            System.exit(0);
393        }
394        File modelFile = new File(args[0]);
395        TfIdfCategorizer categorizer;
396        if (modelFile.exists()) {
397            log.info("Loading model from: " + modelFile.getAbsolutePath());
398            FileInputStream is = new FileInputStream(modelFile);
399            try {
400                categorizer = load(is);
401            } finally {
402                is.close();
403            }
404        } else {
405            if (args.length == 3) {
406                categorizer = new TfIdfCategorizer(Integer.valueOf(args[2]));
407            } else {
408                categorizer = new TfIdfCategorizer();
409            }
410            log.info("Initializing new model with dimension: " + categorizer.getDimension());
411        }
412        categorizer.learnFiles(new File(args[1]));
413        log.info("Saving trained model to: " + modelFile.getAbsolutePath());
414        categorizer.saveToFile(modelFile);
415    }
416
417    public List<String> guessCategories(String textContent, int maxSuggestions) {
418        return guessCategories(textContent, maxSuggestions, null);
419    }
420
421    public List<String> guessCategories(String textContent, int maxSuggestions, Double precisionThreshold) {
422        precisionThreshold = precisionThreshold == null ? ratioOverMedian : precisionThreshold;
423        Map<String, Float> sims = getSimilarities(tokenize(textContent));
424        Float median = findMedian(sims);
425        List<String> suggested = new ArrayList<String>();
426        for (Map.Entry<String, Float> sim : sims.entrySet()) {
427            double ratio = median != 0 ? sim.getValue() / median : 100;
428            if (suggested.size() >= maxSuggestions || ratio < precisionThreshold) {
429                break;
430            }
431            suggested.add(sim.getKey());
432        }
433        return suggested;
434    }
435
436    public List<String> tokenize(String textContent) {
437        try {
438            List<String> terms = new ArrayList<String>();
439            TokenStream tokenStream = getAnalyzer().tokenStream(null, textContent);
440            CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
441            tokenStream.reset();
442            while (tokenStream.incrementToken()) {
443                terms.add(charTermAttribute.toString());
444            }
445            tokenStream.end();
446            tokenStream.close();
447            return terms;
448        } catch (IOException e) {
449            throw new IllegalStateException(e);
450        }
451    }
452
453    public static Map<String, Float> sortByDecreasingValue(Map<String, Float> map) {
454        List<Map.Entry<String, Float>> list = new LinkedList<Map.Entry<String, Float>>(map.entrySet());
455        Collections.sort(list, new Comparator<Map.Entry<String, Float>>() {
456            public int compare(Map.Entry<String, Float> e1, Map.Entry<String, Float> e2) {
457                return -e1.getValue().compareTo(e2.getValue());
458            }
459        });
460        Map<String, Float> result = new LinkedHashMap<String, Float>();
461        for (Map.Entry<String, Float> e : list) {
462            result.put(e.getKey(), e.getValue());
463        }
464        return result;
465    }
466
467    public static Float findMedian(Map<String, Float> sortedMap) {
468        int remaining = sortedMap.size() / 2;
469        Float median = 0.0f;
470        for (Float value : sortedMap.values()) {
471            median = value;
472            if (remaining-- <= 0) {
473                break;
474            }
475        }
476        return median;
477    }
478
479}