001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.lucene.demo.knn;
018
019import java.io.IOException;
020import java.nio.ByteBuffer;
021import java.nio.ByteOrder;
022import java.nio.FloatBuffer;
023import java.util.Arrays;
024import org.apache.lucene.analysis.TokenFilter;
025import org.apache.lucene.analysis.TokenStream;
026import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
027import org.apache.lucene.util.BytesRef;
028import org.apache.lucene.util.VectorUtil;
029
030/**
031 * Looks up each tokens in a dictionary, and sums the token vectors. Unrecognized tokens are
032 * ignored. The resulting vector is normalized to unit length.
033 */
034public final class KnnVectorDictFilter extends TokenFilter {
035
036  private final KnnVectorDict dict;
037  private final CharTermAttribute termAtt = addAttribute(CharTermAttribute.class);
038  private final float[] scratchFloats;
039  private final float[] result;
040  private final byte[] scratchBytes;
041  private final FloatBuffer scratchBuffer;
042
043  /**
044   * sole constructor
045   *
046   * @param input the input token stream to filter.
047   * @param dict a token to vector dictionary, used to look up the token vectors.
048   */
049  public KnnVectorDictFilter(TokenStream input, KnnVectorDict dict) {
050    super(input);
051    this.dict = dict;
052    result = new float[dict.getDimension()];
053    scratchBytes = new byte[dict.getDimension() * Float.BYTES];
054    scratchBuffer = ByteBuffer.wrap(scratchBytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
055    scratchFloats = new float[dict.getDimension()];
056  }
057
058  @Override
059  public boolean incrementToken() throws IOException {
060    if (input.incrementToken() == false) {
061      return false;
062    }
063    BytesRef term = new BytesRef(termAtt.toString());
064    dict.get(term, scratchBytes);
065    scratchBuffer.position(0);
066    scratchBuffer.get(scratchFloats);
067    VectorUtil.add(result, scratchFloats);
068    return true;
069  }
070
071  @Override
072  public void reset() throws IOException {
073    super.reset();
074    Arrays.fill(result, 0);
075  }
076
077  @Override
078  public void end() throws IOException {
079    super.end();
080    VectorUtil.l2normalize(result, false);
081  }
082
083  /**
084   * Get the vector computed from the input
085   *
086   * @return the resultant sum of the vectors of each term.
087   */
088  public float[] getResult() {
089    return result;
090  }
091}