mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Make chunk length/count customizable
This commit is contained in:
parent
8c06eeaf84
commit
04eca9b65b
@ -65,13 +65,15 @@ class SentenceTransformerEmbedder(Embedder):
|
|||||||
|
|
||||||
embedder = SentenceTransformerEmbedder()
|
embedder = SentenceTransformerEmbedder()
|
||||||
collector = ChromaCollector(embedder)
|
collector = ChromaCollector(embedder)
|
||||||
|
chunk_count = 5
|
||||||
|
|
||||||
|
|
||||||
def feed_data_into_collector(corpus):
|
def feed_data_into_collector(corpus, chunk_len, _chunk_count):
|
||||||
global collector
|
global collector, chunk_count
|
||||||
|
chunk_count = int(_chunk_count)
|
||||||
|
chunk_len = int(chunk_len)
|
||||||
|
|
||||||
cumulative = ''
|
cumulative = ''
|
||||||
chunk_len = 700
|
|
||||||
cumulative += "Breaking the input dataset...\n\n"
|
cumulative += "Breaking the input dataset...\n\n"
|
||||||
yield cumulative
|
yield cumulative
|
||||||
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
|
||||||
@ -83,14 +85,14 @@ def feed_data_into_collector(corpus):
|
|||||||
yield cumulative
|
yield cumulative
|
||||||
|
|
||||||
|
|
||||||
def feed_file_into_collector(file):
|
def feed_file_into_collector(file, chunk_len, chunk_count):
|
||||||
yield 'Reading the input dataset...\n\n'
|
yield 'Reading the input dataset...\n\n'
|
||||||
text = file.decode('utf-8')
|
text = file.decode('utf-8')
|
||||||
for i in feed_data_into_collector(text):
|
for i in feed_data_into_collector(text, chunk_len, chunk_count):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
def feed_url_into_collector(url):
|
def feed_url_into_collector(url, chunk_len, chunk_count):
|
||||||
yield 'Loading the URL...'
|
yield 'Loading the URL...'
|
||||||
html = urlopen(url).read()
|
html = urlopen(url).read()
|
||||||
soup = BeautifulSoup(html, features="html.parser")
|
soup = BeautifulSoup(html, features="html.parser")
|
||||||
@ -101,7 +103,7 @@ def feed_url_into_collector(url):
|
|||||||
lines = (line.strip() for line in text.splitlines())
|
lines = (line.strip() for line in text.splitlines())
|
||||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||||
text = '\n\n'.join(chunk for chunk in chunks if chunk)
|
text = '\n\n'.join(chunk for chunk in chunks if chunk)
|
||||||
for i in feed_data_into_collector(text):
|
for i in feed_data_into_collector(text, chunk_len, chunk_count):
|
||||||
yield i
|
yield i
|
||||||
|
|
||||||
|
|
||||||
@ -115,8 +117,8 @@ def input_modifier(string):
|
|||||||
else:
|
else:
|
||||||
user_input = ''
|
user_input = ''
|
||||||
|
|
||||||
# Get the 5 most similar chunks
|
# Get the most similar chunks
|
||||||
results = collector.get(user_input, n_results=5)
|
results = collector.get(user_input, n_results=chunk_count)
|
||||||
|
|
||||||
# Make the replacements
|
# Make the replacements
|
||||||
string = string.replace('<|begin-user-input|>', '')
|
string = string.replace('<|begin-user-input|>', '')
|
||||||
@ -178,9 +180,13 @@ def ui():
|
|||||||
file_input = gr.File(label='Input file', type='binary')
|
file_input = gr.File(label='Input file', type='binary')
|
||||||
update_file = gr.Button('Apply')
|
update_file = gr.Button('Apply')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens')
|
||||||
|
chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
last_updated = gr.Markdown()
|
last_updated = gr.Markdown()
|
||||||
|
|
||||||
update_data.click(feed_data_into_collector, data_input, last_updated, show_progress=False)
|
update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_count], last_updated, show_progress=False)
|
||||||
update_url.click(feed_url_into_collector, url_input, last_updated, show_progress=False)
|
update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_count], last_updated, show_progress=False)
|
||||||
update_file.click(feed_file_into_collector, file_input, last_updated, show_progress=False)
|
update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_count], last_updated, show_progress=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user