Style improvements (#1957)

This commit is contained in:
oobabooga 2023-05-09 22:49:39 -03:00 committed by GitHub
parent 334486f527
commit 3913155c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 50 deletions

View File

@ -5,7 +5,7 @@ import sys
try: try:
import websockets import websockets
except ImportError: except ImportError:
print("Websockets package not found. Make sure it's installed.") print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws:// # For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5005' HOST = 'localhost:5005'
@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss:// # For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' # URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(context): async def run(context):
# Note: the selected defaults change from time to time. # Note: the selected defaults change from time to time.
request = { request = {
@ -42,7 +43,7 @@ async def run(context):
async with websockets.connect(URI, ping_interval=None) as websocket: async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request)) await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply yield context # Remove this if you just want to see the reply
while True: while True:
incoming_data = await websocket.recv() incoming_data = await websocket.recv()
@ -58,7 +59,7 @@ async def run(context):
async def print_response_stream(prompt): async def print_response_stream(prompt):
async for response in run(prompt): async for response in run(prompt):
print(response, end='') print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
# For reverse-proxied streaming, the remote will likely host with ssl - https:// # For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' # URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(prompt): def run(prompt):
request = { request = {
'prompt': prompt, 'prompt': prompt,
@ -37,6 +38,7 @@ def run(prompt):
result = response.json()['results'][0]['text'] result = response.json()['results'][0]['text']
print(prompt + result) print(prompt + result)
if __name__ == '__main__': if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)" prompt = "In order to make homemade bread, follow these steps:\n1)"
run(prompt) run(prompt)

View File

@ -2,11 +2,10 @@ import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):

View File

@ -5,6 +5,7 @@ from modules import shared
BLOCKING_PORT = 5000 BLOCKING_PORT = 5000
STREAMING_PORT = 5005 STREAMING_PORT = 5005
def setup(): def setup():
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api) blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api) streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)

View File

@ -1,12 +1,12 @@
import json
import asyncio import asyncio
from websockets.server import serve import json
from threading import Thread from threading import Thread
from modules import shared from websockets.server import serve
from modules.text_generation import generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import generate_reply
PATH = '/api/v1/stream' PATH = '/api/v1/stream'

View File

@ -1,6 +1,7 @@
import gradio as gr
import os import os
import gradio as gr
# get the current directory of the script # get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))

View File

@ -1,6 +1,8 @@
import gradio as gr
import logging import logging
import gradio as gr
def ui(): def ui():
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead") gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead") logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View File

@ -6,10 +6,11 @@ from io import BytesIO
from typing import Any, List, Optional from typing import Any, List, Optional
import torch import torch
from PIL import Image
from extensions.multimodal.pipeline_loader import load_pipeline from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared from modules import shared
from modules.text_generation import encode, get_max_prompt_length from modules.text_generation import encode, get_max_prompt_length
from PIL import Image
@dataclass @dataclass

View File

@ -7,6 +7,7 @@ from io import BytesIO
import gradio as gr import gradio as gr
import torch import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared from modules import shared

View File

