001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.lucene.demo.knn;
018
019import java.io.BufferedReader;
020import java.io.Closeable;
021import java.io.IOException;
022import java.nio.ByteBuffer;
023import java.nio.ByteOrder;
024import java.nio.FloatBuffer;
025import java.nio.file.Files;
026import java.nio.file.Path;
027import java.util.Arrays;
028import java.util.regex.Pattern;
029import org.apache.lucene.store.Directory;
030import org.apache.lucene.store.IOContext;
031import org.apache.lucene.store.IndexInput;
032import org.apache.lucene.store.IndexOutput;
033import org.apache.lucene.util.BytesRef;
034import org.apache.lucene.util.IntsRefBuilder;
035import org.apache.lucene.util.VectorUtil;
036import org.apache.lucene.util.fst.FST;
037import org.apache.lucene.util.fst.FSTCompiler;
038import org.apache.lucene.util.fst.PositiveIntOutputs;
039import org.apache.lucene.util.fst.Util;
040
041/**
042 * Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is
043 * stored as an FST: token-to-ordinal plus a dense binary file holding the vectors.
044 */
045public class KnnVectorDict implements Closeable {
046
047  private final FST<Long> fst;
048  private final IndexInput vectors;
049  private final int dimension;
050
051  /**
052   * Sole constructor
053   *
054   * @param directory Lucene directory from which knn directory should be read.
055   * @param dictName the base name of the directory files that store the knn vector dictionary. A
056   *     file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the
057   *     '.bin' file.
058   */
059  public KnnVectorDict(Directory directory, String dictName) throws IOException {
060    try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) {
061      fst = new FST<>(fstIn, fstIn, PositiveIntOutputs.getSingleton());
062    }
063
064    vectors = directory.openInput(dictName + ".bin", IOContext.READ);
065    long size = vectors.length();
066    vectors.seek(size - Integer.BYTES);
067    dimension = vectors.readInt();
068    if ((size - Integer.BYTES) % (dimension * Float.BYTES) != 0) {
069      throw new IllegalStateException(
070          "vector file size " + size + " is not consonant with the vector dimension " + dimension);
071    }
072  }
073
074  /**
075   * Get the vector corresponding to the given token. NOTE: the returned array is shared and its
076   * contents will be overwritten by subsequent calls. The caller is responsible to copy the data as
077   * needed.
078   *
079   * @param token the token to look up
080   * @param output the array in which to write the corresponding vector. Its length must be {@link
081   *     #getDimension()} * {@link Float#BYTES}. It will be filled with zeros if the token is not
082   *     present in the dictionary.
083   * @throws IllegalArgumentException if the output array is incorrectly sized
084   * @throws IOException if there is a problem reading the dictionary
085   */
086  public void get(BytesRef token, byte[] output) throws IOException {
087    if (output.length != dimension * Float.BYTES) {
088      throw new IllegalArgumentException(
089          "the output array must be of length "
090              + (dimension * Float.BYTES)
091              + ", got "
092              + output.length);
093    }
094    Long ord = Util.get(fst, token);
095    if (ord == null) {
096      Arrays.fill(output, (byte) 0);
097    } else {
098      vectors.seek(ord * dimension * Float.BYTES);
099      vectors.readBytes(output, 0, output.length);
100    }
101  }
102
103  /**
104   * Get the dimension of the vectors returned by this.
105   *
106   * @return the vector dimension
107   */
108  public int getDimension() {
109    return dimension;
110  }
111
112  @Override
113  public void close() throws IOException {
114    vectors.close();
115  }
116
117  /**
118   * Convert from a GloVe-formatted dictionary file to a KnnVectorDict file pair.
119   *
120   * @param gloveInput the path to the input dictionary. The dictionary is delimited by newlines,
121   *     and each line is space-delimited. The first column has the token, and the remaining columns
122   *     are the vector components, as text. The dictionary must be sorted by its leading tokens
123   *     (considered as bytes).
124   * @param directory a Lucene directory to write the dictionary to.
125   * @param dictName Base name for the knn dictionary files.
126   */
127  public static void build(Path gloveInput, Directory directory, String dictName)
128      throws IOException {
129    new Builder().build(gloveInput, directory, dictName);
130  }
131
132  private static class Builder {
133    private static final Pattern SPACE_RE = Pattern.compile(" ");
134
135    private final IntsRefBuilder intsRefBuilder = new IntsRefBuilder();
136    private final FSTCompiler<Long> fstCompiler =
137        new FSTCompiler<>(FST.INPUT_TYPE.BYTE1, PositiveIntOutputs.getSingleton());
138    private float[] scratch;
139    private ByteBuffer byteBuffer;
140    private long ordinal = 1;
141    private int numFields;
142
143    void build(Path gloveInput, Directory directory, String dictName) throws IOException {
144      try (BufferedReader in = Files.newBufferedReader(gloveInput);
145          IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT);
146          IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) {
147        writeFirstLine(in, binOut);
148        while (addOneLine(in, binOut)) {
149          // continue;
150        }
151        fstCompiler.compile().save(fstOut, fstOut);
152        binOut.writeInt(numFields - 1);
153      }
154    }
155
156    private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException {
157      String[] fields = readOneLine(in);
158      if (fields == null) {
159        return;
160      }
161      numFields = fields.length;
162      byteBuffer =
163          ByteBuffer.allocate((numFields - 1) * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
164      scratch = new float[numFields - 1];
165      writeVector(fields, out);
166    }
167
168    private String[] readOneLine(BufferedReader in) throws IOException {
169      String line = in.readLine();
170      if (line == null) {
171        return null;
172      }
173      return SPACE_RE.split(line, 0);
174    }
175
176    private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException {
177      String[] fields = readOneLine(in);
178      if (fields == null) {
179        return false;
180      }
181      if (fields.length != numFields) {
182        throw new IllegalStateException(
183            "different field count at line "
184                + ordinal
185                + " got "
186                + fields.length
187                + " when expecting "
188                + numFields);
189      }
190      fstCompiler.add(Util.toIntsRef(new BytesRef(fields[0]), intsRefBuilder), ordinal++);
191      writeVector(fields, out);
192      return true;
193    }
194
195    private void writeVector(String[] fields, IndexOutput out) throws IOException {
196      byteBuffer.position(0);
197      FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
198      for (int i = 1; i < fields.length; i++) {
199        scratch[i - 1] = Float.parseFloat(fields[i]);
200      }
201      VectorUtil.l2normalize(scratch);
202      floatBuffer.put(scratch);
203      byte[] bytes = byteBuffer.array();
204      out.writeBytes(bytes, bytes.length);
205    }
206  }
207
208  /** Return the size of the dictionary in bytes */
209  public long ramBytesUsed() {
210    return fst.ramBytesUsed() + vectors.length();
211  }
212}