001/* 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.lucene.demo.knn; 018 019import java.io.BufferedReader; 020import java.io.Closeable; 021import java.io.IOException; 022import java.nio.ByteBuffer; 023import java.nio.ByteOrder; 024import java.nio.FloatBuffer; 025import java.nio.file.Files; 026import java.nio.file.Path; 027import java.util.Arrays; 028import java.util.regex.Pattern; 029import org.apache.lucene.store.Directory; 030import org.apache.lucene.store.IOContext; 031import org.apache.lucene.store.IndexInput; 032import org.apache.lucene.store.IndexOutput; 033import org.apache.lucene.util.BytesRef; 034import org.apache.lucene.util.IntsRefBuilder; 035import org.apache.lucene.util.VectorUtil; 036import org.apache.lucene.util.fst.FST; 037import org.apache.lucene.util.fst.FSTCompiler; 038import org.apache.lucene.util.fst.PositiveIntOutputs; 039import org.apache.lucene.util.fst.Util; 040 041/** 042 * Manages a map from token to numeric vector for use with KnnVector indexing and search. The map is 043 * stored as an FST: token-to-ordinal plus a dense binary file holding the vectors. 044 */ 045public class KnnVectorDict implements Closeable { 046 047 private final FST<Long> fst; 048 private final IndexInput vectors; 049 private final int dimension; 050 051 /** 052 * Sole constructor 053 * 054 * @param directory Lucene directory from which knn directory should be read. 055 * @param dictName the base name of the directory files that store the knn vector dictionary. A 056 * file with extension '.bin' holds the vectors and the '.fst' maps tokens to offsets in the 057 * '.bin' file. 058 */ 059 public KnnVectorDict(Directory directory, String dictName) throws IOException { 060 try (IndexInput fstIn = directory.openInput(dictName + ".fst", IOContext.READ)) { 061 fst = new FST<>(fstIn, fstIn, PositiveIntOutputs.getSingleton()); 062 } 063 064 vectors = directory.openInput(dictName + ".bin", IOContext.READ); 065 long size = vectors.length(); 066 vectors.seek(size - Integer.BYTES); 067 dimension = vectors.readInt(); 068 if ((size - Integer.BYTES) % (dimension * Float.BYTES) != 0) { 069 throw new IllegalStateException( 070 "vector file size " + size + " is not consonant with the vector dimension " + dimension); 071 } 072 } 073 074 /** 075 * Get the vector corresponding to the given token. NOTE: the returned array is shared and its 076 * contents will be overwritten by subsequent calls. The caller is responsible to copy the data as 077 * needed. 078 * 079 * @param token the token to look up 080 * @param output the array in which to write the corresponding vector. Its length must be {@link 081 * #getDimension()} * {@link Float#BYTES}. It will be filled with zeros if the token is not 082 * present in the dictionary. 083 * @throws IllegalArgumentException if the output array is incorrectly sized 084 * @throws IOException if there is a problem reading the dictionary 085 */ 086 public void get(BytesRef token, byte[] output) throws IOException { 087 if (output.length != dimension * Float.BYTES) { 088 throw new IllegalArgumentException( 089 "the output array must be of length " 090 + (dimension * Float.BYTES) 091 + ", got " 092 + output.length); 093 } 094 Long ord = Util.get(fst, token); 095 if (ord == null) { 096 Arrays.fill(output, (byte) 0); 097 } else { 098 vectors.seek(ord * dimension * Float.BYTES); 099 vectors.readBytes(output, 0, output.length); 100 } 101 } 102 103 /** 104 * Get the dimension of the vectors returned by this. 105 * 106 * @return the vector dimension 107 */ 108 public int getDimension() { 109 return dimension; 110 } 111 112 @Override 113 public void close() throws IOException { 114 vectors.close(); 115 } 116 117 /** 118 * Convert from a GloVe-formatted dictionary file to a KnnVectorDict file pair. 119 * 120 * @param gloveInput the path to the input dictionary. The dictionary is delimited by newlines, 121 * and each line is space-delimited. The first column has the token, and the remaining columns 122 * are the vector components, as text. The dictionary must be sorted by its leading tokens 123 * (considered as bytes). 124 * @param directory a Lucene directory to write the dictionary to. 125 * @param dictName Base name for the knn dictionary files. 126 */ 127 public static void build(Path gloveInput, Directory directory, String dictName) 128 throws IOException { 129 new Builder().build(gloveInput, directory, dictName); 130 } 131 132 private static class Builder { 133 private static final Pattern SPACE_RE = Pattern.compile(" "); 134 135 private final IntsRefBuilder intsRefBuilder = new IntsRefBuilder(); 136 private final FSTCompiler<Long> fstCompiler = 137 new FSTCompiler<>(FST.INPUT_TYPE.BYTE1, PositiveIntOutputs.getSingleton()); 138 private float[] scratch; 139 private ByteBuffer byteBuffer; 140 private long ordinal = 1; 141 private int numFields; 142 143 void build(Path gloveInput, Directory directory, String dictName) throws IOException { 144 try (BufferedReader in = Files.newBufferedReader(gloveInput); 145 IndexOutput binOut = directory.createOutput(dictName + ".bin", IOContext.DEFAULT); 146 IndexOutput fstOut = directory.createOutput(dictName + ".fst", IOContext.DEFAULT)) { 147 writeFirstLine(in, binOut); 148 while (addOneLine(in, binOut)) { 149 // continue; 150 } 151 fstCompiler.compile().save(fstOut, fstOut); 152 binOut.writeInt(numFields - 1); 153 } 154 } 155 156 private void writeFirstLine(BufferedReader in, IndexOutput out) throws IOException { 157 String[] fields = readOneLine(in); 158 if (fields == null) { 159 return; 160 } 161 numFields = fields.length; 162 byteBuffer = 163 ByteBuffer.allocate((numFields - 1) * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); 164 scratch = new float[numFields - 1]; 165 writeVector(fields, out); 166 } 167 168 private String[] readOneLine(BufferedReader in) throws IOException { 169 String line = in.readLine(); 170 if (line == null) { 171 return null; 172 } 173 return SPACE_RE.split(line, 0); 174 } 175 176 private boolean addOneLine(BufferedReader in, IndexOutput out) throws IOException { 177 String[] fields = readOneLine(in); 178 if (fields == null) { 179 return false; 180 } 181 if (fields.length != numFields) { 182 throw new IllegalStateException( 183 "different field count at line " 184 + ordinal 185 + " got " 186 + fields.length 187 + " when expecting " 188 + numFields); 189 } 190 fstCompiler.add(Util.toIntsRef(new BytesRef(fields[0]), intsRefBuilder), ordinal++); 191 writeVector(fields, out); 192 return true; 193 } 194 195 private void writeVector(String[] fields, IndexOutput out) throws IOException { 196 byteBuffer.position(0); 197 FloatBuffer floatBuffer = byteBuffer.asFloatBuffer(); 198 for (int i = 1; i < fields.length; i++) { 199 scratch[i - 1] = Float.parseFloat(fields[i]); 200 } 201 VectorUtil.l2normalize(scratch); 202 floatBuffer.put(scratch); 203 byte[] bytes = byteBuffer.array(); 204 out.writeBytes(bytes, bytes.length); 205 } 206 } 207 208 /** Return the size of the dictionary in bytes */ 209 public long ramBytesUsed() { 210 return fst.ramBytesUsed() + vectors.length(); 211 } 212}