mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Style improvements (#1957)
This commit is contained in:
parent
334486f527
commit
3913155c1f
@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
||||
async def run(context):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
|
@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||
|
||||
|
||||
def run(prompt):
|
||||
request = {
|
||||
'prompt': prompt,
|
||||
@ -37,6 +38,7 @@ def run(prompt):
|
||||
result = response.json()['results'][0]['text']
|
||||
print(prompt + result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
run(prompt)
|
||||
|
@ -2,11 +2,10 @@ import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
|
@ -5,6 +5,7 @@ from modules import shared
|
||||
BLOCKING_PORT = 5000
|
||||
STREAMING_PORT = 5005
|
||||
|
||||
|
||||
def setup():
|
||||
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
||||
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
||||
|
@ -1,12 +1,12 @@
|
||||
import json
|
||||
import asyncio
|
||||
from websockets.server import serve
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
from websockets.server import serve
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# get the current directory of the script
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def ui():
|
||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||
|
@ -6,10 +6,11 @@ from io import BytesIO
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -7,6 +7,7 @@ from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||
from modules import shared
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path.startswith('/v1/models'):
|
||||
|
@ -6,12 +6,13 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.shared as shared
|
||||
import requests
|
||||
import torch
|
||||
from modules.models import reload_model, unload_model
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import reload_model, unload_model
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# parameters which can be customized in settings.json of webui
|
||||
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
|
||||
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
@ -122,7 +124,6 @@ def input_modifier(string):
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params
|
||||
|
||||
if params['manage_VRAM']:
|
||||
@ -259,6 +260,7 @@ def SD_api_address_update(address):
|
||||
|
||||
return gr.Textbox.update(label=msg)
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
@ -296,7 +298,6 @@ def ui():
|
||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
|
@ -2,7 +2,6 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
import tts_preprocessor
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||
return list(map(lambda x : int(x[2:]), result))
|
||||
return list(map(lambda x: int(x[2:]), result))
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
@ -162,13 +162,13 @@ def input_modifier(string):
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
if len(shared.history['internal']) > 2 and user_input != '':
|
||||
chunks = []
|
||||
for i in range(len(shared.history['internal'])-1):
|
||||
for i in range(len(shared.history['internal']) - 1):
|
||||
chunks.append('\n'.join(shared.history['internal'][i]))
|
||||
|
||||
add_chunks_to_collector(chunks)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
|
||||
|
||||
# Sort the history by relevance instead of by chronological order,
|
||||
# except for the latest message
|
||||
|
@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
import speech_recognition as sr
|
||||
|
||||
from modules import shared
|
||||
|
||||
input_hijack = {
|
||||
|
@ -24,13 +24,12 @@ class RWKVModel:
|
||||
@classmethod
|
||||
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
||||
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
||||
|
||||
if shared.args.rwkv_strategy is None:
|
||||
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
||||
else:
|
||||
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
|
||||
pipeline = PIPELINE(model, str(tokenizer_path))
|
||||
|
||||
pipeline = PIPELINE(model, str(tokenizer_path))
|
||||
result = self()
|
||||
result.pipeline = pipeline
|
||||
result.model = model
|
||||
@ -83,7 +82,6 @@ class RWKVModel:
|
||||
out = self.cached_output_logits
|
||||
|
||||
for i in range(token_count):
|
||||
|
||||
# forward
|
||||
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
||||
while len(tokens) > 0:
|
||||
@ -102,6 +100,7 @@ class RWKVModel:
|
||||
# adjust probabilities
|
||||
for n in args.token_ban:
|
||||
out[n] = -float('inf')
|
||||
|
||||
for n in occurrence:
|
||||
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
||||
|
||||
@ -109,6 +108,7 @@ class RWKVModel:
|
||||
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
||||
if token in args.token_stop:
|
||||
break
|
||||
|
||||
all_tokens += [token]
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
@ -120,6 +120,7 @@ class RWKVModel:
|
||||
if '\ufffd' not in tmp: # is valid utf-8 string?
|
||||
if callback:
|
||||
callback(tmp)
|
||||
|
||||
out_str += tmp
|
||||
|
||||
return out_str
|
||||
@ -133,7 +134,6 @@ class RWKVTokenizer:
|
||||
def from_pretrained(self, path):
|
||||
tokenizer_path = path / "20B_tokenizer.json"
|
||||
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
||||
|
||||
result = self()
|
||||
result.tokenizer = tokenizer
|
||||
return result
|
||||
|
@ -1,5 +1,4 @@
|
||||
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
||||
|
||||
'''
|
||||
DeepSpeed configration
|
||||
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
||||
|
@ -20,6 +20,8 @@ def load_past_evaluations():
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
||||
|
||||
|
||||
past_evaluations = load_past_evaluations()
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@ import gradio as gr
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
state = {}
|
||||
available_extensions = []
|
||||
setup_called = set()
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copied from https://stackoverflow.com/a/1336640
|
||||
|
||||
import logging
|
||||
import platform
|
||||
|
||||
|
||||
def add_coloring_to_emit_windows(fn):
|
||||
# add methods we need to the class
|
||||
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
|
||||
|
||||
def _set_color(self, code):
|
||||
import ctypes
|
||||
|
||||
# Constants from the Windows API
|
||||
self.STD_OUTPUT_HANDLE = -11
|
||||
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
|
||||
return new
|
||||
|
||||
|
||||
import platform
|
||||
if platform.system() == 'Windows':
|
||||
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
||||
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
||||
|
@ -377,7 +377,7 @@ def create_model_menus():
|
||||
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
|
Loading…
Reference in New Issue
Block a user