mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Revert "New tokenizer implementation for MPT and GPT-J"
This reverts commit bbcee1ced5
.
This commit is contained in:
parent
cdc7d6ccc4
commit
7f9f91ad94
@ -1,4 +1,4 @@
|
|||||||
[codespell]
|
[codespell]
|
||||||
skip = .git,*.pdf,*.svg,*_tokenizer_config.h
|
skip = .git,*.pdf,*.svg
|
||||||
#
|
#
|
||||||
# ignore-words-list =
|
# ignore-words-list =
|
||||||
|
@ -23,7 +23,6 @@ set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL
|
|||||||
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
|
|
||||||
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
|
set(LLAMA_BUILD_EXAMPLES ON CACHE BOOL "llama: build examples" FORCE)
|
||||||
set(BUILD_SHARED_LIBS ON FORCE)
|
set(BUILD_SHARED_LIBS ON FORCE)
|
||||||
@ -35,7 +34,6 @@ if (GPT4ALL_AVX_ONLY)
|
|||||||
set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE)
|
set(LLAMA_FMA OFF CACHE BOOL "llama: enable FMA" FORCE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
find_package(ICU REQUIRED COMPONENTS uc i18n)
|
|
||||||
add_subdirectory(llama.cpp)
|
add_subdirectory(llama.cpp)
|
||||||
|
|
||||||
add_library(llmodel
|
add_library(llmodel
|
||||||
@ -43,14 +41,12 @@ add_library(llmodel
|
|||||||
llamamodel.h llamamodel.cpp
|
llamamodel.h llamamodel.cpp
|
||||||
llama.cpp/examples/common.cpp
|
llama.cpp/examples/common.cpp
|
||||||
llmodel.h llmodel_c.h llmodel_c.cpp
|
llmodel.h llmodel_c.h llmodel_c.cpp
|
||||||
mpt.h mpt.cpp tokenizer/bpe.cpp tokenizer/bpe.h
|
mpt.h mpt.cpp
|
||||||
tokenizer/mpt_tokenizer_config.h tokenizer/gptj_tokenizer_config.h
|
|
||||||
utils.h utils.cpp
|
utils.h utils.cpp
|
||||||
)
|
)
|
||||||
|
|
||||||
target_link_libraries(llmodel
|
target_link_libraries(llmodel
|
||||||
PRIVATE llama
|
PRIVATE llama)
|
||||||
PUBLIC ICU::uc ICU::i18n)
|
|
||||||
|
|
||||||
set_target_properties(llmodel PROPERTIES
|
set_target_properties(llmodel PROPERTIES
|
||||||
VERSION ${PROJECT_VERSION}
|
VERSION ${PROJECT_VERSION}
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <filesystem>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -861,8 +860,6 @@ bool GPTJ::loadModel(const std::string &modelPath) {
|
|||||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||||
d_ptr->modelLoaded = true;
|
d_ptr->modelLoaded = true;
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
get_bpecpp_tokenizer(TokenizerType::GPTJ, m_bpe, m_tokav);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -918,7 +915,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
int64_t t_prompt_us = 0;
|
int64_t t_prompt_us = 0;
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<uint32_t> embd_inp = m_tokav->encode(prompt, *m_bpe);
|
std::vector<gpt_vocab::id> embd_inp = ::gpt_tokenize(d_ptr->vocab, prompt);
|
||||||
|
|
||||||
// save the context size
|
// save the context size
|
||||||
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
|
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
|
||||||
@ -1035,7 +1032,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
if (id == 50256 /*end of text*/)
|
if (id == 50256 /*end of text*/)
|
||||||
goto stop_generating;
|
goto stop_generating;
|
||||||
|
|
||||||
const std::string str = m_tokav->decode({(uint32_t) id}, *m_bpe, true, false);
|
const std::string str = d_ptr->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Check if the provided str is part of our reverse prompts
|
// Check if the provided str is part of our reverse prompts
|
||||||
bool foundPartialReversePrompt = false;
|
bool foundPartialReversePrompt = false;
|
||||||
@ -1065,8 +1062,7 @@ void GPTJ::prompt(const std::string &prompt,
|
|||||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||||
promptCtx.tokens.push_back(t);
|
promptCtx.tokens.push_back(t);
|
||||||
const std::string decoded = m_tokav->decode({(uint32_t) t}, *m_bpe, true, false);
|
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
|
||||||
if (!responseCallback(t, decoded))
|
|
||||||
goto stop_generating;
|
goto stop_generating;
|
||||||
}
|
}
|
||||||
cachedTokens.clear();
|
cachedTokens.clear();
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "llmodel.h"
|
#include "llmodel.h"
|
||||||
#include "tokenizer/bpe.h"
|
|
||||||
|
|
||||||
class GPTJPrivate;
|
class GPTJPrivate;
|
||||||
class GPTJ : public LLModel {
|
class GPTJ : public LLModel {
|
||||||
@ -32,8 +31,6 @@ protected:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
GPTJPrivate *d_ptr;
|
GPTJPrivate *d_ptr;
|
||||||
std::unique_ptr<bpecpp::AdditionalVocabAdapter> m_tokav;
|
|
||||||
std::unique_ptr<bpecpp::BPE> m_bpe;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // GPTJ_H
|
#endif // GPTJ_H
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <filesystem>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <random>
|
#include <random>
|
||||||
@ -786,12 +785,6 @@ bool MPT::loadModel(const std::string &modelPath) {
|
|||||||
d_ptr->modelLoaded = true;
|
d_ptr->modelLoaded = true;
|
||||||
d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end();
|
d_ptr->has_im_end = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end();
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
if (modelPath.find("-chat") != std::string::npos) {
|
|
||||||
get_bpecpp_tokenizer(TokenizerType::MPT_CHAT, m_bpe, m_tokav);
|
|
||||||
} else {
|
|
||||||
get_bpecpp_tokenizer(TokenizerType::MPT, m_bpe, m_tokav);
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -847,7 +840,7 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
int64_t t_prompt_us = 0;
|
int64_t t_prompt_us = 0;
|
||||||
|
|
||||||
// tokenize the prompt
|
// tokenize the prompt
|
||||||
std::vector<uint32_t> embd_inp = m_tokav->encode(prompt, *m_bpe);
|
std::vector<int> embd_inp = gpt_tokenize(d_ptr->vocab, prompt);
|
||||||
|
|
||||||
// save the context size
|
// save the context size
|
||||||
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
|
promptCtx.n_ctx = d_ptr->model->hparams.n_ctx;
|
||||||
@ -913,7 +906,6 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
int r_instructFound = 0;
|
int r_instructFound = 0;
|
||||||
|
|
||||||
std::string cachedResponse;
|
std::string cachedResponse;
|
||||||
std::string decodeBuffer;
|
|
||||||
std::vector<int> cachedTokens;
|
std::vector<int> cachedTokens;
|
||||||
std::unordered_set<std::string> reversePrompts
|
std::unordered_set<std::string> reversePrompts
|
||||||
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
= { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" };
|
||||||
@ -969,7 +961,7 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
if (id == 0 /*end of text*/)
|
if (id == 0 /*end of text*/)
|
||||||
goto stop_generating;
|
goto stop_generating;
|
||||||
|
|
||||||
const std::string str = m_tokav->decode({(uint32_t) id}, *m_bpe, true, false);
|
const std::string str = d_ptr->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Check if the provided str is part of our reverse prompts
|
// Check if the provided str is part of our reverse prompts
|
||||||
bool foundPartialReversePrompt = false;
|
bool foundPartialReversePrompt = false;
|
||||||
@ -999,8 +991,7 @@ void MPT::prompt(const std::string &prompt,
|
|||||||
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
if (promptCtx.tokens.size() == promptCtx.n_ctx)
|
||||||
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
promptCtx.tokens.erase(promptCtx.tokens.begin());
|
||||||
promptCtx.tokens.push_back(t);
|
promptCtx.tokens.push_back(t);
|
||||||
const std::string decoded = m_tokav->decode({(uint32_t) t}, *m_bpe, true, false);
|
if (!responseCallback(t, d_ptr->vocab.id_to_token[t]))
|
||||||
if (!responseCallback(t, decoded))
|
|
||||||
goto stop_generating;
|
goto stop_generating;
|
||||||
}
|
}
|
||||||
cachedTokens.clear();
|
cachedTokens.clear();
|
||||||
|
@ -5,7 +5,6 @@
|
|||||||
#include <functional>
|
#include <functional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "llmodel.h"
|
#include "llmodel.h"
|
||||||
#include "tokenizer/bpe.h"
|
|
||||||
|
|
||||||
class MPTPrivate;
|
class MPTPrivate;
|
||||||
class MPT : public LLModel {
|
class MPT : public LLModel {
|
||||||
@ -32,8 +31,6 @@ protected:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
MPTPrivate *d_ptr;
|
MPTPrivate *d_ptr;
|
||||||
std::unique_ptr<bpecpp::AdditionalVocabAdapter> m_tokav;
|
|
||||||
std::unique_ptr<bpecpp::BPE> m_bpe;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MPT_H
|
#endif // MPT_H
|
||||||
|
@ -1,136 +0,0 @@
|
|||||||
import sys
|
|
||||||
import json
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
def iter_with_last(lst):
|
|
||||||
llen = len(lst)
|
|
||||||
for i, entry in enumerate(lst):
|
|
||||||
last = i == (llen - 1)
|
|
||||||
yield last, entry
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BufSlice:
|
|
||||||
offset: int
|
|
||||||
length: int
|
|
||||||
def __repr__(self):
|
|
||||||
return '{'f'0x{self.offset:x},{self.length}''}'
|
|
||||||
|
|
||||||
def c_str_dump(bs):
|
|
||||||
s = bytearray()
|
|
||||||
s += b'"'
|
|
||||||
llen = 0
|
|
||||||
lasthex = False
|
|
||||||
for byte in bs:
|
|
||||||
if byte in (b' 01234567890abcdefghijklmnopqrstuvwxyz_-=/;:<>'
|
|
||||||
b'ABCDEFGHIJKLMNOPQRSTUVWXYZ!@#$%^&*(),.[]{}`~|'):
|
|
||||||
# need to avoid hex characters not part of a hex escape
|
|
||||||
# appearing directly after a hex scape
|
|
||||||
if lasthex and byte in b'0123456789abcdefABCDEF':
|
|
||||||
s += b'""'
|
|
||||||
llen += 2
|
|
||||||
s += bytes([byte])
|
|
||||||
llen += 1
|
|
||||||
lasthex = False
|
|
||||||
else:
|
|
||||||
s += f'\\x{byte:02x}'.encode('utf8')
|
|
||||||
llen += 4
|
|
||||||
lasthex = True
|
|
||||||
if llen >= 80:
|
|
||||||
llen = 0
|
|
||||||
s += b"\"\n\""
|
|
||||||
s += b'"'
|
|
||||||
return s.decode('utf8')
|
|
||||||
|
|
||||||
class Buf:
|
|
||||||
def __init__(self):
|
|
||||||
self.buf = b''
|
|
||||||
self.cache = {}
|
|
||||||
|
|
||||||
def get(self, s):
|
|
||||||
if s in self.cache:
|
|
||||||
return self.cache[s]
|
|
||||||
offset = len(self.buf)
|
|
||||||
bs = s.encode('utf8')
|
|
||||||
exoffs = self.buf.find(bs)
|
|
||||||
if exoffs != -1:
|
|
||||||
slc = BufSlice(offset=exoffs, length=len(bs))
|
|
||||||
self.cache[s] = slc
|
|
||||||
return slc
|
|
||||||
return None
|
|
||||||
|
|
||||||
def insert(self, s):
|
|
||||||
slc = self.get(s)
|
|
||||||
if slc is None:
|
|
||||||
bs = s.encode('utf8')
|
|
||||||
offset = len(self.buf)
|
|
||||||
self.buf += bs
|
|
||||||
slc = BufSlice(offset=offset, length=len(bs))
|
|
||||||
return slc
|
|
||||||
|
|
||||||
class BreakEvery:
|
|
||||||
def __init__(self, n):
|
|
||||||
self.counter = 0
|
|
||||||
self.n = n
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
self.counter += 1
|
|
||||||
self.counter %= self.n
|
|
||||||
if self.counter == 0:
|
|
||||||
return '\n'
|
|
||||||
return ''
|
|
||||||
|
|
||||||
def do_convert(tkfilename, prefix):
|
|
||||||
with open(tkfilename, 'rb') as tkf:
|
|
||||||
tokconfig = json.load(tkf)
|
|
||||||
|
|
||||||
# every string in the vocab also appears in the merges list so we can store
|
|
||||||
# much less data in the binary by deduplicating these references, sorting by
|
|
||||||
# length descending makes it more likely prefixes of longer strings get
|
|
||||||
# deduped, and secondarily sorting lexicographically them makes the buffer
|
|
||||||
# data more compressible (they are not compressed in the binary itself, but
|
|
||||||
# the binary will be more compressible)
|
|
||||||
split_merges = [s.split(' ') for s in tokconfig['model']['merges']]
|
|
||||||
len_then = lambda m: (len(m),m)
|
|
||||||
avwords = sorted((av['content'] for av in tokconfig['added_tokens']), key=len_then, reverse=True)
|
|
||||||
all_strs = avwords + sorted(list(tokconfig['model']['vocab'].keys()), key=len_then, reverse=True)
|
|
||||||
buf = Buf()
|
|
||||||
for s in all_strs:
|
|
||||||
buf.insert(s)
|
|
||||||
|
|
||||||
print('// @generated GENERATED BY scripts/gen_tokenizer_include.py DO NOT MODIFY')
|
|
||||||
print(f'#ifndef {prefix.upper()}_TOKENIZER_CONFIG_H_')
|
|
||||||
print(f'#define {prefix.upper()}_TOKENIZER_CONFIG_H_')
|
|
||||||
print('#include "bpe.h"')
|
|
||||||
print(f"// buflen {len(buf.buf)}")
|
|
||||||
print(f"constexpr const char {prefix}_buffer[] =\n{c_str_dump(buf.buf)};")
|
|
||||||
avilen = len(tokconfig['added_tokens'])
|
|
||||||
print(f'constexpr std::array<bpecpp::additional_vocab_item_embedded, {avilen}> {prefix}_additional_vocab = ''{{')
|
|
||||||
for last, avi in iter_with_last(tokconfig['added_tokens']):
|
|
||||||
comma = ',' if not last else ''
|
|
||||||
print(' {'f'.id = {avi["id"]}, .content={buf.get(avi["content"])}, .special={json.dumps(avi["special"])}''}' + comma)
|
|
||||||
print('}};')
|
|
||||||
print()
|
|
||||||
mergeslen = len(tokconfig['model']['merges'])
|
|
||||||
print(f'constexpr std::array<std::pair<bpecpp::buf_ref, bpecpp::buf_ref>, {mergeslen}> {prefix}_merges = ''{{')
|
|
||||||
breaker = BreakEvery(4)
|
|
||||||
for last, (ma, mb) in iter_with_last(split_merges):
|
|
||||||
comma = ',' if not last else ''
|
|
||||||
print(' {'f'{buf.get(ma)},{buf.get(mb)}''}' + comma + repr(breaker), end='')
|
|
||||||
print('\n}};')
|
|
||||||
vocablen = len(tokconfig['model']['vocab'])
|
|
||||||
print(f'constexpr std::array<bpecpp::buf_ref, {vocablen}> {prefix}_vocab = ''{{')
|
|
||||||
breaker = BreakEvery(8)
|
|
||||||
for last, vi in iter_with_last(tokconfig['model']['vocab']):
|
|
||||||
comma = ',' if not last else ''
|
|
||||||
print(f' {buf.get(vi)}' + comma + repr(breaker), end='')
|
|
||||||
print('\n}};')
|
|
||||||
print(f'#endif // {prefix.upper()}_TOKENIZER_CONFIG_H_')
|
|
||||||
|
|
||||||
def main():
|
|
||||||
if len(sys.argv) < 3:
|
|
||||||
print(f'Usage: {sys.argv[0]} <hf tokenizer json> <symbol prefix>')
|
|
||||||
sys.exit(1)
|
|
||||||
do_convert(sys.argv[1], sys.argv[2])
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
@ -1,257 +0,0 @@
|
|||||||
#include "bpe.h"
|
|
||||||
#include <unicode/normalizer2.h>
|
|
||||||
#include <unicode/regex.h>
|
|
||||||
#include <unicode/schriter.h>
|
|
||||||
#include <unicode/unistr.h>
|
|
||||||
|
|
||||||
#include <regex>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
namespace bpecpp {
|
|
||||||
const std::string_view BPE_PRETOK_REGEX =
|
|
||||||
R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
|
||||||
|
|
||||||
static void get_bigrams(const std::vector<icu::UnicodeString>& input,
|
|
||||||
std::unordered_set<UnicodeBigram, bigram_hash>& pairs) {
|
|
||||||
pairs.clear();
|
|
||||||
auto i = input.begin();
|
|
||||||
auto prev = *i++;
|
|
||||||
for (; i != input.end(); ++i) {
|
|
||||||
pairs.insert({prev, *i});
|
|
||||||
prev = *i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BPE::BPE(const std::unordered_map<std::string_view, uint32_t>& vocab,
|
|
||||||
const std::vector<std::pair<std::string_view, std::string_view>>& merges) {
|
|
||||||
for (auto pair : vocab) {
|
|
||||||
icu::UnicodeString encd = icu::UnicodeString::fromUTF8(pair.first);
|
|
||||||
m_vocab[encd] = pair.second;
|
|
||||||
m_reverse_vocab[pair.second] = encd;
|
|
||||||
}
|
|
||||||
size_t n = 0;
|
|
||||||
for (auto merge : merges) {
|
|
||||||
auto left = icu::UnicodeString::fromUTF8(merge.first);
|
|
||||||
auto right = icu::UnicodeString::fromUTF8(merge.second);
|
|
||||||
m_merges[{left, right}] = n++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint32_t> BPE::encode(const std::string& input) {
|
|
||||||
auto normalized = normalize_nfc(input);
|
|
||||||
auto pretokenized = pretokenize(normalized);
|
|
||||||
std::vector<icu::UnicodeString> tokens_merged;
|
|
||||||
for (auto &ptok : pretokenized) {
|
|
||||||
bpe(ptok, tokens_merged);
|
|
||||||
}
|
|
||||||
std::vector<uint32_t> final_tokens;
|
|
||||||
for (auto &mtok : tokens_merged) {
|
|
||||||
final_tokens.push_back(m_vocab[mtok]);
|
|
||||||
}
|
|
||||||
return final_tokens;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string BPE::decode(const std::vector<uint32_t>& tokens, bool valid_utf8) {
|
|
||||||
std::string out;
|
|
||||||
for (uint32_t t : tokens) {
|
|
||||||
icu::UnicodeString benc = m_reverse_vocab[t];
|
|
||||||
icu::StringCharacterIterator schriter(benc);
|
|
||||||
for (UChar32 c = schriter.first32(); schriter.hasNext();
|
|
||||||
c = schriter.next32()) {
|
|
||||||
out.push_back(m_bs_table.codepoint_to_byte((uint32_t)c));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// roundtrip through ICU to replace invalid utf8 with U+FFFD
|
|
||||||
if (valid_utf8) {
|
|
||||||
auto tmp = icu::UnicodeString::fromUTF8(out);
|
|
||||||
out.clear();
|
|
||||||
tmp.toUTF8String(out);
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/karpathy/minGPT/blob/37baab71b9abea1b76ab957409a1cc2fbfba8a26/mingpt/bpe.py#L95
|
|
||||||
void BPE::bpe(icu::UnicodeString token_pretoked,
|
|
||||||
std::vector<icu::UnicodeString>& output) {
|
|
||||||
if (token_pretoked.length() < 2) {
|
|
||||||
output.push_back(token_pretoked);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
std::vector<icu::UnicodeString> words;
|
|
||||||
std::vector<icu::UnicodeString> words_update;
|
|
||||||
icu::StringCharacterIterator schriter(token_pretoked);
|
|
||||||
UChar32 c;
|
|
||||||
for (schriter.setToStart(); schriter.hasNext();) {
|
|
||||||
c = schriter.next32PostInc();
|
|
||||||
icu::UnicodeString w;
|
|
||||||
w.append(c);
|
|
||||||
words.push_back(w);
|
|
||||||
}
|
|
||||||
std::unordered_set<UnicodeBigram, bigram_hash> pairs;
|
|
||||||
get_bigrams(words, pairs);
|
|
||||||
while (true) {
|
|
||||||
size_t min_rank = SIZE_MAX;
|
|
||||||
UnicodeBigram to_merge;
|
|
||||||
for (auto &bigram : pairs) {
|
|
||||||
auto loc = m_merges.find(bigram);
|
|
||||||
if (loc != m_merges.end() && loc->second < min_rank) {
|
|
||||||
min_rank = loc->second;
|
|
||||||
to_merge = loc->first;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (min_rank == SIZE_MAX) {
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
auto i = words.begin();
|
|
||||||
while (i < words.end()) {
|
|
||||||
if (*i == to_merge.first) {
|
|
||||||
auto inext = i;
|
|
||||||
inext++;
|
|
||||||
if (inext != words.end() && *inext == to_merge.second) {
|
|
||||||
words_update.push_back(*i + *inext);
|
|
||||||
i = inext;
|
|
||||||
} else {
|
|
||||||
words_update.push_back(*i);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
words_update.push_back(*i);
|
|
||||||
}
|
|
||||||
++i;
|
|
||||||
}
|
|
||||||
words.swap(words_update);
|
|
||||||
words_update.clear();
|
|
||||||
get_bigrams(words, pairs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
output.insert(output.end(), words.begin(), words.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string BPE::normalize_nfc(const std::string& input) {
|
|
||||||
UErrorCode uerror = U_ZERO_ERROR;
|
|
||||||
auto nfcnorm = icu::Normalizer2::getNFCInstance(uerror);
|
|
||||||
if (!U_SUCCESS(uerror))
|
|
||||||
throw std::runtime_error("could not get ICU NFC normalizer");
|
|
||||||
auto icu_ti = icu::UnicodeString::fromUTF8(input);
|
|
||||||
std::string out;
|
|
||||||
nfcnorm->normalize(icu_ti, uerror).toUTF8String(out);
|
|
||||||
if (!U_SUCCESS(uerror))
|
|
||||||
throw std::runtime_error("ICU string normalization failed");
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<icu::UnicodeString> BPE::pretokenize(const std::string& input) {
|
|
||||||
UParseError pe;
|
|
||||||
UErrorCode uerror = U_ZERO_ERROR;
|
|
||||||
auto bpe_re_icustr = icu::UnicodeString::fromUTF8(BPE_PRETOK_REGEX);
|
|
||||||
if (m_pretok_re == nullptr) {
|
|
||||||
m_pretok_re = std::unique_ptr<icu::RegexPattern>(
|
|
||||||
icu::RegexPattern::compile(bpe_re_icustr, pe, uerror));
|
|
||||||
if (!U_SUCCESS(uerror))
|
|
||||||
throw std::runtime_error("Compiling BPE pretokenizer regex failed");
|
|
||||||
}
|
|
||||||
auto uinput = icu::UnicodeString::fromUTF8(input);
|
|
||||||
std::unique_ptr<icu::RegexMatcher> pretok_matcher(
|
|
||||||
m_pretok_re->matcher(uinput, uerror));
|
|
||||||
std::vector<icu::UnicodeString> pretoks;
|
|
||||||
if (!U_SUCCESS(uerror))
|
|
||||||
throw std::runtime_error("Creating BPE pretokenizer matcher failed");
|
|
||||||
while (pretok_matcher->find()) {
|
|
||||||
auto match = pretok_matcher->group(uerror);
|
|
||||||
if (!U_SUCCESS(uerror))
|
|
||||||
throw std::runtime_error(
|
|
||||||
"Getting BPE pretokenizer regex match failed");
|
|
||||||
std::string s;
|
|
||||||
icu::UnicodeString out;
|
|
||||||
match.toUTF8String(s);
|
|
||||||
for (char c : s) {
|
|
||||||
uint32_t codepoint = m_bs_table.byte_to_codepoint((uint8_t)c);
|
|
||||||
out.append((UChar32)codepoint);
|
|
||||||
}
|
|
||||||
pretoks.push_back(out);
|
|
||||||
}
|
|
||||||
return pretoks;
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::string regex_escape(const std::string_view inp) {
|
|
||||||
std::string s(inp);
|
|
||||||
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
|
||||||
return std::regex_replace(s, metacharacters, "\\$&");
|
|
||||||
}
|
|
||||||
|
|
||||||
AdditionalVocabAdapter::AdditionalVocabAdapter(
|
|
||||||
const std::vector<additional_vocab_item>& vocab) {
|
|
||||||
std::string addedtoken_regex;
|
|
||||||
for (const additional_vocab_item& item : vocab) {
|
|
||||||
if (!addedtoken_regex.empty()) {
|
|
||||||
addedtoken_regex += "|";
|
|
||||||
}
|
|
||||||
addedtoken_regex += regex_escape(item.content);
|
|
||||||
m_token_to_id[item.content] = item.id;
|
|
||||||
m_id_to_token[item.id] = item.content;
|
|
||||||
if (item.special) {
|
|
||||||
m_special_ids.insert(item.id);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m_addedtoken_re = std::regex(addedtoken_regex);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint32_t> AdditionalVocabAdapter::encode(
|
|
||||||
const std::string& input,
|
|
||||||
BPE& bpemodel,
|
|
||||||
bool encode_special_tokens) {
|
|
||||||
if (m_token_to_id.empty()) {
|
|
||||||
return bpemodel.encode(input);
|
|
||||||
}
|
|
||||||
std::vector<uint32_t> out;
|
|
||||||
std::string work = input;
|
|
||||||
std::smatch m;
|
|
||||||
while (std::regex_search(work, m, m_addedtoken_re)) {
|
|
||||||
auto tokloc = m_token_to_id.find(m.str());
|
|
||||||
if (tokloc != m_token_to_id.end()) {
|
|
||||||
auto tokid = tokloc->second;
|
|
||||||
auto prefix_decoded = bpemodel.encode(m.prefix());
|
|
||||||
out.insert(out.end(), prefix_decoded.begin(), prefix_decoded.end());
|
|
||||||
bool special = m_special_ids.find(tokid) != m_special_ids.end();
|
|
||||||
if (!special || encode_special_tokens) {
|
|
||||||
out.push_back(tokid);
|
|
||||||
}
|
|
||||||
work = m.suffix();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!work.empty()) {
|
|
||||||
auto rest_decoded = bpemodel.encode(work);
|
|
||||||
out.insert(out.end(), rest_decoded.begin(), rest_decoded.end());
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string AdditionalVocabAdapter::decode(const std::vector<uint32_t>& tokens,
|
|
||||||
BPE& bpemodel,
|
|
||||||
bool decode_special_tokens,
|
|
||||||
bool valid_utf8) {
|
|
||||||
std::string out;
|
|
||||||
std::vector<uint32_t> to_decode;
|
|
||||||
for (auto tokid : tokens) {
|
|
||||||
auto tokloc = m_id_to_token.find(tokid);
|
|
||||||
if (tokloc != m_id_to_token.end()) { // is an added token
|
|
||||||
if (!to_decode.empty()) {
|
|
||||||
out += bpemodel.decode(to_decode, valid_utf8);
|
|
||||||
to_decode.clear();
|
|
||||||
}
|
|
||||||
bool special = m_special_ids.find(tokid) != m_special_ids.end();
|
|
||||||
// only include non-special tokens unless decode_special_tokens
|
|
||||||
if (!special || decode_special_tokens) {
|
|
||||||
out += tokloc->second;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// non-added, regular token.
|
|
||||||
to_decode.push_back(tokid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!to_decode.empty()) {
|
|
||||||
out += bpemodel.decode(to_decode, valid_utf8);
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
} // namespace bpecpp
|
|
@ -1,123 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
#include <unicode/regex.h>
|
|
||||||
#include <unicode/unistr.h>
|
|
||||||
|
|
||||||
#include <cstdint>
|
|
||||||
#include <regex>
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <unordered_set>
|
|
||||||
#include <vector>
|
|
||||||
#include <string_view>
|
|
||||||
|
|
||||||
namespace bpecpp {
|
|
||||||
typedef std::pair<icu::UnicodeString, icu::UnicodeString> UnicodeBigram;
|
|
||||||
|
|
||||||
class bpe_char_byte_table {
|
|
||||||
public:
|
|
||||||
bpe_char_byte_table() {
|
|
||||||
int n = 0;
|
|
||||||
for (uint8_t byte = 0; m_codepoint_to_byte.size() < 256; byte++) {
|
|
||||||
bool keep = (byte >= '!' && byte <= '~') ||
|
|
||||||
(byte >= 0xa1 && byte <= 0xac) ||
|
|
||||||
(byte >= 0xae && byte <= 0xff);
|
|
||||||
uint32_t codepoint = byte;
|
|
||||||
if (!keep) {
|
|
||||||
codepoint = 256 + n;
|
|
||||||
n++;
|
|
||||||
}
|
|
||||||
m_byte_to_codepoint[byte] = codepoint;
|
|
||||||
m_codepoint_to_byte[codepoint] = byte;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
uint32_t byte_to_codepoint(uint8_t byte) {
|
|
||||||
return m_byte_to_codepoint[byte];
|
|
||||||
}
|
|
||||||
|
|
||||||
uint8_t codepoint_to_byte(uint32_t codepoint) {
|
|
||||||
return m_codepoint_to_byte.at(codepoint);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::array<uint32_t, 256> m_byte_to_codepoint;
|
|
||||||
std::unordered_map<uint32_t, uint8_t> m_codepoint_to_byte;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct bigram_hash {
|
|
||||||
std::size_t operator()(const UnicodeBigram& pair) const {
|
|
||||||
return pair.first.hashCode() + pair.second.hashCode();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct icu_hash {
|
|
||||||
std::size_t operator()(const icu::UnicodeString& us) const {
|
|
||||||
return us.hashCode();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class BPE {
|
|
||||||
public:
|
|
||||||
BPE(const std::unordered_map<std::string_view, uint32_t> &vocab,
|
|
||||||
const std::vector<std::pair<std::string_view, std::string_view>> &merges);
|
|
||||||
|
|
||||||
std::vector<uint32_t> encode(const std::string& input);
|
|
||||||
|
|
||||||
std::string decode(const std::vector<uint32_t>& tokens,
|
|
||||||
bool valid_utf8 = true);
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unordered_map<icu::UnicodeString, uint32_t, icu_hash> m_vocab;
|
|
||||||
std::unordered_map<uint32_t, icu::UnicodeString> m_reverse_vocab;
|
|
||||||
std::unordered_map<UnicodeBigram, size_t, bigram_hash> m_merges;
|
|
||||||
bpe_char_byte_table m_bs_table;
|
|
||||||
|
|
||||||
void bpe(icu::UnicodeString token_pretoked,
|
|
||||||
std::vector<icu::UnicodeString>& output);
|
|
||||||
std::unique_ptr<icu::RegexPattern> m_pretok_re;
|
|
||||||
std::string normalize_nfc(const std::string& input);
|
|
||||||
std::vector<icu::UnicodeString> pretokenize(const std::string& input);
|
|
||||||
};
|
|
||||||
|
|
||||||
// for embedding tokenizer configs in the library - had initially constructed
|
|
||||||
// `string_view`s in the generated headers, *but* generating thousands actual
|
|
||||||
// references into the buffer generates thousands of *relocations* and makes
|
|
||||||
// compilation rather slow, delaying resolving the real address into a
|
|
||||||
// string_view until runtime fixes that
|
|
||||||
struct buf_ref {
|
|
||||||
// packing these into a single u32 reduces the size of the embedded
|
|
||||||
// configs significantly (5.0MB->1.6MB)
|
|
||||||
uint32_t offset : 20;
|
|
||||||
uint32_t length : 12;
|
|
||||||
|
|
||||||
std::string_view into(const char* buf) {
|
|
||||||
return std::string_view(&buf[offset], length);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
struct additional_vocab_item_embedded {
|
|
||||||
uint32_t id;
|
|
||||||
buf_ref content;
|
|
||||||
bool special;
|
|
||||||
};
|
|
||||||
struct additional_vocab_item {
|
|
||||||
uint32_t id;
|
|
||||||
std::string_view content;
|
|
||||||
bool special = false;
|
|
||||||
};
|
|
||||||
class AdditionalVocabAdapter {
|
|
||||||
public:
|
|
||||||
AdditionalVocabAdapter(const std::vector<additional_vocab_item> &vocab);
|
|
||||||
std::vector<uint32_t> encode(const std::string& input,
|
|
||||||
BPE& bpemodel,
|
|
||||||
bool encode_special_tokens = true);
|
|
||||||
std::string decode(const std::vector<uint32_t>& tokens,
|
|
||||||
BPE& bpemodel,
|
|
||||||
bool decode_special_tokens = true,
|
|
||||||
bool valid_utf8 = true);
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unordered_map<std::string_view, uint32_t> m_token_to_id;
|
|
||||||
std::unordered_map<uint32_t, std::string_view> m_id_to_token;
|
|
||||||
std::unordered_set<uint32_t> m_special_ids;
|
|
||||||
std::regex m_addedtoken_re;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace bpecpp
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,49 +1,220 @@
|
|||||||
#include "utils.h"
|
#include "utils.h"
|
||||||
#include "tokenizer/bpe.h"
|
|
||||||
#include "tokenizer/mpt_tokenizer_config.h"
|
|
||||||
#include "tokenizer/gptj_tokenizer_config.h"
|
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <stdexcept>
|
|
||||||
|
|
||||||
void get_bpecpp_tokenizer(const TokenizerType ttype, std::unique_ptr<bpecpp::BPE>& bpe, std::unique_ptr<bpecpp::AdditionalVocabAdapter>& av) {
|
void replace(std::string & str, const std::string & needle, const std::string & replacement) {
|
||||||
std::vector<bpecpp::additional_vocab_item> avis;
|
size_t pos = 0;
|
||||||
std::unordered_map<std::string_view, uint32_t> vocab;
|
while ((pos = str.find(needle, pos)) != std::string::npos) {
|
||||||
std::vector<std::pair<std::string_view, std::string_view>> merges;
|
str.replace(pos, needle.length(), replacement);
|
||||||
|
pos += replacement.length();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t tok_id = 0;
|
std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||||
switch (ttype) {
|
std::map<std::string, int32_t> result;
|
||||||
case TokenizerType::MPT_CHAT:
|
|
||||||
avis.push_back({ .id = 50277, .content = std::string_view("<|im_start|>"), .special = true });
|
// read file into string
|
||||||
avis.push_back({ .id = 50278, .content = std::string_view("<|im_end|>"), .special = true });
|
std::string json;
|
||||||
case TokenizerType::MPT:
|
{
|
||||||
for (auto avi_e: mpt_additional_vocab) {
|
std::ifstream ifs(fname);
|
||||||
avis.push_back({avi_e.id, avi_e.content.into(mpt_buffer), avi_e.special});
|
if (!ifs) {
|
||||||
|
fprintf(stderr, "Failed to open %s\n", fname.c_str());
|
||||||
|
exit(1);
|
||||||
}
|
}
|
||||||
for (auto merge: mpt_merges) {
|
|
||||||
merges.push_back({merge.first.into(mpt_buffer), merge.second.into(mpt_buffer)});
|
json = std::string((std::istreambuf_iterator<char>(ifs)),
|
||||||
|
(std::istreambuf_iterator<char>()));
|
||||||
}
|
}
|
||||||
for (auto bufref: mpt_vocab) {
|
|
||||||
vocab.insert({bufref.into(mpt_buffer), tok_id++});
|
if (json[0] != '{') {
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parse json
|
||||||
|
{
|
||||||
|
bool has_key = false;
|
||||||
|
bool in_token = false;
|
||||||
|
|
||||||
|
std::string str_key = "";
|
||||||
|
std::string str_val = "";
|
||||||
|
|
||||||
|
int n = json.size();
|
||||||
|
for (int i = 1; i < n; ++i) {
|
||||||
|
if (!in_token) {
|
||||||
|
if (json[i] == ' ') continue;
|
||||||
|
if (json[i] == '"') {
|
||||||
|
in_token = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (json[i] == '\\' && i+1 < n) {
|
||||||
|
if (has_key == false) {
|
||||||
|
str_key += json[i];
|
||||||
|
} else {
|
||||||
|
str_val += json[i];
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
} else if (json[i] == '"') {
|
||||||
|
if (has_key == false) {
|
||||||
|
has_key = true;
|
||||||
|
++i;
|
||||||
|
while (json[i] == ' ') ++i;
|
||||||
|
++i; // :
|
||||||
|
while (json[i] == ' ') ++i;
|
||||||
|
if (json[i] != '\"') {
|
||||||
|
while (json[i] != ',' && json[i] != '}') {
|
||||||
|
str_val += json[i++];
|
||||||
|
}
|
||||||
|
has_key = false;
|
||||||
|
} else {
|
||||||
|
in_token = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
has_key = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
::replace(str_key, "\\u0120", " " ); // \u0120 -> space
|
||||||
|
::replace(str_key, "\\u010a", "\n"); // \u010a -> new line
|
||||||
|
::replace(str_key, "\\\"", "\""); // \\\" -> "
|
||||||
|
|
||||||
|
try {
|
||||||
|
result[str_key] = std::stoi(str_val);
|
||||||
|
} catch (...) {
|
||||||
|
//fprintf(stderr, "%s: ignoring key '%s' with value '%s'\n", fname.c_str(), str_key.c_str(), str_val.c_str());
|
||||||
|
|
||||||
|
}
|
||||||
|
str_key = "";
|
||||||
|
str_val = "";
|
||||||
|
in_token = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (has_key == false) {
|
||||||
|
str_key += json[i];
|
||||||
|
} else {
|
||||||
|
str_val += json[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) {
|
||||||
|
std::vector<std::string> words;
|
||||||
|
|
||||||
|
// first split the text into words
|
||||||
|
{
|
||||||
|
std::string str = text;
|
||||||
|
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||||
|
|
||||||
|
std::regex re(pat);
|
||||||
|
std::smatch m;
|
||||||
|
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
for (auto x : m) {
|
||||||
|
words.push_back(x);
|
||||||
|
}
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the longest tokens that form the words:
|
||||||
|
std::vector<gpt_vocab::id> tokens;
|
||||||
|
for (const auto & word : words) {
|
||||||
|
if (word.size() == 0) continue;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
int n = word.size();
|
||||||
|
while (i < n) {
|
||||||
|
int j = n;
|
||||||
|
while (j > i) {
|
||||||
|
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||||
|
if (it != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(it->second);
|
||||||
|
i = j;
|
||||||
break;
|
break;
|
||||||
case TokenizerType::GPTJ:
|
|
||||||
for (auto avi_e: gptj_additional_vocab) {
|
|
||||||
avis.push_back({avi_e.id, avi_e.content.into(gptj_buffer), avi_e.special});
|
|
||||||
}
|
}
|
||||||
for (auto merge: gptj_merges) {
|
--j;
|
||||||
merges.push_back({merge.first.into(gptj_buffer), merge.second.into(gptj_buffer)});
|
|
||||||
}
|
|
||||||
for (auto bufref: gptj_vocab) {
|
|
||||||
vocab.insert({bufref.into(gptj_buffer), tok_id++});
|
|
||||||
}
|
}
|
||||||
|
if (i == n) {
|
||||||
break;
|
break;
|
||||||
default:
|
|
||||||
throw std::invalid_argument("invalid tokenizer type");
|
|
||||||
}
|
}
|
||||||
av = std::make_unique<bpecpp::AdditionalVocabAdapter>(avis);
|
if (j == i) {
|
||||||
bpe = std::make_unique<bpecpp::BPE>(vocab, merges);
|
auto sub = word.substr(i, 1);
|
||||||
|
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(vocab.token_to_id.at(sub));
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string regex_escape(const std::string &s) {
|
||||||
|
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
||||||
|
return std::regex_replace(s, metacharacters, "\\$&");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||||
|
// Generate the subpattern from the special_tokens vector if it's not empty
|
||||||
|
if (!vocab.special_tokens.empty()) {
|
||||||
|
std::vector<gpt_vocab::id> out;
|
||||||
|
std::vector<std::string> chunks;
|
||||||
|
std::string str = text;
|
||||||
|
std::string special_tokens_subpattern;
|
||||||
|
for (const auto &token : vocab.special_tokens) {
|
||||||
|
if (!special_tokens_subpattern.empty()) {
|
||||||
|
special_tokens_subpattern += "|";
|
||||||
|
}
|
||||||
|
special_tokens_subpattern += regex_escape(token);
|
||||||
|
}
|
||||||
|
std::regex re(special_tokens_subpattern);
|
||||||
|
std::smatch m;
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
auto tok = vocab.token_to_id.find(m.str());
|
||||||
|
if (tok != vocab.token_to_id.end()) {
|
||||||
|
auto tokid = tok->second;
|
||||||
|
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
|
||||||
|
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
||||||
|
out.push_back(tokid);
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!str.empty()) {
|
||||||
|
auto tokrest = gpt_tokenize_inner(vocab, str);
|
||||||
|
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
} else {
|
||||||
|
return gpt_tokenize_inner(vocab, text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||||
|
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
vocab.token_to_id = ::json_parse(fname);
|
||||||
|
|
||||||
|
for (const auto & kv : vocab.token_to_id) {
|
||||||
|
vocab.id_to_token[kv.second] = kv.first;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("%s: vocab size = %d\n", __func__, (int) vocab.token_to_id.size());
|
||||||
|
|
||||||
|
// print the vocabulary
|
||||||
|
//for (auto kv : vocab.token_to_id) {
|
||||||
|
// printf("'%s' -> %d\n", kv.first.data(), kv.second);
|
||||||
|
//}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include "tokenizer/bpe.h"
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// CLI argument parsing
|
// CLI argument parsing
|
||||||
@ -52,6 +51,26 @@ struct gpt_vocab {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
||||||
|
|
||||||
|
// poor-man's JSON parsing
|
||||||
|
std::map<std::string, int32_t> json_parse(const std::string & fname);
|
||||||
|
|
||||||
|
// split text into tokens
|
||||||
|
//
|
||||||
|
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
|
||||||
|
//
|
||||||
|
// Regex (Python):
|
||||||
|
// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||||
|
//
|
||||||
|
// Regex (C++):
|
||||||
|
// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"
|
||||||
|
//
|
||||||
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text);
|
||||||
|
|
||||||
|
// load the tokens from encoder.json
|
||||||
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||||
|
|
||||||
// sample next token given probabilities for each embedding
|
// sample next token given probabilities for each embedding
|
||||||
//
|
//
|
||||||
// - consider only the top K tokens
|
// - consider only the top K tokens
|
||||||
@ -70,9 +89,3 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
|||||||
double temp,
|
double temp,
|
||||||
float repeat_penalty,
|
float repeat_penalty,
|
||||||
std::mt19937 & rng);
|
std::mt19937 & rng);
|
||||||
|
|
||||||
enum TokenizerType {
|
|
||||||
MPT, MPT_CHAT, GPTJ
|
|
||||||
};
|
|
||||||
|
|
||||||
void get_bpecpp_tokenizer(const TokenizerType ttype, std::unique_ptr<bpecpp::BPE>& bpe, std::unique_ptr<bpecpp::AdditionalVocabAdapter>& av);
|
|
||||||
|
Loading…
Reference in New Issue
Block a user