assistant: space fix, system: prompt fix (#2219)

This commit is contained in:
matatonic 2023-05-20 22:32:34 -04:00 committed by GitHub
parent 05593a7834
commit 78b2478d9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -243,9 +243,9 @@ class Handler(BaseHTTPRequestHandler):
messages = body['messages'] messages = body['messages']
system_msg = '' # You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date} system_msgs = []
if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this. if 'prompt' in body: # Maybe they sent both? This is not documented in the API, but some clients seem to do this.
system_msg = body['prompt'] system_msgs = [ body['prompt'] ]
chat_msgs = [] chat_msgs = []
@ -254,10 +254,15 @@ class Handler(BaseHTTPRequestHandler):
content = m['content'] content = m['content']
# name = m.get('name', 'user') # name = m.get('name', 'user')
if role == 'system': if role == 'system':
system_msg += content system_msgs.extend([content.strip()])
else: else:
chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed? chat_msgs.extend([f"\n{role}: {content.strip()}"]) # Strip content? linefeed?
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
system_msg = 'You are assistant, a large language model. Answer as concisely as possible.'
if system_msgs:
system_msg = '\n'.join(system_msgs)
system_token_count = len(encode(system_msg)[0]) system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count remaining_tokens = req_params['truncation_length'] - req_params['max_new_tokens'] - system_token_count
chat_msg = '' chat_msg = ''
@ -277,9 +282,9 @@ class Handler(BaseHTTPRequestHandler):
print(f"truncating chat messages, dropping {len(chat_msgs)} messages.") print(f"truncating chat messages, dropping {len(chat_msgs)} messages.")
if system_msg: if system_msg:
prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant: ' prompt = 'system: ' + system_msg + '\n' + chat_msg + '\nassistant:'
else: else:
prompt = chat_msg + '\nassistant: ' prompt = chat_msg + '\nassistant:'
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
@ -396,6 +401,11 @@ class Handler(BaseHTTPRequestHandler):
"finish_reason": None, "finish_reason": None,
}], }],
} }
# strip extra leading space off new generated content
if len_seen == 0 and new_content[0] == ' ':
new_content = new_content[1:]
if stream_object_type == 'text_completion.chunk': if stream_object_type == 'text_completion.chunk':
chunk[resp_list][0]['text'] = new_content chunk[resp_list][0]['text'] = new_content
else: else:
@ -432,9 +442,15 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))
# Finished if streaming. # Finished if streaming.
if debug: if debug:
if answer and answer[0] == ' ':
answer = answer[1:]
print({'response': answer}) print({'response': answer})
return return
# strip extra leading space off new generated content
if answer and answer[0] == ' ':
answer = answer[1:]
if debug: if debug:
print({'response': answer}) print({'response': answer})