From 120fb86c6ac8a30aed96cdefce7248735361bb72 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 20 Aug 2023 20:49:21 -0300 Subject: [PATCH] Add a simple logit viewer (#3636) --- css/main.css | 15 ++++++++++++++- modules/logits.py | 19 +++++++++++++++++++ modules/ui_default.py | 7 ++++++- modules/ui_notebook.py | 7 ++++++- 4 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 modules/logits.py diff --git a/css/main.css b/css/main.css index 7bfd0146..2562cf7d 100644 --- a/css/main.css +++ b/css/main.css @@ -116,7 +116,20 @@ div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { height: calc(100dvh - 241px); } -.textbox_default textarea, .textbox_default_output textarea, .textbox textarea { +.textbox_logits textarea { + height: calc(100dvh - 241px); +} + +.textbox_logits_notebook textarea { + height: calc(100dvh - 292px); +} + +.textbox_default textarea, +.textbox_default_output textarea, +.textbox_logits textarea, +.textbox_logits_notebook textarea, +.textbox textarea +{ font-size: 16px !important; color: #46464A !important; } diff --git a/modules/logits.py b/modules/logits.py new file mode 100644 index 00000000..99cb336f --- /dev/null +++ b/modules/logits.py @@ -0,0 +1,19 @@ +import torch + +from modules import shared + + +def get_next_logits(prompt): + tokens = shared.tokenizer.encode(prompt, return_tensors='pt').cuda() + output = shared.model(input_ids=tokens) + + scores = output['logits'][-1][-1] + probs = torch.softmax(scores, dim=-1, dtype=torch.float) + + topk_values, topk_indices = torch.topk(probs, k=20, largest=True, sorted=True) + topk_values = [f"{float(i):.5f}" % i for i in topk_values] + output = '' + for row in list(zip(topk_values, shared.tokenizer.convert_ids_to_tokens(topk_indices))): + output += f"{row[0]} {row[1]}\n" + + return output diff --git a/modules/ui_default.py b/modules/ui_default.py index 99657227..a5fbc3f5 100644 --- a/modules/ui_default.py +++ b/modules/ui_default.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import shared, ui, utils +from modules import logits, shared, ui, utils from modules.prompts import count_tokens, load_prompt from modules.text_generation import ( generate_reply_wrapper, @@ -43,6 +43,10 @@ def create_ui(): with gr.Tab('HTML'): shared.gradio['html-default'] = gr.HTML() + with gr.Tab('Logits'): + shared.gradio['get_logits-default'] = gr.Button('Get next token probabilities') + shared.gradio['logits-default'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits', 'add_scrollbar']) + def create_event_handlers(): shared.gradio['Generate-default'].click( @@ -80,3 +84,4 @@ def create_event_handlers(): lambda: gr.update(visible=True), None, gradio('file_deleter')) shared.gradio['count_tokens-default'].click(count_tokens, gradio('textbox-default'), gradio('status-default'), show_progress=False) + shared.gradio['get_logits-default'].click(logits.get_next_logits, gradio('textbox-default'), gradio('logits-default')) diff --git a/modules/ui_notebook.py b/modules/ui_notebook.py index 6949ed78..289cf62c 100644 --- a/modules/ui_notebook.py +++ b/modules/ui_notebook.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import shared, ui, utils +from modules import logits, shared, ui, utils from modules.prompts import count_tokens, load_prompt from modules.text_generation import ( generate_reply_wrapper, @@ -27,6 +27,10 @@ def create_ui(): with gr.Tab('HTML'): shared.gradio['html-notebook'] = gr.HTML() + with gr.Tab('Logits'): + shared.gradio['get_logits-notebook'] = gr.Button('Get next token probabilities') + shared.gradio['logits-notebook'] = gr.Textbox(lines=23, label='Output', elem_classes=['textbox_logits_notebook', 'add_scrollbar']) + with gr.Row(): shared.gradio['Generate-notebook'] = gr.Button('Generate', variant='primary', elem_classes='small-button') shared.gradio['Stop-notebook'] = gr.Button('Stop', elem_classes='small-button', elem_id='stop') @@ -83,3 +87,4 @@ def create_event_handlers(): lambda: gr.update(visible=True), None, gradio('file_deleter')) shared.gradio['count_tokens-notebook'].click(count_tokens, gradio('textbox-notebook'), gradio('status-notebook'), show_progress=False) + shared.gradio['get_logits-notebook'].click(logits.get_next_logits, gradio('textbox-notebook'), gradio('logits-notebook'))