Make chunk length/count customizable

This commit is contained in:
oobabooga 2023-05-07 05:02:04 -03:00
parent 8c06eeaf84
commit 04eca9b65b

View File

@ -65,13 +65,15 @@ class SentenceTransformerEmbedder(Embedder):
embedder = SentenceTransformerEmbedder() embedder = SentenceTransformerEmbedder()
collector = ChromaCollector(embedder) collector = ChromaCollector(embedder)
chunk_count = 5
def feed_data_into_collector(corpus): def feed_data_into_collector(corpus, chunk_len, _chunk_count):
global collector global collector, chunk_count
chunk_count = int(_chunk_count)
chunk_len = int(chunk_len)
cumulative = '' cumulative = ''
chunk_len = 700
cumulative += "Breaking the input dataset...\n\n" cumulative += "Breaking the input dataset...\n\n"
yield cumulative yield cumulative
data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)] data_chunks = [corpus[i:i + chunk_len] for i in range(0, len(corpus), chunk_len)]
@ -83,14 +85,14 @@ def feed_data_into_collector(corpus):
yield cumulative yield cumulative
def feed_file_into_collector(file): def feed_file_into_collector(file, chunk_len, chunk_count):
yield 'Reading the input dataset...\n\n' yield 'Reading the input dataset...\n\n'
text = file.decode('utf-8') text = file.decode('utf-8')
for i in feed_data_into_collector(text): for i in feed_data_into_collector(text, chunk_len, chunk_count):
yield i yield i
def feed_url_into_collector(url): def feed_url_into_collector(url, chunk_len, chunk_count):
yield 'Loading the URL...' yield 'Loading the URL...'
html = urlopen(url).read() html = urlopen(url).read()
soup = BeautifulSoup(html, features="html.parser") soup = BeautifulSoup(html, features="html.parser")
@ -101,7 +103,7 @@ def feed_url_into_collector(url):
lines = (line.strip() for line in text.splitlines()) lines = (line.strip() for line in text.splitlines())
chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
text = '\n\n'.join(chunk for chunk in chunks if chunk) text = '\n\n'.join(chunk for chunk in chunks if chunk)
for i in feed_data_into_collector(text): for i in feed_data_into_collector(text, chunk_len, chunk_count):
yield i yield i
@ -115,8 +117,8 @@ def input_modifier(string):
else: else:
user_input = '' user_input = ''
# Get the 5 most similar chunks # Get the most similar chunks
results = collector.get(user_input, n_results=5) results = collector.get(user_input, n_results=chunk_count)
# Make the replacements # Make the replacements
string = string.replace('<|begin-user-input|>', '') string = string.replace('<|begin-user-input|>', '')
@ -178,9 +180,13 @@ def ui():
file_input = gr.File(label='Input file', type='binary') file_input = gr.File(label='Input file', type='binary')
update_file = gr.Button('Apply') update_file = gr.Button('Apply')
with gr.Row():
chunk_len = gr.Number(value=700, label='Chunk length', info='In characters, not tokens')
chunk_count = gr.Number(value=5, label='Chunk count', info='The number of closest-matching chunks to include in the prompt')
with gr.Column(): with gr.Column():
last_updated = gr.Markdown() last_updated = gr.Markdown()
update_data.click(feed_data_into_collector, data_input, last_updated, show_progress=False) update_data.click(feed_data_into_collector, [data_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_url.click(feed_url_into_collector, url_input, last_updated, show_progress=False) update_url.click(feed_url_into_collector, [url_input, chunk_len, chunk_count], last_updated, show_progress=False)
update_file.click(feed_file_into_collector, file_input, last_updated, show_progress=False) update_file.click(feed_file_into_collector, [file_input, chunk_len, chunk_count], last_updated, show_progress=False)