Move the example dialogue to the chat history, and keep it hidden.

This greatly improves the performance of text generation, as
histories can be quite long. It also makes more sense to implement
it this way.
This commit is contained in:
oobabooga 2023-01-21 02:48:06 -03:00
parent 3f2c1e7170
commit 990ee54ddd
3 changed files with 39 additions and 8 deletions

View File

@ -139,9 +139,9 @@ Optionally, you can use the following command-line flags:
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
| `--disk-cache-dir DISK_CACHE_DIR` | Directory which you want the disk cache to load to. |
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--gpu-memory GPU_MEMORY` | Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99 GiB.|
| `--no-stream` | Don't stream the text output in real time. This slightly improves the text generation performance.|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
| `--listen` | Make the web UI reachable from your local network.|

View File

@ -6,6 +6,7 @@ This is a library for formatting GPT-4chan and chat outputs as nice HTML.
import re
from pathlib import Path
import copy
def generate_basic_html(s):
s = '\n'.join([f'<p style="margin-bottom: 20px">{line}</p>' for line in s.split('\n')])
@ -160,7 +161,7 @@ def generate_4chan_html(f):
return output
def generate_chat_html(history, name1, name2, character):
def generate_chat_html(_history, name1, name2, character):
css = """
.chat {
margin-left: auto;
@ -233,6 +234,13 @@ def generate_chat_html(history, name1, name2, character):
img = f'<img src="file/{i}">'
break
history = copy.deepcopy(_history)
for i in range(len(history)):
if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]:
history[i][0] = history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
history = history[i:]
break
for i,_row in enumerate(history[::-1]):
row = _row.copy()
row[0] = re.sub(r"[\\]*\*", r"*", row[0])

View File

@ -26,9 +26,9 @@ parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
parser.add_argument('--disk-cache-dir', type=str, help='Directory which you want the disk cache to load to.')
parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".')
parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number.')
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99 GiB.')
parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time. This slightly improves the text generation performance.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.')
@ -262,6 +262,7 @@ if args.chat or args.cai_chat:
rows.pop(1)
question = ''.join(rows)
question = question.replace('<|BEGIN-VISIBLE-CHAT|>', '')
return question
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
@ -336,6 +337,26 @@ if args.chat or args.cai_chat:
global history
history = json.loads(file.decode('utf-8'))['data']
def tokenize_example_dialogue(dialogue, name1, name2):
dialogue = re.sub('<START>', '', dialogue)
dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
idx = [m.start() for m in re.finditer(f"(^|\n)({name1}|{name2}):", dialogue)]
messages = []
for i in range(len(idx)-1):
messages.append(dialogue[idx[i]:idx[i+1]].strip())
history = []
entry = ['', '']
for i in messages:
if i.startswith(f'{name1}:'):
entry[0] = i[len(f'{name1}:'):].strip()
elif i.startswith(f'{name2}:'):
entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
history.append(entry)
entry = ['', '']
return history
def load_character(_character, name1, name2):
global history, character
context = ""
@ -351,9 +372,11 @@ if args.chat or args.cai_chat:
context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n"
if 'example_dialogue' in data and data['example_dialogue'] != '':
context += f"{data['example_dialogue'].strip()}\n"
if 'char_greeting' in data:
history = [['', data['char_greeting']]]
history = tokenize_example_dialogue(data['example_dialogue'], name1, name2)
if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
history += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
else:
history += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
else:
character = None
context = settings['context_pygmalion']