From 93fa9bbe010ad6f0df657894fe1ea30ac3b328cb Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 19 Jan 2023 10:43:05 -0300 Subject: [PATCH] Clean up the streaming implementation --- README.md | 1 + server.py | 67 +++++++++++++++++++++++++++---------------------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 0075bde0..636db8a4 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ Optionally, you can use the following command-line flags: | `--load-in-8bit` | Load the model with 8-bit precision.| | `--max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. | | `--no-listen` | Make the web UI unreachable from your local network.| +| `--no-stream` | Don't stream the text output in real time. This slightly improves the text generation performance.| | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.| ## Presets diff --git a/server.py b/server.py index 3adf3478..97acb1dc 100644 --- a/server.py +++ b/server.py @@ -25,7 +25,7 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.') parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.') parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.') -parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') +parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This slightly improves the text generation performance.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.') args = parser.parse_args() @@ -125,6 +125,21 @@ def encode(prompt, tokens): input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens) return input_ids +def decode(output_ids): + reply = tokenizer.decode(output_ids, skip_special_tokens=True) + reply = reply.replace(r'<|endoftext|>', '') + return reply + +def formatted_outputs(reply, model_name): + if model_name.lower().startswith('galactica'): + reply = fix_galactica(reply) + return reply, reply, generate_basic_html(reply) + elif model_name.lower().startswith('gpt4chan'): + reply = fix_gpt4chan(reply) + return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) + else: + return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None): global model, tokenizer, model_name, loaded_preset, preset @@ -141,43 +156,27 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok loaded_preset = inference_settings cuda = "" if args.cpu else ".cuda()" - if not args.no_stream: + n = None if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1] + + # Generate the entire reply at once + if args.no_stream: + input_ids = encode(question, tokens) + output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") + reply = decode(output[0]) + yield formatted_outputs(reply, model_name) + + # Generate the reply 1 token at a time + else: input_ids = encode(question, 1) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') for i in range(tokens): output = eval(f"model.generate(input_ids, {preset}){cuda}") - reply = tokenizer.decode(output[0], skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') + reply = decode(output[0]) if eos_token is not None and reply[-1] == eos_token: break - if model_name.lower().startswith('galactica'): - reply = fix_galactica(reply) - yield reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith('gpt4chan'): - reply = fix_gpt4chan(reply) - yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) - else: - yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) + yield formatted_outputs(reply, model_name) input_ids = output - else: - input_ids = encode(question, tokens) - if eos_token is None: - output = eval(f"model.generate(input_ids, {preset}){cuda}") - else: - n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1] - output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}") - reply = tokenizer.decode(output[0], skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') - if model_name.lower().startswith('galactica'): - reply = fix_galactica(reply) - yield reply, reply, generate_basic_html(reply) - elif model_name.lower().startswith('gpt4chan'): - reply = fix_gpt4chan(reply) - yield reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply) - else: - yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply) - # Choosing the default model if args.model is not None: @@ -206,7 +205,6 @@ else: description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n" css = ".my-4 {margin-top: 0} .py-6 {padding-top: 2.5rem}" - if args.chat or args.cai_chat: history = [] @@ -257,20 +255,21 @@ if args.chat or args.cai_chat: reply = clean_chat_message(reply) history[-1] = [text, reply] + if next_character_found: + break # Prevent the chat log from flashing if something like "\nYo" is generated just # before "\nYou:" is completed tmp = f"\n{name1}:" next_character_substring_found = False - for j in range(1, len(tmp)+1): + for j in range(1, len(tmp)): if reply[-j:] == tmp[:j]: next_character_substring_found = True if not next_character_substring_found: yield history - if next_character_found: - break + yield history def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check): for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):