non-llama: explicitly greedy sampling for temp<=0 (#901)

copied directly from llama.cpp - without this temp=0.0 will just
scale all the logits to infinity and give bad output
This commit is contained in:
Aaron Miller 2023-06-08 11:08:30 -07:00 committed by GitHub
parent b14953e136
commit 47fbc0e309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -232,6 +232,19 @@ gpt_vocab::id gpt_sample_top_k_top_p(
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data();
if (temp <= 0) {
// select the token with the highest logit directly
float max_logit = plogits[0];
gpt_vocab::id max_id = 0;
for (int i = 1; i < n_logits; ++i) {
if (plogits[i] > max_logit) {
max_logit = plogits[i];
max_id = i;
}
}
return max_id;
}
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);