text-generation-webui/Colab-TextGen-GPU.ipynb

130 lines
4.6 KiB
Plaintext
Raw Normal View History

2023-10-21 23:27:52 -04:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# oobabooga/text-generation-webui\n",
"\n",
2023-11-10 12:18:25 -05:00
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate an API link.\n",
2023-10-21 23:27:52 -04:00
"\n",
"* Project page: https://github.com/oobabooga/text-generation-webui\n",
"* Gradio server status: https://status.gradio.app/"
],
"metadata": {
"id": "MFQl6-FjSYtY"
}
},
{
"cell_type": "code",
"source": [
"#@title 1. Keep this tab alive to prevent Colab from disconnecting you { display-mode: \"form\" }\n",
"\n",
"#@markdown Press play on the music player that will appear below:\n",
"%%html\n",
"<audio src=\"https://oobabooga.github.io/silence.m4a\" controls>"
],
"metadata": {
"id": "f7TVVj_z4flw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 2. Launch the web UI\n",
"\n",
"#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
"\n",
"import torch\n",
"from pathlib import Path\n",
"\n",
"if Path.cwd().name != 'text-generation-webui':\n",
" print(\"Installing the webui...\")\n",
"\n",
" !git clone https://github.com/oobabooga/text-generation-webui\n",
" %cd text-generation-webui\n",
"\n",
" torver = torch.__version__\n",
" print(f\"TORCH: {torver}\")\n",
" is_cuda118 = '+cu118' in torver # 2.1.0+cu118\n",
"\n",
" textgen_requirements = open('requirements.txt').read().splitlines()\n",
" if is_cuda118:\n",
" textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n",
2023-10-21 23:27:52 -04:00
" with open('temp_requirements.txt', 'w') as file:\n",
" file.write('\\n'.join(textgen_requirements))\n",
"\n",
" !pip install -r temp_requirements.txt --upgrade\n",
"\n",
2023-10-22 11:57:16 -04:00
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
2023-10-21 23:27:52 -04:00
" print(\"\\033[1;32;1m\\n --> There is no need to restart the runtime.\\n\\033[0;37;0m\")\n",
"\n",
" try:\n",
" import flash_attn\n",
" except:\n",
" !pip uninstall -y flash_attn\n",
"\n",
"# Parameters\n",
2023-11-30 14:55:18 -05:00
"model_url = \"https://huggingface.co/TheBloke/MythoMax-L2-13B-GPTQ\" #@param {type:\"string\"}\n",
"branch = \"gptq-4bit-32g-actorder_True\" #@param {type:\"string\"}\n",
"command_line_flags = \"--n-gpu-layers 128 --load-in-4bit --use_double_quant\" #@param {type:\"string\"}\n",
2023-10-21 23:27:52 -04:00
"api = False #@param {type:\"boolean\"}\n",
"\n",
"if api:\n",
" for param in ['--api', '--public-api']:\n",
" if param not in command_line_flags:\n",
" command_line_flags += f\" {param}\"\n",
"\n",
"model_url = model_url.strip()\n",
2023-10-22 11:57:16 -04:00
"if model_url != \"\":\n",
" if not model_url.startswith('http'):\n",
" model_url = 'https://huggingface.co/' + model_url\n",
"\n",
" # Download the model\n",
" url_parts = model_url.strip('/').strip().split('/')\n",
" output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
" branch = branch.strip('\"\\' ')\n",
2023-11-15 11:18:32 -05:00
" if branch.strip() not in ['', 'main']:\n",
2023-10-22 11:57:16 -04:00
" output_folder += f\"_{branch}\"\n",
" !python download-model.py {model_url} --branch {branch}\n",
" else:\n",
" !python download-model.py {model_url}\n",
2023-10-21 23:27:52 -04:00
"else:\n",
2023-10-22 11:57:16 -04:00
" output_folder = \"\"\n",
2023-10-21 23:27:52 -04:00
"\n",
"# Start the web UI\n",
2023-10-22 12:56:43 -04:00
"cmd = f\"python server.py --share\"\n",
2023-10-22 11:57:16 -04:00
"if output_folder != \"\":\n",
" cmd += f\" --model {output_folder}\"\n",
"cmd += f\" {command_line_flags}\"\n",
2023-10-21 23:27:52 -04:00
"print(cmd)\n",
"!$cmd"
],
"metadata": {
"id": "LGQ8BiMuXMDG",
"cellView": "form"
},
"execution_count": null,
"outputs": []
}
]
}