diff --git a/modules/RWKV.py b/modules/RWKV.py index bb6bab50..1b0078ad 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -87,7 +87,9 @@ class RWKVModel: while len(tokens) > 0: out, state = self.model.forward(tokens[:args.chunk_len], state) tokens = tokens[args.chunk_len:] - + if i == 0: + begin_token= len(all_tokens) + last_token_posi=begin_token # cache the model state after scanning the context # we don't cache the state after processing our own generated tokens because # the output string might be post-processed arbitrarily. Therefore, what's fed into the model @@ -116,13 +118,13 @@ class RWKVModel: occurrence[token] += 1 # output - tmp = self.pipeline.decode([token]) + tmp = self.pipeline.decode(all_tokens[last_token_posi:]) if '\ufffd' not in tmp: # is valid utf-8 string? if callback: callback(tmp) - + out_str += tmp - + last_token_posi = begin_token + i + 1 return out_str