@ -1,11 +1,12 @@
import base64 import base64
import json import json
import numpy as np
import os import os
import time import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
import numpy as np
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
ascii_string = encoded_bytes.decode('ascii') ascii_string = encoded_bytes.decode('ascii')
return ascii_string return ascii_string
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path.startswith('/v1/models'): if self.path.startswith('/v1/models'):
@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler):
"created": created_time, "created": created_time,
"model": model, # TODO: add Lora info? "model": model, # TODO: add Lora info?
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": "stop", "finish_reason": "stop",
}], }],
"usage": { "usage": {
"prompt_tokens": token_count, "prompt_tokens": token_count,

View File

@ -6,12 +6,13 @@ from datetime import date
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.shared as shared
import requests import requests
import torch import torch
from modules.models import reload_model, unload_model
from PIL import Image from PIL import Image
import modules.shared as shared
from modules.models import reload_model, unload_model
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui # parameters which can be customized in settings.json of webui
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
@ -122,7 +124,6 @@ def input_modifier(string):
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params global params
if params['manage_VRAM']: if params['manage_VRAM']:
@ -259,6 +260,7 @@ def SD_api_address_update(address):
return gr.Textbox.update(label=msg) return gr.Textbox.update(label=msg)
def ui(): def ui():
# Gradio elements # Gradio elements
@ -290,12 +292,11 @@ def ui():
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box") cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
with gr.Column() as hr_options: with gr.Column() as hr_options:
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces') restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options: with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
address.change(lambda x: params.update({"address": filter_address(x)}), address, None) address.change(lambda x: params.update({"address": filter_address(x)}), address, None)

View File

@ -4,6 +4,7 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
from extensions.silero_tts import tts_preprocessor from extensions.silero_tts import tts_preprocessor
from modules import chat, shared from modules import chat, shared
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
@ -216,4 +217,4 @@ def ui():
# Play preview # Play preview
preview_text.submit(voice_preview, preview_text, preview_audio) preview_text.submit(voice_preview, preview_text, preview_audio)
preview_play.click(voice_preview, preview_text, preview_audio) preview_play.click(voice_preview, preview_text, preview_audio)

View File

@ -2,7 +2,6 @@ import time
from pathlib import Path from pathlib import Path
import torch import torch
import tts_preprocessor import tts_preprocessor
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)

View File

@ -69,7 +69,7 @@ def remove_surrounded_chars(string):
# first this expression will check if there is a string nested exclusively between a alt= # first this expression will check if there is a string nested exclusively between a alt=
# and a style= string. This would correspond to only a the alt text of an embedded image # and a style= string. This would correspond to only a the alt text of an embedded image
# If it matches it will only keep that part as the string, and rend it for further processing # If it matches it will only keep that part as the string, and rend it for further processing
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any # Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string' # asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL): if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL):
m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL) m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL)

View File

@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results) n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result)) return list(map(lambda x: int(x[2:]), result))
def clear(self): def clear(self):
self.collection.delete(ids=self.ids) self.collection.delete(ids=self.ids)
@ -162,13 +162,13 @@ def input_modifier(string):
def custom_generate_chat_prompt(user_input, state, **kwargs): def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(shared.history['internal']) > 2 and user_input != '': if len(shared.history['internal']) > 2 and user_input != '':
chunks = [] chunks = []
for i in range(len(shared.history['internal'])-1): for i in range(len(shared.history['internal']) - 1):
chunks.append('\n'.join(shared.history['internal'][i])) chunks.append('\n'.join(shared.history['internal'][i]))
add_chunks_to_collector(chunks) add_chunks_to_collector(chunks)
query = '\n'.join(shared.history['internal'][-1] + [user_input]) query = '\n'.join(shared.history['internal'][-1] + [user_input])
try: try:
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1) best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
# Sort the history by relevance instead of by chronological order, # Sort the history by relevance instead of by chronological order,
# except for the latest message # except for the latest message
@ -226,7 +226,7 @@ def ui():
## Chat mode ## Chat mode
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair. In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
That is, the prompt will include (starting from the end): That is, the prompt will include (starting from the end):

View File

