text-generation-webui/Colab-TextGen-GPU.ipynb

133 lines
4.8 KiB
Plaintext
Raw Normal View History

2023-10-21 23:27:52 -04:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# oobabooga/text-generation-webui\n",
"\n",
"After running both cells, a public gradio URL will appear at the bottom in a few minutes. You can optionally generate API links.\n",
"\n",
"* Project page: https://github.com/oobabooga/text-generation-webui\n",
"* Gradio server status: https://status.gradio.app/"
],
"metadata": {
"id": "MFQl6-FjSYtY"
}
},
{
"cell_type": "code",
"source": [
"#@title 1. Keep this tab alive to prevent Colab from disconnecting you { display-mode: \"form\" }\n",
"\n",
"#@markdown Press play on the music player that will appear below:\n",
"%%html\n",
"<audio src=\"https://oobabooga.github.io/silence.m4a\" controls>"
],
"metadata": {
"id": "f7TVVj_z4flw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title 2. Launch the web UI\n",
"\n",
"#@markdown If unsure about the branch, write \"main\" or leave it blank.\n",
"\n",
"import torch\n",
"from pathlib import Path\n",
"\n",
"if Path.cwd().name != 'text-generation-webui':\n",
" print(\"Installing the webui...\")\n",
"\n",
" !git clone https://github.com/oobabooga/text-generation-webui\n",
" %cd text-generation-webui\n",
"\n",
" torver = torch.__version__\n",
" print(f\"TORCH: {torver}\")\n",
" is_cuda118 = '+cu118' in torver # 2.1.0+cu118\n",
" is_cuda117 = '+cu117' in torver # 2.0.1+cu117\n",
"\n",
" textgen_requirements = open('requirements.txt').read().splitlines()\n",
" if is_cuda117:\n",
" textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]\n",
2023-10-21 23:27:52 -04:00
" elif is_cuda118:\n",
" textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]\n",
2023-10-21 23:27:52 -04:00
" with open('temp_requirements.txt', 'w') as file:\n",
" file.write('\\n'.join(textgen_requirements))\n",
"\n",
" !pip install -r extensions/api/requirements.txt --upgrade\n",
" !pip install -r temp_requirements.txt --upgrade\n",
"\n",
2023-10-22 11:57:16 -04:00
" print(\"\\033[1;32;1m\\n --> If you see a warning about \\\"previously imported packages\\\", just ignore it.\\033[0;37;0m\")\n",
2023-10-21 23:27:52 -04:00
" print(\"\\033[1;32;1m\\n --> There is no need to restart the runtime.\\n\\033[0;37;0m\")\n",
"\n",
" try:\n",
" import flash_attn\n",
" except:\n",
" !pip uninstall -y flash_attn\n",
"\n",
"# Parameters\n",
"model_url = \"https://huggingface.co/turboderp/Mistral-7B-instruct-exl2\" #@param {type:\"string\"}\n",
"branch = \"4.0bpw\" #@param {type:\"string\"}\n",
"command_line_flags = \"\" #@param {type:\"string\"}\n",
"api = False #@param {type:\"boolean\"}\n",
"\n",
"if api:\n",
" for param in ['--api', '--public-api']:\n",
" if param not in command_line_flags:\n",
" command_line_flags += f\" {param}\"\n",
"\n",
"model_url = model_url.strip()\n",
2023-10-22 11:57:16 -04:00
"if model_url != \"\":\n",
" if not model_url.startswith('http'):\n",
" model_url = 'https://huggingface.co/' + model_url\n",
"\n",
" # Download the model\n",
" url_parts = model_url.strip('/').strip().split('/')\n",
" output_folder = f\"{url_parts[-2]}_{url_parts[-1]}\"\n",
" branch = branch.strip('\"\\' ')\n",
" if branch.strip() != '':\n",
" output_folder += f\"_{branch}\"\n",
" !python download-model.py {model_url} --branch {branch}\n",
" else:\n",
" !python download-model.py {model_url}\n",
2023-10-21 23:27:52 -04:00
"else:\n",
2023-10-22 11:57:16 -04:00
" output_folder = \"\"\n",
2023-10-21 23:27:52 -04:00
"\n",
"# Start the web UI\n",
"cmd = f\"python server.py --share --n-gpu-layers 128\"\n",
2023-10-22 11:57:16 -04:00
"if output_folder != \"\":\n",
" cmd += f\" --model {output_folder}\"\n",
"cmd += f\" {command_line_flags}\"\n",
2023-10-21 23:27:52 -04:00
"print(cmd)\n",
"!$cmd"
],
"metadata": {
"id": "LGQ8BiMuXMDG",
"cellView": "form"
},
"execution_count": null,
"outputs": []
}
]
}