Fix: Mirostat fails on models split across multiple GPUs

This commit is contained in:
Forkoz 2023-08-05 16:45:47 +00:00 committed by GitHub
parent 23055b21ee
commit 9dcb37e8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -104,7 +104,7 @@ class MirostatLogitsWarper(LogitsWarper):
break
# Normalize the probabilities of the remaining words
prob_topk = torch.softmax(sorted_logits, dim=0)
prob_topk = torch.softmax(sorted_logits, dim=0).to('cuda')
prev_i = torch.multinomial(prob_topk, num_samples=1, replacement=True).to('cuda')