mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Style improvements (#1957)
This commit is contained in:
parent
334486f527
commit
3913155c1f
@ -5,7 +5,7 @@ import sys
|
|||||||
try:
|
try:
|
||||||
import websockets
|
import websockets
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("Websockets package not found. Make sure it's installed.")
|
print("Websockets package not found. Make sure it's installed.")
|
||||||
|
|
||||||
# For local streaming, the websockets are hosted without ssl - ws://
|
# For local streaming, the websockets are hosted without ssl - ws://
|
||||||
HOST = 'localhost:5005'
|
HOST = 'localhost:5005'
|
||||||
@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
|
|||||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||||
|
|
||||||
|
|
||||||
async def run(context):
|
async def run(context):
|
||||||
# Note: the selected defaults change from time to time.
|
# Note: the selected defaults change from time to time.
|
||||||
request = {
|
request = {
|
||||||
@ -42,7 +43,7 @@ async def run(context):
|
|||||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||||
await websocket.send(json.dumps(request))
|
await websocket.send(json.dumps(request))
|
||||||
|
|
||||||
yield context # Remove this if you just want to see the reply
|
yield context # Remove this if you just want to see the reply
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
incoming_data = await websocket.recv()
|
incoming_data = await websocket.recv()
|
||||||
@ -58,7 +59,7 @@ async def run(context):
|
|||||||
async def print_response_stream(prompt):
|
async def print_response_stream(prompt):
|
||||||
async for response in run(prompt):
|
async for response in run(prompt):
|
||||||
print(response, end='')
|
print(response, end='')
|
||||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
|
|||||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||||
|
|
||||||
|
|
||||||
def run(prompt):
|
def run(prompt):
|
||||||
request = {
|
request = {
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
@ -37,6 +38,7 @@ def run(prompt):
|
|||||||
result = response.json()['results'][0]['text']
|
result = response.json()['results'][0]['text']
|
||||||
print(prompt + result)
|
print(prompt + result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||||
run(prompt)
|
run(prompt)
|
||||||
|
@ -2,11 +2,10 @@ import json
|
|||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
|
@ -5,6 +5,7 @@ from modules import shared
|
|||||||
BLOCKING_PORT = 5000
|
BLOCKING_PORT = 5000
|
||||||
STREAMING_PORT = 5005
|
STREAMING_PORT = 5005
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
||||||
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from websockets.server import serve
|
import json
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from modules import shared
|
from websockets.server import serve
|
||||||
from modules.text_generation import generate_reply
|
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||||
|
from modules import shared
|
||||||
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
PATH = '/api/v1/stream'
|
PATH = '/api/v1/stream'
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import gradio as gr
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
# get the current directory of the script
|
# get the current directory of the script
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import gradio as gr
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||||
|
@ -6,10 +6,11 @@ from io import BytesIO
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, get_max_prompt_length
|
from modules.text_generation import encode, get_max_prompt_length
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -7,6 +7,7 @@ from io import BytesIO
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
|
|||||||
ascii_string = encoded_bytes.decode('ascii')
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
return ascii_string
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/models'):
|
if self.path.startswith('/v1/models'):
|
||||||
@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model, # TODO: add Lora info?
|
"model": model, # TODO: add Lora info?
|
||||||
resp_list: [{
|
resp_list: [{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
}],
|
}],
|
||||||
"usage": {
|
"usage": {
|
||||||
"prompt_tokens": token_count,
|
"prompt_tokens": token_count,
|
||||||
|
@ -6,12 +6,13 @@ from datetime import date
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.shared as shared
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from modules.models import reload_model, unload_model
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.models import reload_model, unload_model
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# parameters which can be customized in settings.json of webui
|
# parameters which can be customized in settings.json of webui
|
||||||
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
|
|||||||
|
|
||||||
picture_response = False # specifies if the next model response should appear as a picture
|
picture_response = False # specifies if the next model response should appear as a picture
|
||||||
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
@ -122,7 +124,6 @@ def input_modifier(string):
|
|||||||
|
|
||||||
# Get and save the Stable Diffusion-generated picture
|
# Get and save the Stable Diffusion-generated picture
|
||||||
def get_SD_pictures(description):
|
def get_SD_pictures(description):
|
||||||
|
|
||||||
global params
|
global params
|
||||||
|
|
||||||
if params['manage_VRAM']:
|
if params['manage_VRAM']:
|
||||||
@ -259,6 +260,7 @@ def SD_api_address_update(address):
|
|||||||
|
|
||||||
return gr.Textbox.update(label=msg)
|
return gr.Textbox.update(label=msg)
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
@ -290,12 +292,11 @@ def ui():
|
|||||||
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
|
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
|
||||||
with gr.Column() as hr_options:
|
with gr.Column() as hr_options:
|
||||||
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
|
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
|
||||||
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
||||||
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
|
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
|
||||||
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
||||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||||
|
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from extensions.silero_tts import tts_preprocessor
|
from extensions.silero_tts import tts_preprocessor
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
@ -216,4 +217,4 @@ def ui():
|
|||||||
|
|
||||||
# Play preview
|
# Play preview
|
||||||
preview_text.submit(voice_preview, preview_text, preview_audio)
|
preview_text.submit(voice_preview, preview_text, preview_audio)
|
||||||
preview_play.click(voice_preview, preview_text, preview_audio)
|
preview_play.click(voice_preview, preview_text, preview_audio)
|
||||||
|
@ -2,7 +2,6 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tts_preprocessor
|
import tts_preprocessor
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
@ -69,7 +69,7 @@ def remove_surrounded_chars(string):
|
|||||||
# first this expression will check if there is a string nested exclusively between a alt=
|
# first this expression will check if there is a string nested exclusively between a alt=
|
||||||
# and a style= string. This would correspond to only a the alt text of an embedded image
|
# and a style= string. This would correspond to only a the alt text of an embedded image
|
||||||
# If it matches it will only keep that part as the string, and rend it for further processing
|
# If it matches it will only keep that part as the string, and rend it for further processing
|
||||||
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
|
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
|
||||||
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL):
|
if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL):
|
||||||
m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL)
|
m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL)
|
||||||
|
@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
|
|||||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||||
n_results = min(len(self.ids), n_results)
|
n_results = min(len(self.ids), n_results)
|
||||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||||
return list(map(lambda x : int(x[2:]), result))
|
return list(map(lambda x: int(x[2:]), result))
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.collection.delete(ids=self.ids)
|
self.collection.delete(ids=self.ids)
|
||||||
@ -162,13 +162,13 @@ def input_modifier(string):
|
|||||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||||
if len(shared.history['internal']) > 2 and user_input != '':
|
if len(shared.history['internal']) > 2 and user_input != '':
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in range(len(shared.history['internal'])-1):
|
for i in range(len(shared.history['internal']) - 1):
|
||||||
chunks.append('\n'.join(shared.history['internal'][i]))
|
chunks.append('\n'.join(shared.history['internal'][i]))
|
||||||
|
|
||||||
add_chunks_to_collector(chunks)
|
add_chunks_to_collector(chunks)
|
||||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||||
try:
|
try:
|
||||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
|
best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
|
||||||
|
|
||||||
# Sort the history by relevance instead of by chronological order,
|
# Sort the history by relevance instead of by chronological order,
|
||||||
# except for the latest message
|
# except for the latest message
|
||||||
@ -226,7 +226,7 @@ def ui():
|
|||||||
|
|
||||||
## Chat mode
|
## Chat mode
|
||||||
|
|
||||||
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||||
|
|
||||||
That is, the prompt will include (starting from the end):
|
That is, the prompt will include (starting from the end):
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import speech_recognition as sr
|
import speech_recognition as sr
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
input_hijack = {
|
input_hijack = {
|
||||||
|
@ -24,13 +24,12 @@ class RWKVModel:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
||||||
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
||||||
|
|
||||||
if shared.args.rwkv_strategy is None:
|
if shared.args.rwkv_strategy is None:
|
||||||
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
||||||
else:
|
else:
|
||||||
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
|
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
|
||||||
pipeline = PIPELINE(model, str(tokenizer_path))
|
|
||||||
|
|
||||||
|
pipeline = PIPELINE(model, str(tokenizer_path))
|
||||||
result = self()
|
result = self()
|
||||||
result.pipeline = pipeline
|
result.pipeline = pipeline
|
||||||
result.model = model
|
result.model = model
|
||||||
@ -83,7 +82,6 @@ class RWKVModel:
|
|||||||
out = self.cached_output_logits
|
out = self.cached_output_logits
|
||||||
|
|
||||||
for i in range(token_count):
|
for i in range(token_count):
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
||||||
while len(tokens) > 0:
|
while len(tokens) > 0:
|
||||||
@ -91,35 +89,38 @@ class RWKVModel:
|
|||||||
tokens = tokens[args.chunk_len:]
|
tokens = tokens[args.chunk_len:]
|
||||||
|
|
||||||
# cache the model state after scanning the context
|
# cache the model state after scanning the context
|
||||||
# we don't cache the state after processing our own generated tokens because
|
# we don't cache the state after processing our own generated tokens because
|
||||||
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
|
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
|
||||||
# on the next round of chat might be slightly different what what it output on the previous round
|
# on the next round of chat might be slightly different what what it output on the previous round
|
||||||
if i == 0:
|
if i == 0:
|
||||||
self.cached_context += ctx
|
self.cached_context += ctx
|
||||||
self.cached_model_state = copy.deepcopy(state)
|
self.cached_model_state = copy.deepcopy(state)
|
||||||
self.cached_output_logits = copy.deepcopy(out)
|
self.cached_output_logits = copy.deepcopy(out)
|
||||||
|
|
||||||
# adjust probabilities
|
# adjust probabilities
|
||||||
for n in args.token_ban:
|
for n in args.token_ban:
|
||||||
out[n] = -float('inf')
|
out[n] = -float('inf')
|
||||||
|
|
||||||
for n in occurrence:
|
for n in occurrence:
|
||||||
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
||||||
|
|
||||||
# sampler
|
# sampler
|
||||||
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
||||||
if token in args.token_stop:
|
if token in args.token_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
all_tokens += [token]
|
all_tokens += [token]
|
||||||
if token not in occurrence:
|
if token not in occurrence:
|
||||||
occurrence[token] = 1
|
occurrence[token] = 1
|
||||||
else:
|
else:
|
||||||
occurrence[token] += 1
|
occurrence[token] += 1
|
||||||
|
|
||||||
# output
|
# output
|
||||||
tmp = self.pipeline.decode([token])
|
tmp = self.pipeline.decode([token])
|
||||||
if '\ufffd' not in tmp: # is valid utf-8 string?
|
if '\ufffd' not in tmp: # is valid utf-8 string?
|
||||||
if callback:
|
if callback:
|
||||||
callback(tmp)
|
callback(tmp)
|
||||||
|
|
||||||
out_str += tmp
|
out_str += tmp
|
||||||
|
|
||||||
return out_str
|
return out_str
|
||||||
@ -133,7 +134,6 @@ class RWKVTokenizer:
|
|||||||
def from_pretrained(self, path):
|
def from_pretrained(self, path):
|
||||||
tokenizer_path = path / "20B_tokenizer.json"
|
tokenizer_path = path / "20B_tokenizer.json"
|
||||||
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
||||||
|
|
||||||
result = self()
|
result = self()
|
||||||
result.tokenizer = tokenizer
|
result.tokenizer = tokenizer
|
||||||
return result
|
return result
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
||||||
|
|
||||||
'''
|
'''
|
||||||
DeepSpeed configration
|
DeepSpeed configration
|
||||||
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
||||||
|
@ -20,6 +20,8 @@ def load_past_evaluations():
|
|||||||
return df
|
return df
|
||||||
else:
|
else:
|
||||||
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
||||||
|
|
||||||
|
|
||||||
past_evaluations = load_past_evaluations()
|
past_evaluations = load_past_evaluations()
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ import gradio as gr
|
|||||||
import extensions
|
import extensions
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
available_extensions = []
|
available_extensions = []
|
||||||
setup_called = set()
|
setup_called = set()
|
||||||
@ -91,7 +90,7 @@ def _apply_state_modifier_extensions(state):
|
|||||||
state = getattr(extension, "state_modifier")(state)
|
state = getattr(extension, "state_modifier")(state)
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# Extension functions that override the default tokenizer output - currently only the first one will work
|
# Extension functions that override the default tokenizer output - currently only the first one will work
|
||||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||||
@ -108,7 +107,7 @@ def _apply_custom_tokenized_length(prompt):
|
|||||||
for extension, _ in iterator():
|
for extension, _ in iterator():
|
||||||
if hasattr(extension, 'custom_tokenized_length'):
|
if hasattr(extension, 'custom_tokenized_length'):
|
||||||
return getattr(extension, 'custom_tokenized_length')(prompt)
|
return getattr(extension, 'custom_tokenized_length')(prompt)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Copied from https://stackoverflow.com/a/1336640
|
# Copied from https://stackoverflow.com/a/1336640
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
def add_coloring_to_emit_windows(fn):
|
def add_coloring_to_emit_windows(fn):
|
||||||
# add methods we need to the class
|
# add methods we need to the class
|
||||||
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
|
|||||||
|
|
||||||
def _set_color(self, code):
|
def _set_color(self, code):
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
# Constants from the Windows API
|
# Constants from the Windows API
|
||||||
self.STD_OUTPUT_HANDLE = -11
|
self.STD_OUTPUT_HANDLE = -11
|
||||||
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||||
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
import platform
|
|
||||||
if platform.system() == 'Windows':
|
if platform.system() == 'Windows':
|
||||||
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
||||||
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
||||||
|
@ -161,10 +161,10 @@ def load_model(model_name):
|
|||||||
# Custom
|
# Custom
|
||||||
else:
|
else:
|
||||||
params = {
|
params = {
|
||||||
"low_cpu_mem_usage": True,
|
"low_cpu_mem_usage": True,
|
||||||
"trust_remote_code": trust_remote_code
|
"trust_remote_code": trust_remote_code
|
||||||
}
|
}
|
||||||
|
|
||||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||||
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
||||||
shared.args.cpu = True
|
shared.args.cpu = True
|
||||||
@ -288,7 +288,7 @@ def load_soft_prompt(name):
|
|||||||
logging.info(f"{field}: {', '.join(j[field])}")
|
logging.info(f"{field}: {', '.join(j[field])}")
|
||||||
else:
|
else:
|
||||||
logging.info(f"{field}: {j[field]}")
|
logging.info(f"{field}: {j[field]}")
|
||||||
|
|
||||||
logging.info()
|
logging.info()
|
||||||
tensor = np.load('tensor.npy')
|
tensor = np.load('tensor.npy')
|
||||||
Path('tensor.npy').unlink()
|
Path('tensor.npy').unlink()
|
||||||
|
@ -377,7 +377,7 @@ def create_model_menus():
|
|||||||
|
|
||||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||||
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||||
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||||
|
|
||||||
|
|
||||||
def create_settings_menus(default_preset):
|
def create_settings_menus(default_preset):
|
||||||
|
Loading…
Reference in New Issue
Block a user