Style improvements (#1957)

This commit is contained in:
oobabooga 2023-05-09 22:49:39 -03:00 committed by GitHub
parent 334486f527
commit 3913155c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 50 deletions

View File

@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss:// # For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' # URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(context): async def run(context):
# Note: the selected defaults change from time to time. # Note: the selected defaults change from time to time.
request = { request = {

View File

@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
# For reverse-proxied streaming, the remote will likely host with ssl - https:// # For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' # URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(prompt): def run(prompt):
request = { request = {
'prompt': prompt, 'prompt': prompt,
@ -37,6 +38,7 @@ def run(prompt):
result = response.json()['results'][0]['text'] result = response.json()['results'][0]['text']
print(prompt + result) print(prompt + result)
if __name__ == '__main__': if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)" prompt = "In order to make homemade bread, follow these steps:\n1)"
run(prompt) run(prompt)

View File

@ -2,11 +2,10 @@ import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):

View File

@ -5,6 +5,7 @@ from modules import shared
BLOCKING_PORT = 5000 BLOCKING_PORT = 5000
STREAMING_PORT = 5005 STREAMING_PORT = 5005
def setup(): def setup():
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api) blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api) streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)

View File

@ -1,12 +1,12 @@
import json
import asyncio import asyncio
from websockets.server import serve import json
from threading import Thread from threading import Thread
from modules import shared from websockets.server import serve
from modules.text_generation import generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import generate_reply
PATH = '/api/v1/stream' PATH = '/api/v1/stream'

View File

@ -1,6 +1,7 @@
import gradio as gr
import os import os
import gradio as gr
# get the current directory of the script # get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))

View File

@ -1,6 +1,8 @@
import gradio as gr
import logging import logging
import gradio as gr
def ui(): def ui():
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead") gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead") logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View File

@ -6,10 +6,11 @@ from io import BytesIO
from typing import Any, List, Optional from typing import Any, List, Optional
import torch import torch
from PIL import Image
from extensions.multimodal.pipeline_loader import load_pipeline from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared from modules import shared
from modules.text_generation import encode, get_max_prompt_length from modules.text_generation import encode, get_max_prompt_length
from PIL import Image
@dataclass @dataclass

View File

@ -7,6 +7,7 @@ from io import BytesIO
import gradio as gr import gradio as gr
import torch import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared from modules import shared

View File

@ -1,11 +1,12 @@
import base64 import base64
import json import json
import numpy as np
import os import os
import time import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
import numpy as np
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
ascii_string = encoded_bytes.decode('ascii') ascii_string = encoded_bytes.decode('ascii')
return ascii_string return ascii_string
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path.startswith('/v1/models'): if self.path.startswith('/v1/models'):

View File

@ -6,12 +6,13 @@ from datetime import date
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.shared as shared
import requests import requests
import torch import torch
from modules.models import reload_model, unload_model
from PIL import Image from PIL import Image
import modules.shared as shared
from modules.models import reload_model, unload_model
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui # parameters which can be customized in settings.json of webui
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
@ -122,7 +124,6 @@ def input_modifier(string):
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params global params
if params['manage_VRAM']: if params['manage_VRAM']:
@ -259,6 +260,7 @@ def SD_api_address_update(address):
return gr.Textbox.update(label=msg) return gr.Textbox.update(label=msg)
def ui(): def ui():
# Gradio elements # Gradio elements
@ -296,7 +298,6 @@ def ui():
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend
address.change(lambda x: params.update({"address": filter_address(x)}), address, None) address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
mode.select(lambda x: params.update({"mode": x}), mode, None) mode.select(lambda x: params.update({"mode": x}), mode, None)

View File

@ -4,6 +4,7 @@ from pathlib import Path
import gradio as gr import gradio as gr
import torch import torch
from extensions.silero_tts import tts_preprocessor from extensions.silero_tts import tts_preprocessor
from modules import chat, shared from modules import chat, shared
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper

View File

@ -2,7 +2,6 @@ import time
from pathlib import Path from pathlib import Path
import torch import torch
import tts_preprocessor import tts_preprocessor
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)

View File

