diff --git a/css/main.css b/css/main.css index 83b8850b..b87053be 100644 --- a/css/main.css +++ b/css/main.css @@ -59,6 +59,10 @@ ol li p, ul li p { margin-bottom: 35px; } +.extension-tab { + border: 0 !important; +} + span.math.inline { font-size: 27px; vertical-align: baseline !important; diff --git a/docs/Extensions.md b/docs/Extensions.md index 0e52c8d1..287ee32a 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -49,18 +49,35 @@ Additionally, the script may define two special global variables: #### `params` dictionary +`script.py` may contain a special dictionary called `params`: + ```python params = { - "language string": "ja", + "display_name": "Google Translate", + "is_tab": True, } ``` -This dicionary can be used to make the extension parameters customizable by adding entries to a `settings.json` file like this: +In this dictionary, `display_name` is used to define the displayed name of the extension inside the UI, and `is_tab` is used to define whether the extension's `ui()` function should be called in a new `gr.Tab()` that will appear in the header bar. By default, the extension appears at the bottom of the "Text generation" tab. + +Additionally, `params` may contain variables that you want to be customizable through a `settings.json` file. For instance, assuming the extension is in `extensions/google_translate`, the variable `language string` below + +```python +params = { + "display_name": "Google Translate", + "is_tab": True, + "language string": "jp" +} +``` + +can be customized by adding a key called `google_translate-language string` to `settings.json`: ```python "google_translate-language string": "fr", ``` +That is, the syntax is `extension_name-variable_name`. + #### `input_hijack` dictionary ```python diff --git a/modules/extensions.py b/modules/extensions.py index fe8cb7be..2f68caf5 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -137,6 +137,30 @@ def _apply_custom_js(): return all_js +def create_extensions_block(): + to_display = [] + for extension, name in iterator(): + if hasattr(extension, "ui") and not (hasattr(extension, 'params') and extension.params.get('is_tab', False)): + to_display.append((extension, name)) + + # Creating the extension ui elements + if len(to_display) > 0: + with gr.Column(elem_id="extensions"): + for row in to_display: + extension, name = row + display_name = getattr(extension, 'params', {}).get('display_name', name) + gr.Markdown(f"\n### {display_name}") + extension.ui() + + +def create_extensions_tabs(): + for extension, name in iterator(): + if hasattr(extension, "ui") and (hasattr(extension, 'params') and extension.params.get('is_tab', False)): + display_name = getattr(extension, 'params', {}).get('display_name', name) + with gr.Tab(display_name, elem_classes="extension-tab"): + extension.ui() + + EXTENSION_MAP = { "input": partial(_apply_string_extensions, "input_modifier"), "output": partial(_apply_string_extensions, "output_modifier"), @@ -157,21 +181,3 @@ def apply_extensions(typ, *args, **kwargs): raise ValueError(f"Invalid extension type {typ}") return EXTENSION_MAP[typ](*args, **kwargs) - - -def create_extensions_block(): - global setup_called - - should_display_ui = False - for extension, name in iterator(): - if hasattr(extension, "ui"): - should_display_ui = True - break - - # Creating the extension ui elements - if should_display_ui: - with gr.Column(elem_id="extensions"): - for extension, name in iterator(): - if hasattr(extension, "ui"): - gr.Markdown(f"\n### {name}") - extension.ui() diff --git a/server.py b/server.py index 334841e6..0115a97e 100644 --- a/server.py +++ b/server.py @@ -877,9 +877,11 @@ def create_interface(): shared.gradio['interface'].load(None, None, None, _js=f"() => {{{js}}}") shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) + # Extensions tabs + extensions_module.create_extensions_tabs() + # Extensions block - if shared.args.extensions is not None: - extensions_module.create_extensions_block() + extensions_module.create_extensions_block() # Launch the interface shared.gradio['interface'].queue()