001/* 002 * (C) Copyright 2009 Nuxeo SAS (http://nuxeo.com/) and contributors. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the GNU Lesser General Public License 006 * (LGPL) version 2.1 which accompanies this distribution, and is available at 007 * http://www.gnu.org/licenses/lgpl.html 008 * 009 * This library is distributed in the hope that it will be useful, 010 * but WITHOUT ANY WARRANTY; without even the implied warranty of 011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 012 * Lesser General Public License for more details. 013 * 014 * Contributors: 015 * Olivier Grisel 016 */ 017 018package org.nuxeo.ecm.platform.categorization.categorizer.tfidf; 019 020import java.io.BufferedReader; 021import java.io.File; 022import java.io.FileInputStream; 023import java.io.FileNotFoundException; 024import java.io.FileOutputStream; 025import java.io.IOException; 026import java.io.InputStream; 027import java.io.InputStreamReader; 028import java.io.ObjectInputStream; 029import java.io.ObjectOutputStream; 030import java.io.OutputStream; 031import java.io.Serializable; 032import java.nio.charset.Charset; 033import java.util.ArrayList; 034import java.util.Collections; 035import java.util.Comparator; 036import java.util.LinkedHashMap; 037import java.util.LinkedList; 038import java.util.List; 039import java.util.Map; 040import java.util.Set; 041import java.util.SortedMap; 042import java.util.TreeMap; 043import java.util.TreeSet; 044import java.util.concurrent.ConcurrentHashMap; 045import java.util.zip.GZIPInputStream; 046import java.util.zip.GZIPOutputStream; 047 048import org.apache.commons.logging.Log; 049import org.apache.commons.logging.LogFactory; 050import org.apache.lucene.analysis.Analyzer; 051import org.apache.lucene.analysis.TokenStream; 052import org.apache.lucene.analysis.standard.StandardAnalyzer; 053import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 054import org.apache.lucene.util.Version; 055import org.nuxeo.ecm.platform.categorization.service.Categorizer; 056 057/** 058 * Maintains a map of TF counts vectors in memory (just for a few reference documents or topics) along with the common 059 * IDF estimate of all previously seen text content. 060 * <p> 061 * See: http://en.wikipedia.org/wiki/Tfidf 062 * <p> 063 * Classification is then achieved using the cosine similarity between the TF-IDF of the document to classify and the 064 * registered topics. 065 */ 066public class TfIdfCategorizer extends PrimitiveVectorHelper implements Categorizer, Serializable { 067 068 private static final long serialVersionUID = 1L; 069 070 public static final Log log = LogFactory.getLog(TfIdfCategorizer.class); 071 072 protected final Set<String> topicNames = new TreeSet<String>(); 073 074 protected final Map<String, Object> topicTermCount = new ConcurrentHashMap<String, Object>(); 075 076 protected final Map<String, Object> cachedTopicTfIdf = new ConcurrentHashMap<String, Object>(); 077 078 protected final Map<String, Float> cachedTopicTfIdfNorm = new ConcurrentHashMap<String, Float>(); 079 080 protected long[] allTermCounts; 081 082 protected final int dim; 083 084 protected float[] cachedIdf; 085 086 protected long totalTermCount = 0; 087 088 protected final HashingVectorizer vectorizer; 089 090 protected transient Analyzer analyzer; 091 092 protected Double ratioOverMedian = 3.0; 093 094 protected boolean updateDisabled = false; 095 096 public TfIdfCategorizer() { 097 this(524288); // 2 ** 19 098 } 099 100 public TfIdfCategorizer(int dim) { 101 this.dim = dim; 102 allTermCounts = new long[dim]; 103 vectorizer = new HashingVectorizer().dimension(dim); 104 } 105 106 public HashingVectorizer getVectorizer() { 107 return vectorizer; 108 } 109 110 public Analyzer getAnalyzer() { 111 if (analyzer == null) { 112 // TODO: make it possible to configure the stop words 113 analyzer = new StandardAnalyzer(Version.LUCENE_47); 114 } 115 return analyzer; 116 } 117 118 /** 119 * Precompute all the TF-IDF vectors and unload the original count vectors to spare some memory. Updates won't be 120 * possible any more. 121 */ 122 public synchronized void disableUpdate() { 123 updateDisabled = true; 124 // compute all the frequencies 125 getIdf(); 126 for (String topicName : topicNames) { 127 tfidf(topicName); 128 tfidfNorm(topicName); 129 } 130 // upload the count vectors 131 topicTermCount.clear(); 132 allTermCounts = null; 133 } 134 135 /** 136 * Update the model to take into account the statistical properties of a document that is known to be relevant to 137 * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see 138 * #getSimilarities(List) 139 * 140 * @param topicName the name of the document topic or category 141 * @param terms the list of document tokens (use a lucene analyzer to extract theme for instance) 142 */ 143 public void update(String topicName, List<String> terms) { 144 if (updateDisabled) { 145 throw new IllegalStateException("updates are no longer authorized once #disableUpdate has been called"); 146 } 147 long[] counts = vectorizer.count(terms); 148 totalTermCount += sum(counts); 149 long[] topicCounts = (long[]) topicTermCount.get(topicName); 150 if (topicCounts == null) { 151 topicCounts = new long[dim]; 152 topicTermCount.put(topicName, topicCounts); 153 topicNames.add(topicName); 154 } 155 add(topicCounts, counts); 156 add(allTermCounts, counts); 157 invalidateCache(topicName); 158 } 159 160 /** 161 * Update the model to take into account the statistical properties of a document that is known to be relevant to 162 * the given topic. Warning: this method is not thread safe: it should not be used concurrently with @see 163 * #getSimilarities(List) 164 * 165 * @param topicName the name of the document topic or category 166 * @param textContent textual content to be tokenized and analyzed 167 */ 168 public void update(String topicName, String textContent) { 169 update(topicName, tokenize(textContent)); 170 } 171 172 protected void invalidateCache(String topicName) { 173 cachedTopicTfIdf.remove(topicName); 174 cachedTopicTfIdfNorm.remove(topicName); 175 cachedIdf = null; 176 } 177 178 protected void invalidateCache() { 179 for (String topicName : topicNames) { 180 invalidateCache(topicName); 181 } 182 } 183 184 /** 185 * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the 186 * document given by a list of tokens. 187 * 188 * @param terms a tokenized document. 189 * @return a map of topic names to float values from 0 to 1 sorted by reverse value. 190 */ 191 public Map<String, Float> getSimilarities(List<String> terms) { 192 SortedMap<String, Float> similarities = new TreeMap<String, Float>(); 193 194 float[] tfidf1 = getTfIdf(vectorizer.count(terms)); 195 float norm1 = normOf(tfidf1); 196 if (norm1 == 0) { 197 return similarities; 198 } 199 200 for (String topicName : topicNames) { 201 float[] tfidf2 = tfidf(topicName); 202 float norm2 = tfidfNorm(topicName); 203 if (norm2 == 0) { 204 continue; 205 } 206 similarities.put(topicName, dot(tfidf1, tfidf2) / (norm1 * norm2)); 207 } 208 return sortByDecreasingValue(similarities); 209 } 210 211 /** 212 * For each registered topic, compute the cosine similarity of the TFIDF vector of the topic and the one of the 213 * document. 214 * 215 * @param the document to be tokenized and analyzed 216 * @return a map of topic names to float values from 0 to 1 sorted by reverse value. 217 */ 218 public Map<String, Float> getSimilarities(String allThePets) { 219 return getSimilarities(tokenize(allThePets)); 220 } 221 222 protected float tfidfNorm(String topicName) { 223 Float norm = cachedTopicTfIdfNorm.get(topicName); 224 if (norm == null) { 225 norm = normOf(tfidf(topicName)); 226 cachedTopicTfIdfNorm.put(topicName, norm); 227 } 228 return norm.floatValue(); 229 } 230 231 protected float[] tfidf(String topicName) { 232 float[] tfidf = (float[]) cachedTopicTfIdf.get(topicName); 233 if (tfidf == null) { 234 tfidf = getTfIdf((long[]) topicTermCount.get(topicName)); 235 cachedTopicTfIdf.put(topicName, tfidf); 236 } 237 return tfidf; 238 } 239 240 protected float[] getTfIdf(long[] counts) { 241 float[] idf = getIdf(); 242 float[] tfidf = new float[counts.length]; 243 long sum = sum(counts); 244 if (sum == 0) { 245 return tfidf; 246 } 247 for (int i = 0; i < counts.length; i++) { 248 tfidf[i] = ((float) counts[i]) / sum * idf[i]; 249 } 250 return tfidf; 251 } 252 253 protected float[] getIdf() { 254 if (cachedIdf == null) { 255 float[] idf = new float[allTermCounts.length]; 256 for (int i = 0; i < allTermCounts.length; i++) { 257 if (allTermCounts[i] == 0) { 258 idf[i] = 0; 259 } else { 260 idf[i] = (float) Math.log1p(((float) totalTermCount) / allTermCounts[i]); 261 } 262 } 263 // atomic update to ensure thread-safeness 264 cachedIdf = idf; 265 } 266 return cachedIdf; 267 } 268 269 public int getDimension() { 270 return dim; 271 } 272 273 /** 274 * Utility method to initialize the parameters from a set of UTF-8 encoded text files with names used as topic 275 * names. 276 * <p> 277 * The content of the file to assumed to be lines of terms separated by whitespaces without punctuation. 278 * 279 * @param folder 280 */ 281 public void learnFiles(File folder) throws IOException { 282 if (!folder.isDirectory()) { 283 throw new IOException(String.format("%s is not a folder", folder.getAbsolutePath())); 284 } 285 for (File file : folder.listFiles()) { 286 if (file.isDirectory()) { 287 continue; 288 } 289 String topicName = file.getName(); 290 if (topicName.contains(".")) { 291 topicName = topicName.substring(0, topicName.indexOf('.')); 292 } 293 log.info(String.format("About to analyze file %s", file.getAbsolutePath())); 294 FileInputStream is = new FileInputStream(file); 295 try { 296 BufferedReader reader = new BufferedReader(new InputStreamReader(is, Charset.forName("UTF-8"))); 297 String line = reader.readLine(); 298 int i = 0; 299 while (line != null) { 300 update(topicName, line); 301 line = reader.readLine(); 302 i++; 303 if (i % 10000 == 0) { 304 log.info(String.format("Analyzed %d lines from '%s'", i, file.getAbsolutePath())); 305 } 306 } 307 } finally { 308 is.close(); 309 } 310 } 311 } 312 313 /** 314 * Save the model to a compressed binary format on the filesystem. 315 * 316 * @param file where to write the model 317 */ 318 public void saveToFile(File file) throws IOException { 319 FileOutputStream out = new FileOutputStream(file); 320 try { 321 saveToStream(out); 322 } finally { 323 out.close(); 324 } 325 } 326 327 /** 328 * Save a compressed binary representation of the trained model. 329 * 330 * @param out the output stream to write to 331 */ 332 public void saveToStream(OutputStream out) throws IOException { 333 if (updateDisabled) { 334 throw new IllegalStateException("model in disabled update mode cannot be saved"); 335 } 336 invalidateCache(); 337 GZIPOutputStream gzOut = new GZIPOutputStream(out); 338 ObjectOutputStream objOut = new ObjectOutputStream(gzOut); 339 objOut.writeObject(this); 340 gzOut.finish(); 341 } 342 343 /** 344 * Load a TfIdfCategorizer instance from it's compressed binary representation. 345 * 346 * @param in the input stream to read from 347 * @return a new instance with parameters coming from the saved version 348 */ 349 public static TfIdfCategorizer load(InputStream in) throws IOException, ClassNotFoundException { 350 GZIPInputStream gzIn = new GZIPInputStream(in); 351 ObjectInputStream objIn = new ObjectInputStream(gzIn); 352 TfIdfCategorizer cat = (TfIdfCategorizer) objIn.readObject(); 353 log.info(String.format("Sucessfully loaded model with %d topics, dimension %d and density %f", 354 cat.getTopicNames().size(), cat.getDimension(), cat.getDensity())); 355 return cat; 356 } 357 358 public double getDensity() { 359 long sum = 0; 360 for (Object singleTopicTermCount : topicTermCount.values()) { 361 for (long c : (long[]) singleTopicTermCount) { 362 sum += c != 0L ? 1 : 0; 363 } 364 } 365 for (long c : allTermCounts) { 366 sum += c != 0 ? 1 : 0; 367 } 368 return ((double) sum) / ((topicNames.size() + 1) * getDimension()); 369 } 370 371 public Set<String> getTopicNames() { 372 return topicNames; 373 } 374 375 /** 376 * Load a TfIdfCategorizer instance from it's compressed binary representation from a named resource in the 377 * classloading path of the current thread. 378 * 379 * @param modelPath the path of the file model in the classloading path 380 * @return a new instance with parameters coming from the saved version 381 */ 382 public static TfIdfCategorizer load(String modelPath) throws IOException, ClassNotFoundException { 383 ClassLoader loader = Thread.currentThread().getContextClassLoader(); 384 return load(loader.getResourceAsStream(modelPath)); 385 } 386 387 public static void main(String[] args) throws FileNotFoundException, IOException, ClassNotFoundException { 388 if (args.length < 2 || args.length > 3) { 389 System.out.println("Train a model:\n" + "First argument is the model filename (e.g. my-model.gz)\n" 390 + "Second argument is the path to a folder with UTF-8 text files\n" 391 + "Third optional argument is the dimension of the model"); 392 System.exit(0); 393 } 394 File modelFile = new File(args[0]); 395 TfIdfCategorizer categorizer; 396 if (modelFile.exists()) { 397 log.info("Loading model from: " + modelFile.getAbsolutePath()); 398 FileInputStream is = new FileInputStream(modelFile); 399 try { 400 categorizer = load(is); 401 } finally { 402 is.close(); 403 } 404 } else { 405 if (args.length == 3) { 406 categorizer = new TfIdfCategorizer(Integer.valueOf(args[2])); 407 } else { 408 categorizer = new TfIdfCategorizer(); 409 } 410 log.info("Initializing new model with dimension: " + categorizer.getDimension()); 411 } 412 categorizer.learnFiles(new File(args[1])); 413 log.info("Saving trained model to: " + modelFile.getAbsolutePath()); 414 categorizer.saveToFile(modelFile); 415 } 416 417 public List<String> guessCategories(String textContent, int maxSuggestions) { 418 return guessCategories(textContent, maxSuggestions, null); 419 } 420 421 public List<String> guessCategories(String textContent, int maxSuggestions, Double precisionThreshold) { 422 precisionThreshold = precisionThreshold == null ? ratioOverMedian : precisionThreshold; 423 Map<String, Float> sims = getSimilarities(tokenize(textContent)); 424 Float median = findMedian(sims); 425 List<String> suggested = new ArrayList<String>(); 426 for (Map.Entry<String, Float> sim : sims.entrySet()) { 427 double ratio = median != 0 ? sim.getValue() / median : 100; 428 if (suggested.size() >= maxSuggestions || ratio < precisionThreshold) { 429 break; 430 } 431 suggested.add(sim.getKey()); 432 } 433 return suggested; 434 } 435 436 public List<String> tokenize(String textContent) { 437 try { 438 List<String> terms = new ArrayList<String>(); 439 TokenStream tokenStream = getAnalyzer().tokenStream(null, textContent); 440 CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class); 441 tokenStream.reset(); 442 while (tokenStream.incrementToken()) { 443 terms.add(charTermAttribute.toString()); 444 } 445 tokenStream.end(); 446 tokenStream.close(); 447 return terms; 448 } catch (IOException e) { 449 throw new IllegalStateException(e); 450 } 451 } 452 453 public static Map<String, Float> sortByDecreasingValue(Map<String, Float> map) { 454 List<Map.Entry<String, Float>> list = new LinkedList<Map.Entry<String, Float>>(map.entrySet()); 455 Collections.sort(list, new Comparator<Map.Entry<String, Float>>() { 456 public int compare(Map.Entry<String, Float> e1, Map.Entry<String, Float> e2) { 457 return -e1.getValue().compareTo(e2.getValue()); 458 } 459 }); 460 Map<String, Float> result = new LinkedHashMap<String, Float>(); 461 for (Map.Entry<String, Float> e : list) { 462 result.put(e.getKey(), e.getValue()); 463 } 464 return result; 465 } 466 467 public static Float findMedian(Map<String, Float> sortedMap) { 468 int remaining = sortedMap.size() / 2; 469 Float median = 0.0f; 470 for (Float value : sortedMap.values()) { 471 median = value; 472 if (remaining-- <= 0) { 473 break; 474 } 475 } 476 return median; 477 } 478 479}