@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results) n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result)) return list(map(lambda x: int(x[2:]), result))
def clear(self): def clear(self):
self.collection.delete(ids=self.ids) self.collection.delete(ids=self.ids)
@ -162,13 +162,13 @@ def input_modifier(string):
def custom_generate_chat_prompt(user_input, state, **kwargs): def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(shared.history['internal']) > 2 and user_input != '': if len(shared.history['internal']) > 2 and user_input != '':
chunks = [] chunks = []
for i in range(len(shared.history['internal'])-1): for i in range(len(shared.history['internal']) - 1):
chunks.append('\n'.join(shared.history['internal'][i])) chunks.append('\n'.join(shared.history['internal'][i]))
add_chunks_to_collector(chunks) add_chunks_to_collector(chunks)
query = '\n'.join(shared.history['internal'][-1] + [user_input]) query = '\n'.join(shared.history['internal'][-1] + [user_input])
try: try:
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1) best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
# Sort the history by relevance instead of by chronological order, # Sort the history by relevance instead of by chronological order,
# except for the latest message # except for the latest message

View File

@ -1,5 +1,6 @@
import gradio as gr import gradio as gr
import speech_recognition as sr import speech_recognition as sr
from modules import shared from modules import shared
input_hijack = { input_hijack = {

View File

@ -24,13 +24,12 @@ class RWKVModel:
@classmethod @classmethod
def from_pretrained(self, path, dtype="fp16", device="cuda"): def from_pretrained(self, path, dtype="fp16", device="cuda"):
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
if shared.args.rwkv_strategy is None: if shared.args.rwkv_strategy is None:
model = RWKV(model=str(path), strategy=f'{device} {dtype}') model = RWKV(model=str(path), strategy=f'{device} {dtype}')
else: else:
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy) model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, str(tokenizer_path))
pipeline = PIPELINE(model, str(tokenizer_path))
result = self() result = self()
result.pipeline = pipeline result.pipeline = pipeline
result.model = model result.model = model
@ -83,7 +82,6 @@ class RWKVModel:
out = self.cached_output_logits out = self.cached_output_logits
for i in range(token_count): for i in range(token_count):
# forward # forward
tokens = self.pipeline.encode(ctx) if i == 0 else [token] tokens = self.pipeline.encode(ctx) if i == 0 else [token]
while len(tokens) > 0: while len(tokens) > 0:
@ -102,6 +100,7 @@ class RWKVModel:
# adjust probabilities # adjust probabilities
for n in args.token_ban: for n in args.token_ban:
out[n] = -float('inf') out[n] = -float('inf')
for n in occurrence: for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
@ -109,6 +108,7 @@ class RWKVModel:
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop: if token in args.token_stop:
break break
all_tokens += [token] all_tokens += [token]
if token not in occurrence: if token not in occurrence:
occurrence[token] = 1 occurrence[token] = 1
@ -120,6 +120,7 @@ class RWKVModel:
if '\ufffd' not in tmp: # is valid utf-8 string? if '\ufffd' not in tmp: # is valid utf-8 string?
if callback: if callback:
callback(tmp) callback(tmp)
out_str += tmp out_str += tmp
return out_str return out_str
@ -133,7 +134,6 @@ class RWKVTokenizer:
def from_pretrained(self, path): def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json" tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(str(tokenizer_path)) tokenizer = Tokenizer.from_file(str(tokenizer_path))
result = self() result = self()
result.tokenizer = tokenizer result.tokenizer = tokenizer
return result return result

View File

@ -1,5 +1,4 @@
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir): def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
''' '''
DeepSpeed configration DeepSpeed configration
https://huggingface.co/docs/transformers/main_classes/deepspeed https://huggingface.co/docs/transformers/main_classes/deepspeed

View File

@ -20,6 +20,8 @@ def load_past_evaluations():
return df return df
else: else:
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment']) return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
past_evaluations = load_past_evaluations() past_evaluations = load_past_evaluations()

View File

@ -7,7 +7,6 @@ import gradio as gr
import extensions import extensions
import modules.shared as shared import modules.shared as shared
state = {} state = {}
available_extensions = [] available_extensions = []
setup_called = set() setup_called = set()

View File

@ -1,6 +1,8 @@
# Copied from https://stackoverflow.com/a/1336640 # Copied from https://stackoverflow.com/a/1336640
import logging import logging
import platform
def add_coloring_to_emit_windows(fn): def add_coloring_to_emit_windows(fn):
# add methods we need to the class # add methods we need to the class
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
def _set_color(self, code): def _set_color(self, code):
import ctypes import ctypes
# Constants from the Windows API # Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11 self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
return new return new
import platform
if platform.system() == 'Windows': if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color # Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)

View File

@ -377,7 +377,7 @@ def create_model_menus():
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load) shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):