@ -1,5 +1,6 @@
import gradio as gr import gradio as gr
import speech_recognition as sr import speech_recognition as sr
from modules import shared from modules import shared
input_hijack = { input_hijack = {

View File

@ -24,13 +24,12 @@ class RWKVModel:
@classmethod @classmethod
def from_pretrained(self, path, dtype="fp16", device="cuda"): def from_pretrained(self, path, dtype="fp16", device="cuda"):
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
if shared.args.rwkv_strategy is None: if shared.args.rwkv_strategy is None:
model = RWKV(model=str(path), strategy=f'{device} {dtype}') model = RWKV(model=str(path), strategy=f'{device} {dtype}')
else: else:
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy) model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, str(tokenizer_path))
pipeline = PIPELINE(model, str(tokenizer_path))
result = self() result = self()
result.pipeline = pipeline result.pipeline = pipeline
result.model = model result.model = model
@ -83,7 +82,6 @@ class RWKVModel:
out = self.cached_output_logits out = self.cached_output_logits
for i in range(token_count): for i in range(token_count):
# forward # forward
tokens = self.pipeline.encode(ctx) if i == 0 else [token] tokens = self.pipeline.encode(ctx) if i == 0 else [token]
while len(tokens) > 0: while len(tokens) > 0:
@ -91,35 +89,38 @@ class RWKVModel:
tokens = tokens[args.chunk_len:] tokens = tokens[args.chunk_len:]
# cache the model state after scanning the context # cache the model state after scanning the context
# we don't cache the state after processing our own generated tokens because # we don't cache the state after processing our own generated tokens because
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model # the output string might be post-processed arbitrarily. Therefore, what's fed into the model
# on the next round of chat might be slightly different what what it output on the previous round # on the next round of chat might be slightly different what what it output on the previous round
if i == 0: if i == 0:
self.cached_context += ctx self.cached_context += ctx
self.cached_model_state = copy.deepcopy(state) self.cached_model_state = copy.deepcopy(state)
self.cached_output_logits = copy.deepcopy(out) self.cached_output_logits = copy.deepcopy(out)
# adjust probabilities # adjust probabilities
for n in args.token_ban: for n in args.token_ban:
out[n] = -float('inf') out[n] = -float('inf')
for n in occurrence: for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
# sampler # sampler
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop: if token in args.token_stop:
break break
all_tokens += [token] all_tokens += [token]
if token not in occurrence: if token not in occurrence:
occurrence[token] = 1 occurrence[token] = 1
else: else:
occurrence[token] += 1 occurrence[token] += 1
# output # output
tmp = self.pipeline.decode([token]) tmp = self.pipeline.decode([token])
if '\ufffd' not in tmp: # is valid utf-8 string? if '\ufffd' not in tmp: # is valid utf-8 string?
if callback: if callback:
callback(tmp) callback(tmp)
out_str += tmp out_str += tmp
return out_str return out_str
@ -133,7 +134,6 @@ class RWKVTokenizer:
def from_pretrained(self, path): def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json" tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(str(tokenizer_path)) tokenizer = Tokenizer.from_file(str(tokenizer_path))
result = self() result = self()
result.tokenizer = tokenizer result.tokenizer = tokenizer
return result return result

View File

@ -1,5 +1,4 @@
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir): def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
''' '''
DeepSpeed configration DeepSpeed configration
https://huggingface.co/docs/transformers/main_classes/deepspeed https://huggingface.co/docs/transformers/main_classes/deepspeed

View File

@ -20,6 +20,8 @@ def load_past_evaluations():
return df return df
else: else:
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment']) return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
past_evaluations = load_past_evaluations() past_evaluations = load_past_evaluations()

View File

@ -7,7 +7,6 @@ import gradio as gr
import extensions import extensions
import modules.shared as shared import modules.shared as shared
state = {} state = {}
available_extensions = [] available_extensions = []
setup_called = set() setup_called = set()
@ -91,7 +90,7 @@ def _apply_state_modifier_extensions(state):
state = getattr(extension, "state_modifier")(state) state = getattr(extension, "state_modifier")(state)
return state return state
# Extension functions that override the default tokenizer output - currently only the first one will work # Extension functions that override the default tokenizer output - currently only the first one will work
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
@ -108,7 +107,7 @@ def _apply_custom_tokenized_length(prompt):
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, 'custom_tokenized_length'): if hasattr(extension, 'custom_tokenized_length'):
return getattr(extension, 'custom_tokenized_length')(prompt) return getattr(extension, 'custom_tokenized_length')(prompt)
return None return None

View File

@ -1,6 +1,8 @@
# Copied from https://stackoverflow.com/a/1336640 # Copied from https://stackoverflow.com/a/1336640
import logging import logging
import platform
def add_coloring_to_emit_windows(fn): def add_coloring_to_emit_windows(fn):
# add methods we need to the class # add methods we need to the class
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
def _set_color(self, code): def _set_color(self, code):
import ctypes import ctypes
# Constants from the Windows API # Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11 self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
return new return new
import platform
if platform.system() == 'Windows': if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color # Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)

View File

@ -161,10 +161,10 @@ def load_model(model_name):
# Custom # Custom
else: else:
params = { params = {
"low_cpu_mem_usage": True, "low_cpu_mem_usage": True,
"trust_remote_code": trust_remote_code "trust_remote_code": trust_remote_code
} }
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
shared.args.cpu = True shared.args.cpu = True
@ -288,7 +288,7 @@ def load_soft_prompt(name):
logging.info(f"{field}: {', '.join(j[field])}") logging.info(f"{field}: {', '.join(j[field])}")
else: else:
logging.info(f"{field}: {j[field]}") logging.info(f"{field}: {j[field]}")
logging.info() logging.info()
tensor = np.load('tensor.npy') tensor = np.load('tensor.npy')
Path('tensor.npy').unlink() Path('tensor.npy').unlink()

View File

@ -377,7 +377,7 @@ def create_model_menus():
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):