001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.lucene.demo.knn; 018 019import static org.apache.lucene.util.fst.FST.readMetadata; 020 021import java.io.BufferedReader; 022import java.io.Closeable; 023import java.io.IOException; 024import java.nio.ByteBuffer; 025import java.nio.ByteOrder; 026import java.nio.FloatBuffer; 027import java.nio.file.Files; 028import java.nio.file.Path; 029import java.util.Arrays; 030import java.util.regex.Pattern; 031import org.apache.lucene.store.Directory; 032import org.apache.lucene.store.IOContext; 033import org.apache.lucene.store.IndexInput; 034import org.apache.lucene.store.IndexOutput; 035import org.apache.lucene.util.BytesRef; 036import org.apache.lucene.util.IntsRefBuilder; 037import org.apache.lucene.util.VectorUtil; 038import org.apache.lucene.util.fst.FST; 039import org.apache.lucene.util.fst.FSTCompiler; 040import org.apache.lucene.util.fst.PositiveIntOutputs; 041import org.apache.lucene.util.fst.Util; 042 043/** 044 * Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is 045 * stored as an FST: token-to-ordinal plus a dense binary file holding the vectors. 046 */ 047public class KnnVectorDict implements Closeable { 048 049 private final FST<Long> fst; 050 private final IndexInput vectors; 051 private final int dimension; 052 053 /** 054 * Sole constructor 055 * 056 * @param directory Lucene directory from which knn directory should be read. 057 * @param dictName the base name of the directory files that store the knn vector dictionary. A 058 * file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the 059 * '.bin' file. 060 */ 061 public KnnVectorDict(Directory directory, String dictName) throws IOException { 062 try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) { 063 fst = new FST<>(readMetadata(fstIn, PositiveIntOutputs.getSingleton()), fstIn); 064 } 065 066 vectors = directory.openInput(dictName + ".bin", IOContext.READ); 067 long size = vectors.length(); 068 vectors.seek(size - Integer.BYTES); 069 dimension = vectors.readInt(); 070 if ((size - Integer.BYTES) % (dimension * (long) Float.BYTES) != 0) { 071 throw new IllegalStateException( 072 "vector file size " + size + " is not consonant with the vector dimension " + dimension); 073 } 074 } 075 076 /** 077 * Get the vector corresponding to the given token. NOTE: the returned array is shared and its 078 * contents will be overwritten by subsequent calls. The caller is responsible to copy the data as 079 * needed. 080 * 081 * @param token the token to look up 082 * @param output the array in which to write the corresponding vector. Its length must be {@link 083 * #getDimension()} * {@link Float#BYTES}. It will be filled with zeros if the token is not 084 * present in the dictionary. 085 * @throws IllegalArgumentException if the output array is incorrectly sized 086 * @throws IOException if there is a problem reading the dictionary 087 */ 088 public void get(BytesRef token, byte[] output) throws IOException { 089 if (output.length != dimension * Float.BYTES) { 090 throw new IllegalArgumentException( 091 "the output array must be of length " 092 + (dimension * Float.BYTES) 093 + ", got " 094 + output.length); 095 } 096 Long ord = Util.get(fst, token); 097 if (ord == null) { 098 Arrays.fill(output, (byte) 0); 099 } else { 100 vectors.seek(ord * dimension * Float.BYTES); 101 vectors.readBytes(output, 0, output.length); 102 } 103 } 104 105 /** 106 * Get the dimension of the vectors returned by this. 107 * 108 * @return the vector dimension 109 */ 110 public int getDimension() { 111 return dimension; 112 } 113 114 @Override 115 public void close() throws IOException { 116 vectors.close(); 117 } 118 119 /** 120 * Convert from a GloVe-formatted dictionary file to a KnnVectorDict file pair. 121 * 122 * @param gloveInput the path to the input dictionary. The dictionary is delimited by newlines, 123 * and each line is space-delimited. The first column has the token, and the remaining columns 124 * are the vector components, as text. The dictionary must be sorted by its leading tokens 125 * (considered as bytes). 126 * @param directory a Lucene directory to write the dictionary to. 127 * @param dictName Base name for the knn dictionary files. 128 */ 129 public static void build(Path gloveInput, Directory directory, String dictName) 130 throws IOException { 131 new Builder().build(gloveInput, directory, dictName); 132 } 133 134 private static class Builder { 135 private static final Pattern SPACE_RE = Pattern.compile(" "); 136 137 private final IntsRefBuilder intsRefBuilder = new IntsRefBuilder(); 138 private final FSTCompiler<Long> fstCompiler; 139 private float[] scratch; 140 private ByteBuffer byteBuffer; 141 private long ordinal = 1; 142 private int numFields; 143 144 Builder() throws IOException { 145 fstCompiler = 146 new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, PositiveIntOutputs.getSingleton()) 147 .build(); 148 } 149 150 void build(Path gloveInput, Directory directory, String dictName) throws IOException { 151 try (BufferedReader in = Files.newBufferedReader(gloveInput); 152 IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT); 153 IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) { 154 writeFirstLine(in, binOut); 155 while (addOneLine(in, binOut)) { 156 // continue; 157 } 158 fstCompiler.compile().save(fstOut, fstOut); 159 binOut.writeInt(numFields - 1); 160 } 161 } 162 163 private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException { 164 String[] fields = readOneLine(in); 165 if (fields == null) { 166 return; 167 } 168 numFields = fields.length; 169 byteBuffer = 170 ByteBuffer.allocate((numFields - 1) * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); 171 scratch = new float[numFields - 1]; 172 writeVector(fields, out); 173 } 174 175 private String[] readOneLine(BufferedReader in) throws IOException { 176 String line = in.readLine(); 177 if (line == null) { 178 return null; 179 } 180 return SPACE_RE.split(line, 0); 181 } 182 183 private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException { 184 String[] fields = readOneLine(in); 185 if (fields == null) { 186 return false; 187 } 188 if (fields.length != numFields) { 189 throw new IllegalStateException( 190 "different field count at line " 191 + ordinal 192 + " got " 193 + fields.length 194 + " when expecting " 195 + numFields); 196 } 197 fstCompiler.add(Util.toIntsRef(new BytesRef(fields[0]), intsRefBuilder), ordinal++); 198 writeVector(fields, out); 199 return true; 200 } 201 202 private void writeVector(String[] fields, IndexOutput out) throws IOException { 203 byteBuffer.position(0); 204 FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); 205 for (int i = 1; i < fields.length; i++) { 206 scratch[i - 1] = Float.parseFloat(fields[i]); 207 } 208 VectorUtil.l2normalize(scratch); 209 floatBuffer.put(scratch); 210 byte[] bytes = byteBuffer.array(); 211 out.writeBytes(bytes, bytes.length); 212 } 213 } 214 215 /** Return the size of the dictionary in bytes */ 216 public long ramBytesUsed() { 217 return fst.ramBytesUsed() + vectors.length(); 218 } 219}