text-generation-webui/Colab-TextGen-GPU.ipynb
2023-10-22 08:44:35 -07:00

127 lines
4.6 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# oobabooga/text-generation-webui\n",
"\n",
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate API links.\n",
"\n",
"* Project page: https://github.com/oobabooga/text-generation-webui\n",
"* Gradio server status: https://status.gradio.app/"
],
"metadata": {
"id": "MFQl6-FjSYtY"
}
},
{
"cell_type": "code",
"source": [
"#@title 1. Keep this tab alive to prevent Colab from disconnecting you { display-mode: \"form\" }\n",
"\n",
"#@markdown Press play on the music player that will appear below:\n",
"%%html\n",
"<audio src=\"https://oobabooga.github.io/silence.m4a\" controls>"
],
"metadata": {
"id": "f7TVVj_z4flw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 2. Launch the web UI\n",
"\n",
"#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
"\n",
"import torch\n",
"from pathlib import Path\n",
"\n",
"if Path.cwd().name != 'text-generation-webui':\n",
" print(\"Installing the webui...\")\n",
"\n",
" !git clone https://github.com/oobabooga/text-generation-webui\n",
" %cd text-generation-webui\n",
"\n",
" torver = torch.__version__\n",
" print(f\"TORCH: {torver}\")\n",
" is_cuda118 = '+cu118' in torver # 2.1.0+cu118\n",
" is_cuda117 = '+cu117' in torver # 2.0.1+cu117\n",
"\n",
" textgen_requirements = open('requirements.txt').read().splitlines()\n",
" if is_cuda117:\n",
" textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n",
" elif is_cuda118:\n",
" textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n",
" with open('temp_requirements.txt', 'w') as file:\n",
" file.write('\\n'.join(textgen_requirements))\n",
"\n",
" !pip install -r extensions/api/requirements.txt --upgrade\n",
" !pip install -r temp_requirements.txt --upgrade\n",
"\n",
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\n\\033[0;37;0m\")",
" print(\"\\033[1;32;1m\\n --> There is no need to restart the runtime.\\n\\033[0;37;0m\")\n",
"\n",
" try:\n",
" import flash_attn\n",
" except:\n",
" !pip uninstall -y flash_attn\n",
"\n",
"# Parameters\n",
"model_url = \"https://huggingface.co/turboderp/Mistral-7B-instruct-exl2\" #@param {type:\"string\"}\n",
"branch = \"4.0bpw\" #@param {type:\"string\"}\n",
"command_line_flags = \"\" #@param {type:\"string\"}\n",
"api = False #@param {type:\"boolean\"}\n",
"\n",
"if api:\n",
" for param in ['--api', '--public-api']:\n",
" if param not in command_line_flags:\n",
" command_line_flags += f\" {param}\"\n",
"\n",
"model_url = model_url.strip()\n",
"if not model_url.startswith('http'):\n",
" model_url = 'https://huggingface.co/' + model_url\n",
"\n",
"# Download the model\n",
"url_parts = model_url.strip('/').strip().split('/')\n",
"output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
"branch = branch.strip('\"\\' ')\n",
"if branch.strip() != '':\n",
" output_folder += f\"_{branch}\"\n",
" !python download-model.py {model_url} --branch {branch}\n",
"else:\n",
" !python download-model.py {model_url}\n",
"\n",
"# Start the web UI\n",
"cmd = f\"python server.py --share --model {output_folder} {command_line_flags}\"\n",
"print(cmd)\n",
"!$cmd"
],
"metadata": {
"id": "LGQ8BiMuXMDG",
"cellView": "form"
},
"execution_count": null,
"outputs": []
}
]
}