This is a trie that uses a sentinel node to denote the end of a word. This is more space efficient than having to flag each node as to whether it denotes an end of a word. To quickly find the number of prefix matches, it stores the prefix count in the node.
class Trie { char ch; int count = 0; Map<Character, Trie> list = new HashMap<Character, Trie>(); public Trie(char ch) { this.ch = ch; } public Trie add(char ch) { Trie node = this.list.get(ch); if (node == null) { Trie newNode = new Trie(ch); this.list.put(ch, newNode); node = newNode; }
//adding the count to the current node is preferable
//to adding to the node that matches the character.
//This way, we won't add to the sentinel node//and we add only in one place.this.count++; return node; } public int size() { return this.count; } private Trie findChar(char ch) { return this.list.get(ch); } public boolean findWord(String word) { Trie node = this; for (char ch: word.toCharArray()) { node = node.findChar(ch); if (node == null) { return false; } }
//we may have found a prefix, make sure it is a word
//if it's a word, the list must have the sentinel. return node.list.get((char)0) != null; } public int findPartial(String prefix) { Trie node = this; for (char ch : prefix.toCharArray()) { node = node.list.get(ch); if (node == null) { return 0; } } return node.size(); } public void add(String s) { Trie node = this; for (char ch : s.toCharArray()) { node = node.add(ch); }
//add the sentinel to mark the end of the word. node.add((char)0); } }
Now it is possible to reduce the space taken by the trie further by using an array instead of the map. Knowing that we need to use only lower case letters, we can use the charater before 'a' as the sentinel, so that the array length is set to 27.
Another space optimization comes about by using a single word (32 bits) to store both the character and the prefix count. Java uses two bytes for the char type, and we could do with one byte. But that still uses 5 bytes per Trie node, but we don't need the 2 billion range possible with 32 bits to represent the count of all prefixes for any English substring.
The prefix count is highest on the root node, as all words have the head node character as the prefix. So the highest prefix count is the number of words in the dictionary. This is generally never more than 250, 000. We can safely use 24 bits which can represent 8 million as a signed integer.
So we can combine the character and the prefix count to a single word.
Is there anything else we could do? Yes - we could read all the words into our Trie and trim the list on each Trie node. This results from the observation that we rarely use all the slots in our list - Especially as the trie spans out, there are fewer number of new words. Thus we could find the last used index on the list, and create a new shorter list.
Doing all of these drops the size of the trie from ~ 228M to ~ 68M.
Here is an implementation.
I store a random word list in pastebin for testing - there is code here that uses this, as well as pulling a dictionary of lower case words. If you use this, you will need to make sure the dictionary you substitute has only lower case words, so some pre-processing might be necessary - in particular, you are likely to find the hyphen (-) in some word which you will need to remove.
Last but not least, the memory stats don't give an idea of the space saving due to garbage collector not being deterministic. I use the sizeInBytes() to recursively calculate the memory foot print of the Trie.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 | public class Trie { interface GetAndPut { public void put(Character ch, Trie trie); public Trie get(Character ch); public int lastUsedIndex(); public Trie[] children(); public void trim(); } class SuffixCharsWithMap implements GetAndPut { Map<Character, Trie> list = new HashMap<Character, Trie>(); public void put(Character ch, Trie node) { list.put(ch, node); } public Trie get(Character ch) { return list.get(ch); } public int lastUsedIndex() { return list.size()-1; } public Trie[] children() { return list.values().toArray(new Trie[list.size()]); } public void trim() { } } class SuffixCharsWithArray implements GetAndPut { public Trie[] list; public SuffixCharsWithArray() { int sz = (int)('z') - (int)'`' +1; list = new Trie[sz]; } public void put(Character ch, Trie node) { list[(int)ch - (int)'`'] = node; } public Trie get(Character ch) { try { return list[(int) ch - (int) '`']; } catch (ArrayIndexOutOfBoundsException e) { // we hit an index that got trimmed out return null; } } public int lastUsedIndex() { int lue = -1; for (int i=0; i<list.length; i++) { if (list[i] != null) { lue = i; } } return lue; } public void trim() { if (lastUsedIndex()+1 < list.length) { Trie[] newList = new Trie[lastUsedIndex() + 1]; for (int i = 0; i < newList.length; i++) { newList[i] = list[i]; } list = newList; } } public Trie[] children() { List<Trie> l = new ArrayList<Trie>(); for (Trie t: list) { if (t != null && t.getChar() != '`') { l.add(t); } } return l.toArray(new Trie[l.size()]); } } //store char and the prefix count using 32 bits //the first byte is the character, the next 3 bytes get the prefix count //3 bytes can hold ~ 16 million, and there aren't that many english //words. the total word count is less than 250,000, and the prefix count //of any substring is less than that. private int count = 0; public char getChar() { return (char)(count & 0xFF000000 >> 24); } public void setChar(char ch) { count = ((int)ch) << 24 | (count & 0x00FFFFFF); } public int getCount() { return count & 0x00FFFFFF; } //this is safe as the count will never get high enough //to push over int the most significant byte holding the character public void incCount() { count++; } GetAndPut suffixChars = new SuffixCharsWithArray(); //GetAndPut suffixChars = new SuffixCharsWithMap(); public Trie(char ch) { this.setChar(ch); } public Trie add(char ch) { Trie node = this.suffixChars.get(ch); if (node == null) { Trie newNode = new Trie(ch); this.suffixChars.put(ch, newNode); node = newNode; } //adding the count to the current node is preferable //to adding to the node that matches the character. //This way, we won't add to the sentinel node //and we add only in one place. this.incCount(); return node; } public int size() { return this.count; } private Trie findChar(char ch) { return this.suffixChars.get(ch); } public boolean findWord(String word) { Trie node = this; for (char ch: word.toCharArray()) { node = node.findChar(ch); if (node == null) { return false; } } //we may have found a prefix, make sure it is a word //if it's a word, the list must have the sentinel. return node.suffixChars.get('`') != null; } public int findPartial(String prefix) { Trie node = this; for (char ch : prefix.toCharArray()) { node = node.suffixChars.get(ch); if (node == null) { return 0; } } return node.size(); } public void add(String s) { Trie node = this; for (char ch : s.toCharArray()) { node = node.add(ch); } //add the sentinel to mark the end of the word. node.add('`'); } private void walk(int[] indices) { indices[this.suffixChars.lastUsedIndex()] ++; for (Trie ch : suffixChars.children()) { ch.walk(indices); } } public int[] lastUsedIndices() { int[] indices = new int[(int)'z' - (int)'`' + 1]; walk(indices); return indices; } private void walk2() { suffixChars.trim(); for (Trie ch : suffixChars.children()) { ch.walk2(); } } static private int walk3(Trie t) { if (t == null) return 0; // 4 = size of `count` // 8 = size of each reference to a Trie int acc = 4 + 8 * (((SuffixCharsWithArray)t.suffixChars).list.length); for (Trie node: t.suffixChars.children()) { acc += walk3(node); } return acc; } public void trim() { walk2(); } public int sizeInBytes() { return walk3(this); } public void read() throws FileNotFoundException { String wordFilePath = "/Users/thushara/lcwords.txt"; BufferedReader br = new BufferedReader(new FileReader(wordFilePath)); String word; try { while ((word = br.readLine()) != null) { add(word); } } catch (IOException e) { System.err.format("disk error! %s", e.getMessage()); } } static public String getRandomWordList() throws MalformedURLException, IOException { Pattern alpha = Pattern.compile("^[A-Za-z]+$"); String url = "https://pastebin.com/raw/NXH7UAr1"; URL obj = new URL(url); HttpURLConnection con = (HttpURLConnection) obj.openConnection(); con.setRequestMethod("GET"); int responseCode = con.getResponseCode(); BufferedReader in = new BufferedReader( new InputStreamReader(con.getInputStream())); String inputLine; StringBuffer response = new StringBuffer(); while ((inputLine = in.readLine()) != null) { Matcher m = alpha.matcher(inputLine); if (m.matches()) { response.append(inputLine.toLowerCase()); } } in.close(); return response.toString(); } static public void main(String[] args) throws FileNotFoundException, IOException { long mem1 = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory(); System.out.format("memory usage at start %d\n", mem1); Trie trie = new Trie('$'); trie.read(); System.out.format("size of trie in bytes: %d\n", trie.sizeInBytes()); trie.trim(); System.out.format("size of trimmed trie in bytes: %d\n", trie.sizeInBytes()); Scanner in = new Scanner(System.in); System.out.println("type a word in lower case (upper case char to exit)> "); while (true) { String s = in.next(); if (Character.isUpperCase(s.charAt(0))) break; boolean found = trie.findPartial(s) > 0; System.out.println(found ? "yes" : "no"); } long st = System.currentTimeMillis(); String words = getRandomWordList(); String[] arr = words.split(" "); for (String s: arr) { if (!s.isEmpty() && !trie.findWord(s)) System.out.println("couldn't find " + s); } long elapsed = System.currentTimeMillis() - st; long mem2 = Runtime.getRuntime().totalMemory() - Runtime.getRuntime().freeMemory(); System.out.format("memory usage at end %d\n", mem2); System.out.format("took %d ms for %d words using %d MB\n", elapsed, arr.length, (mem2 - mem1)/1024/1024); } } |
No comments:
Post a Comment