001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.lucene.demo.knn; 018 019import java.io.IOException; 020import java.io.Reader; 021import java.io.StringReader; 022import org.apache.lucene.analysis.Analyzer; 023import org.apache.lucene.analysis.LowerCaseFilter; 024import org.apache.lucene.analysis.TokenStream; 025import org.apache.lucene.analysis.Tokenizer; 026import org.apache.lucene.analysis.standard.StandardTokenizer; 027 028/** 029 * This class provides {@link #computeEmbedding(String)} and {@link #computeEmbedding(Reader)} for 030 * calculating "semantic" embedding vectors for textual input. 031 */ 032public class DemoEmbeddings { 033 034 private final Analyzer analyzer; 035 036 /** 037 * Sole constructor 038 * 039 * @param vectorDict a token to vector dictionary 040 */ 041 public DemoEmbeddings(KnnVectorDict vectorDict) { 042 analyzer = 043 new Analyzer() { 044 @Override 045 protected TokenStreamComponents createComponents(String fieldName) { 046 Tokenizer tokenizer = new StandardTokenizer(); 047 TokenStream output = 048 new KnnVectorDictFilter(new LowerCaseFilter(tokenizer), vectorDict); 049 return new TokenStreamComponents(tokenizer, output); 050 } 051 }; 052 } 053 054 /** 055 * Tokenize and lower-case the input, look up the tokens in the dictionary, and sum the token 056 * vectors. Unrecognized tokens are ignored. The resulting vector is normalized to unit length. 057 * 058 * @param input the input to analyze 059 * @return the KnnVector for the input 060 */ 061 public float[] computeEmbedding(String input) throws IOException { 062 return computeEmbedding(new StringReader(input)); 063 } 064 065 /** 066 * Tokenize and lower-case the input, look up the tokens in the dictionary, and sum the token 067 * vectors. Unrecognized tokens are ignored. The resulting vector is normalized to unit length. 068 * 069 * @param input the input to analyze 070 * @return the KnnVector for the input 071 */ 072 public float[] computeEmbedding(Reader input) throws IOException { 073 try (TokenStream tokens = analyzer.tokenStream("dummyField", input)) { 074 tokens.reset(); 075 while (tokens.incrementToken()) {} 076 tokens.end(); 077 return ((KnnVectorDictFilter) tokens).getResult(); 078 } 079 } 080}