001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.lucene.demo.knn; 018 019import java.io.IOException; 020import java.nio.ByteBuffer; 021import java.nio.ByteOrder; 022import java.nio.FloatBuffer; 023import java.util.Arrays; 024import org.apache.lucene.analysis.TokenFilter; 025import org.apache.lucene.analysis.TokenStream; 026import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; 027import org.apache.lucene.util.BytesRef; 028import org.apache.lucene.util.VectorUtil; 029 030/** 031 * Looks up each tokens in a dictionary, and sums the token vectors. Unrecognized tokens are 032 * ignored. The resulting vector is normalized to unit length. 033 */ 034public final class KnnVectorDictFilter extends TokenFilter { 035 036 private final KnnVectorDict dict; 037 private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class); 038 private final float[] scratchFloats; 039 private final float[] result; 040 private final byte[] scratchBytes; 041 private final FloatBuffer scratchBuffer; 042 043 /** 044 * sole constructor 045 * 046 * @param input the input token stream to filter. 047 * @param dict a token to vector dictionary, used to look up the token vectors. 048 */ 049 public KnnVectorDictFilter(TokenStream input, KnnVectorDict dict) { 050 super(input); 051 this.dict = dict; 052 result = new float[dict.getDimension()]; 053 scratchBytes = new byte[dict.getDimension() * Float.BYTES]; 054 scratchBuffer = ByteBuffer.wrap(scratchBytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); 055 scratchFloats = new float[dict.getDimension()]; 056 } 057 058 @Override 059 public boolean incrementToken() throws IOException { 060 if (input.incrementToken() == false) { 061 return false; 062 } 063 BytesRef term = new BytesRef(termAtt.toString()); 064 dict.get(term, scratchBytes); 065 scratchBuffer.position(0); 066 scratchBuffer.get(scratchFloats); 067 VectorUtil.add(result, scratchFloats); 068 return true; 069 } 070 071 @Override 072 public void reset() throws IOException { 073 super.reset(); 074 Arrays.fill(result, 0); 075 } 076 077 @Override 078 public void end() throws IOException { 079 super.end(); 080 VectorUtil.l2normalize(result, false); 081 } 082 083 /** 084 * Get the vector computed from the input 085 * 086 * @return the resultant sum of the vectors of each term. 087 */ 088 public float[] getResult() { 089 return result; 090 } 091}