diff --git a/server.py b/server.py index 6c3c3039..8f658b43 100644 --- a/server.py +++ b/server.py @@ -69,7 +69,7 @@ def fix_galactica(s): s = s.replace(r'$$', r'$') return s -def generate_reply(question, temperature, max_length, inference_settings, selected_model): +def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None): global model, tokenizer, model_name, loaded_preset, preset if selected_model != model_name: @@ -86,7 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select torch.cuda.empty_cache() input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda() - output = eval(f"model.generate(input_ids, {preset}).cuda()") + if eos_token is None: + output = eval(f"model.generate(input_ids, {preset}).cuda()") + else: + n = tokenizer.encode(eos_token, return_tensors='pt')[0][1] + output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}).cuda()") reply = tokenizer.decode(output[0], skip_special_tokens=True) if model_name.lower().startswith('galactica'): @@ -159,7 +163,7 @@ elif args.chat: question += f"{name1}: {text.strip()}\n" question += f"{name2}:" - reply = generate_reply(question, temperature, max_length, inference_settings, selected_model)[0] + reply = generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token='\n')[0] reply = reply[len(question):].split('\n')[0].strip() history.append((text, reply)) return history @@ -175,7 +179,7 @@ elif args.chat: with gr.Column(): with gr.Row(equal_height=True): with gr.Column(): - length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=100) + length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200) preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset') with gr.Column(): temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7) @@ -203,7 +207,7 @@ else: with gr.Column(): textbox = gr.Textbox(value=default_text, lines=15, label='Input') temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7) - length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=100) + length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200) preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset') model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model') btn = gr.Button("Generate")