001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.lucene.demo.knn;
018
019import java.io.IOException;
020import java.io.Reader;
021import java.io.StringReader;
022import org.apache.lucene.analysis.Analyzer;
023import org.apache.lucene.analysis.LowerCaseFilter;
024import org.apache.lucene.analysis.TokenStream;
025import org.apache.lucene.analysis.Tokenizer;
026import org.apache.lucene.analysis.standard.StandardTokenizer;
027
028/**
029 * This class provides {@link #computeEmbedding(String)} and {@link #computeEmbedding(Reader)} for
030 * calculating "semantic" embedding vectors for textual input.
031 */
032public class DemoEmbeddings {
033
034  private final Analyzer analyzer;
035
036  /**
037   * Sole constructor
038   *
039   * @param vectorDict a token to vector dictionary
040   */
041  public DemoEmbeddings(KnnVectorDict vectorDict) {
042    analyzer =
043        new Analyzer() {
044          @Override
045          protected TokenStreamComponents createComponents(String fieldName) {
046            Tokenizer tokenizer = new StandardTokenizer();
047            TokenStream output =
048                new KnnVectorDictFilter(new LowerCaseFilter(tokenizer), vectorDict);
049            return new TokenStreamComponents(tokenizer, output);
050          }
051        };
052  }
053
054  /**
055   * Tokenize and lower-case the input, look up the tokens in the dictionary, and sum the token
056   * vectors. Unrecognized tokens are ignored. The resulting vector is normalized to unit length.
057   *
058   * @param input the input to analyze
059   * @return the KnnVector for the input
060   */
061  public float[] computeEmbedding(String input) throws IOException {
062    return computeEmbedding(new StringReader(input));
063  }
064
065  /**
066   * Tokenize and lower-case the input, look up the tokens in the dictionary, and sum the token
067   * vectors. Unrecognized tokens are ignored. The resulting vector is normalized to unit length.
068   *
069   * @param input the input to analyze
070   * @return the KnnVector for the input
071   */
072  public float[] computeEmbedding(Reader input) throws IOException {
073    try (TokenStream tokens = analyzer.tokenStream("dummyField", input)) {
074      tokens.reset();
075      while (tokens.incrementToken()) {}
076      tokens.end();
077      return ((KnnVectorDictFilter) tokens).getResult();
078    }
079  }
080}