001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *     http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.lucene.demo.knn;
018
019import static org.apache.lucene.util.fst.FST.readMetadata;
020
021import java.io.BufferedReader;
022import java.io.Closeable;
023import java.io.IOException;
024import java.nio.ByteBuffer;
025import java.nio.ByteOrder;
026import java.nio.FloatBuffer;
027import java.nio.file.Files;
028import java.nio.file.Path;
029import java.util.Arrays;
030import java.util.regex.Pattern;
031import org.apache.lucene.store.Directory;
032import org.apache.lucene.store.IOContext;
033import org.apache.lucene.store.IndexInput;
034import org.apache.lucene.store.IndexOutput;
035import org.apache.lucene.util.BytesRef;
036import org.apache.lucene.util.IntsRefBuilder;
037import org.apache.lucene.util.VectorUtil;
038import org.apache.lucene.util.fst.FST;
039import org.apache.lucene.util.fst.FSTCompiler;
040import org.apache.lucene.util.fst.PositiveIntOutputs;
041import org.apache.lucene.util.fst.Util;
042
043/**
044 * Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is
045 * stored as an FST: token-to-ordinal plus a dense binary file holding the vectors.
046 */
047public class KnnVectorDict implements Closeable {
048
049  private final FST<Long> fst;
050  private final IndexInput vectors;
051  private final int dimension;
052
053  /**
054   * Sole constructor
055   *
056   * @param directory Lucene directory from which knn directory should be read.
057   * @param dictName the base name of the directory files that store the knn vector dictionary. A
058   *     file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the
059   *     '.bin' file.
060   */
061  public KnnVectorDict(Directory directory, String dictName) throws IOException {
062    try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) {
063      fst = new FST<>(readMetadata(fstIn, PositiveIntOutputs.getSingleton()), fstIn);
064    }
065
066    vectors = directory.openInput(dictName + ".bin", IOContext.READ);
067    long size = vectors.length();
068    vectors.seek(size - Integer.BYTES);
069    dimension = vectors.readInt();
070    if ((size - Integer.BYTES) % (dimension * (long) Float.BYTES) != 0) {
071      throw new IllegalStateException(
072          "vector file size " + size + " is not consonant with the vector dimension " + dimension);
073    }
074  }
075
076  /**
077   * Get the vector corresponding to the given token. NOTE: the returned array is shared and its
078   * contents will be overwritten by subsequent calls. The caller is responsible to copy the data as
079   * needed.
080   *
081   * @param token the token to look up
082   * @param output the array in which to write the corresponding vector. Its length must be {@link
083   *     #getDimension()} * {@link Float#BYTES}. It will be filled with zeros if the token is not
084   *     present in the dictionary.
085   * @throws IllegalArgumentException if the output array is incorrectly sized
086   * @throws IOException if there is a problem reading the dictionary
087   */
088  public void get(BytesRef token, byte[] output) throws IOException {
089    if (output.length != dimension * Float.BYTES) {
090      throw new IllegalArgumentException(
091          "the output array must be of length "
092              + (dimension * Float.BYTES)
093              + ", got "
094              + output.length);
095    }
096    Long ord = Util.get(fst, token);
097    if (ord == null) {
098      Arrays.fill(output, (byte) 0);
099    } else {
100      vectors.seek(ord * dimension * Float.BYTES);
101      vectors.readBytes(output, 0, output.length);
102    }
103  }
104
105  /**
106   * Get the dimension of the vectors returned by this.
107   *
108   * @return the vector dimension
109   */
110  public int getDimension() {
111    return dimension;
112  }
113
114  @Override
115  public void close() throws IOException {
116    vectors.close();
117  }
118
119  /**
120   * Convert from a GloVe-formatted dictionary file to a KnnVectorDict file pair.
121   *
122   * @param gloveInput the path to the input dictionary. The dictionary is delimited by newlines,
123   *     and each line is space-delimited. The first column has the token, and the remaining columns
124   *     are the vector components, as text. The dictionary must be sorted by its leading tokens
125   *     (considered as bytes).
126   * @param directory a Lucene directory to write the dictionary to.
127   * @param dictName Base name for the knn dictionary files.
128   */
129  public static void build(Path gloveInput, Directory directory, String dictName)
130      throws IOException {
131    new Builder().build(gloveInput, directory, dictName);
132  }
133
134  private static class Builder {
135    private static final Pattern SPACE_RE = Pattern.compile(" ");
136
137    private final IntsRefBuilder intsRefBuilder = new IntsRefBuilder();
138    private final FSTCompiler<Long> fstCompiler;
139    private float[] scratch;
140    private ByteBuffer byteBuffer;
141    private long ordinal = 1;
142    private int numFields;
143
144    Builder() throws IOException {
145      fstCompiler =
146          new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, PositiveIntOutputs.getSingleton())
147              .build();
148    }
149
150    void build(Path gloveInput, Directory directory, String dictName) throws IOException {
151      try (BufferedReader in = Files.newBufferedReader(gloveInput);
152          IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT);
153          IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) {
154        writeFirstLine(in, binOut);
155        while (addOneLine(in, binOut)) {
156          // continue;
157        }
158        fstCompiler.compile().save(fstOut, fstOut);
159        binOut.writeInt(numFields - 1);
160      }
161    }
162
163    private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException {
164      String[] fields = readOneLine(in);
165      if (fields == null) {
166        return;
167      }
168      numFields = fields.length;
169      byteBuffer =
170          ByteBuffer.allocate((numFields - 1) * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
171      scratch = new float[numFields - 1];
172      writeVector(fields, out);
173    }
174
175    private String[] readOneLine(BufferedReader in) throws IOException {
176      String line = in.readLine();
177      if (line == null) {
178        return null;
179      }
180      return SPACE_RE.split(line, 0);
181    }
182
183    private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException {
184      String[] fields = readOneLine(in);
185      if (fields == null) {
186        return false;
187      }
188      if (fields.length != numFields) {
189        throw new IllegalStateException(
190            "different field count at line "
191                + ordinal
192                + " got "
193                + fields.length
194                + " when expecting "
195                + numFields);
196      }
197      fstCompiler.add(Util.toIntsRef(new BytesRef(fields[0]), intsRefBuilder), ordinal++);
198      writeVector(fields, out);
199      return true;
200    }
201
202    private void writeVector(String[] fields, IndexOutput out) throws IOException {
203      byteBuffer.position(0);
204      FloatBuffer floatBuffer = byteBuffer.asFloatBuffer();
205      for (int i = 1; i < fields.length; i++) {
206        scratch[i - 1] = Float.parseFloat(fields[i]);
207      }
208      VectorUtil.l2normalize(scratch);
209      floatBuffer.put(scratch);
210      byte[] bytes = byteBuffer.array();
211      out.writeBytes(bytes, bytes.length);
212    }
213  }
214
215  /** Return the size of the dictionary in bytes */
216  public long ramBytesUsed() {
217    return fst.ramBytesUsed() + vectors.length();
218  }
219}