mirror of
https://github.com/tloen/alpaca-lora.git
synced 2024-10-01 01:05:56 -04:00
Support streaming output on generate (#263)
This commit is contained in:
parent
8e51ebf3f4
commit
e2ed209d3b
48
generate.py
48
generate.py
@ -8,6 +8,7 @@ import transformers
|
||||
from peft import PeftModel
|
||||
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from utils.callbacks import Iteratorize, Stream
|
||||
from utils.prompter import Prompter
|
||||
|
||||
if torch.cuda.is_available():
|
||||
@ -91,6 +92,7 @@ def main(
|
||||
top_k=40,
|
||||
num_beams=4,
|
||||
max_new_tokens=128,
|
||||
stream_output=False,
|
||||
**kwargs,
|
||||
):
|
||||
prompt = prompter.generate_prompt(instruction, input)
|
||||
@ -103,6 +105,47 @@ def main(
|
||||
num_beams=num_beams,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
generate_params = {
|
||||
"input_ids": input_ids,
|
||||
"generation_config": generation_config,
|
||||
"return_dict_in_generate": True,
|
||||
"output_scores": True,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
}
|
||||
|
||||
if stream_output:
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator,
|
||||
# from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
|
||||
|
||||
def generate_with_callback(callback=None, **kwargs):
|
||||
kwargs.setdefault(
|
||||
"stopping_criteria", transformers.StoppingCriteriaList()
|
||||
)
|
||||
kwargs["stopping_criteria"].append(
|
||||
Stream(callback_func=callback)
|
||||
)
|
||||
with torch.no_grad():
|
||||
model.generate(**kwargs)
|
||||
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(
|
||||
generate_with_callback, kwargs, callback=None
|
||||
)
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
# new_tokens = len(output) - len(input_ids[0])
|
||||
decoded_output = tokenizer.decode(output)
|
||||
|
||||
if output[-1] in [tokenizer.eos_token_id]:
|
||||
break
|
||||
|
||||
yield prompter.get_response(decoded_output)
|
||||
return # early return for stream_output
|
||||
|
||||
# Without streaming
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(
|
||||
input_ids=input_ids,
|
||||
@ -113,7 +156,7 @@ def main(
|
||||
)
|
||||
s = generation_output.sequences[0]
|
||||
output = tokenizer.decode(s)
|
||||
return prompter.get_response(output)
|
||||
yield prompter.get_response(output)
|
||||
|
||||
gr.Interface(
|
||||
fn=evaluate,
|
||||
@ -139,6 +182,7 @@ def main(
|
||||
gr.components.Slider(
|
||||
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
|
||||
),
|
||||
gr.components.Checkbox(label="Stream output"),
|
||||
],
|
||||
outputs=[
|
||||
gr.inputs.Textbox(
|
||||
@ -148,7 +192,7 @@ def main(
|
||||
],
|
||||
title="🦙🌲 Alpaca-LoRA",
|
||||
description="Alpaca-LoRA is a 7B-parameter LLaMA model finetuned to follow instructions. It is trained on the [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset and makes use of the Huggingface LLaMA implementation. For more information, please visit [the project's website](https://github.com/tloen/alpaca-lora).", # noqa: E501
|
||||
).launch(server_name="0.0.0.0", share=share_gradio)
|
||||
).queue().launch(server_name="0.0.0.0", share=share_gradio)
|
||||
# Old testing code follows.
|
||||
|
||||
"""
|
||||
|
@ -4,4 +4,10 @@
|
||||
|
||||
Prompter class, a template manager.
|
||||
|
||||
`from utils.prompter import Prompter`
|
||||
`from utils.prompter import Prompter`
|
||||
|
||||
## callbacks.py
|
||||
|
||||
Helpers to support streaming generate output.
|
||||
|
||||
`from utils.callbacks import Iteratorize, Stream`
|
||||
|
75
utils/callbacks.py
Normal file
75
utils/callbacks.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""
|
||||
Helpers to support streaming generate output.
|
||||
Borrowed from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/callbacks.py
|
||||
"""
|
||||
|
||||
import gc
|
||||
import traceback
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
||||
class Stream(transformers.StoppingCriteria):
|
||||
def __init__(self, callback_func=None):
|
||||
self.callback_func = callback_func
|
||||
|
||||
def __call__(self, input_ids, scores) -> bool:
|
||||
if self.callback_func is not None:
|
||||
self.callback_func(input_ids[0])
|
||||
return False
|
||||
|
||||
|
||||
class Iteratorize:
|
||||
|
||||
"""
|
||||
Transforms a function that takes a callback
|
||||
into a lazy iterator (generator).
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs={}, callback=None):
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.kwargs = kwargs
|
||||
self.stop_now = False
|
||||
|
||||
def _callback(val):
|
||||
if self.stop_now:
|
||||
raise ValueError
|
||||
self.q.put(val)
|
||||
|
||||
def gentask():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
except:
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
self.q.put(self.sentinel)
|
||||
if self.c_callback:
|
||||
self.c_callback(ret)
|
||||
|
||||
self.thread = Thread(target=gentask)
|
||||
self.thread.start()
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
obj = self.q.get(True, None)
|
||||
if obj is self.sentinel:
|
||||
raise StopIteration
|
||||
else:
|
||||
return obj
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.stop_now = True
|
Loading…
Reference in New Issue
Block a user