mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Add --admin-key flag for API (#4649)
This commit is contained in:
parent
af76fbedb8
commit
8f4f4daf8b
@ -413,6 +413,7 @@ Optionally, you can use the following command-line flags:
|
|||||||
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
|
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
|
||||||
| `--api-port API_PORT` | The listening port for the API. |
|
| `--api-port API_PORT` | The listening port for the API. |
|
||||||
| `--api-key API_KEY` | API authentication key. |
|
| `--api-key API_KEY` | API authentication key. |
|
||||||
|
| `--admin-key ADMIN_KEY` | API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key. |
|
||||||
|
|
||||||
#### Multimodal
|
#### Multimodal
|
||||||
|
|
||||||
|
@ -60,7 +60,15 @@ def verify_api_key(authorization: str = Header(None)) -> None:
|
|||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(dependencies=[Depends(verify_api_key)])
|
def verify_admin_key(authorization: str = Header(None)) -> None:
|
||||||
|
expected_api_key = shared.args.admin_key
|
||||||
|
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
check_key = [Depends(verify_api_key)]
|
||||||
|
check_admin_key = [Depends(verify_admin_key)]
|
||||||
|
|
||||||
# Configure CORS settings to allow all origins, methods, and headers
|
# Configure CORS settings to allow all origins, methods, and headers
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@ -72,12 +80,12 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.options("/")
|
@app.options("/", dependencies=check_key)
|
||||||
async def options_route():
|
async def options_route():
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/completions', response_model=CompletionResponse)
|
@app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key)
|
||||||
async def openai_completions(request: Request, request_data: CompletionRequest):
|
async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
is_legacy = "/generate" in path
|
is_legacy = "/generate" in path
|
||||||
@ -100,7 +108,7 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse)
|
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key)
|
||||||
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
is_legacy = "/generate" in path
|
is_legacy = "/generate" in path
|
||||||
@ -123,8 +131,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models", dependencies=check_key)
|
||||||
@app.get("/v1/models/{model}")
|
@app.get("/v1/models/{model}", dependencies=check_key)
|
||||||
async def handle_models(request: Request):
|
async def handle_models(request: Request):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
|
||||||
@ -138,7 +146,7 @@ async def handle_models(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.get('/v1/billing/usage')
|
@app.get('/v1/billing/usage', dependencies=check_key)
|
||||||
def handle_billing_usage():
|
def handle_billing_usage():
|
||||||
'''
|
'''
|
||||||
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
@ -146,7 +154,7 @@ def handle_billing_usage():
|
|||||||
return JSONResponse(content={"total_usage": 0})
|
return JSONResponse(content={"total_usage": 0})
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/audio/transcriptions')
|
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
||||||
async def handle_audio_transcription(request: Request):
|
async def handle_audio_transcription(request: Request):
|
||||||
r = sr.Recognizer()
|
r = sr.Recognizer()
|
||||||
|
|
||||||
@ -176,7 +184,7 @@ async def handle_audio_transcription(request: Request):
|
|||||||
return JSONResponse(content=transcription)
|
return JSONResponse(content=transcription)
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/images/generations')
|
@app.post('/v1/images/generations', dependencies=check_key)
|
||||||
async def handle_image_generation(request: Request):
|
async def handle_image_generation(request: Request):
|
||||||
|
|
||||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||||
@ -192,7 +200,7 @@ async def handle_image_generation(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/embeddings", response_model=EmbeddingsResponse)
|
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
|
||||||
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||||
input = request_data.input
|
input = request_data.input
|
||||||
if not input:
|
if not input:
|
||||||
@ -205,7 +213,7 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/moderations")
|
@app.post("/v1/moderations", dependencies=check_key)
|
||||||
async def handle_moderations(request: Request):
|
async def handle_moderations(request: Request):
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
input = body["input"]
|
input = body["input"]
|
||||||
@ -216,37 +224,37 @@ async def handle_moderations(request: Request):
|
|||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/encode", response_model=EncodeResponse)
|
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
|
||||||
async def handle_token_encode(request_data: EncodeRequest):
|
async def handle_token_encode(request_data: EncodeRequest):
|
||||||
response = token_encode(request_data.text)
|
response = token_encode(request_data.text)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/decode", response_model=DecodeResponse)
|
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
|
||||||
async def handle_token_decode(request_data: DecodeRequest):
|
async def handle_token_decode(request_data: DecodeRequest):
|
||||||
response = token_decode(request_data.tokens)
|
response = token_decode(request_data.tokens)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/token-count", response_model=TokenCountResponse)
|
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
|
||||||
async def handle_token_count(request_data: EncodeRequest):
|
async def handle_token_count(request_data: EncodeRequest):
|
||||||
response = token_count(request_data.text)
|
response = token_count(request_data.text)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/stop-generation")
|
@app.post("/v1/internal/stop-generation", dependencies=check_key)
|
||||||
async def handle_stop_generation(request: Request):
|
async def handle_stop_generation(request: Request):
|
||||||
stop_everything_event()
|
stop_everything_event()
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse)
|
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key)
|
||||||
async def handle_model_info():
|
async def handle_model_info():
|
||||||
payload = OAImodels.get_current_model_info()
|
payload = OAImodels.get_current_model_info()
|
||||||
return JSONResponse(content=payload)
|
return JSONResponse(content=payload)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/model/load")
|
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
|
||||||
async def handle_load_model(request_data: LoadModelRequest):
|
async def handle_load_model(request_data: LoadModelRequest):
|
||||||
'''
|
'''
|
||||||
This endpoint is experimental and may change in the future.
|
This endpoint is experimental and may change in the future.
|
||||||
@ -283,7 +291,7 @@ async def handle_load_model(request_data: LoadModelRequest):
|
|||||||
return HTTPException(status_code=400, detail="Failed to load the model.")
|
return HTTPException(status_code=400, detail="Failed to load the model.")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/model/unload")
|
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||||
async def handle_unload_model():
|
async def handle_unload_model():
|
||||||
unload_model()
|
unload_model()
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
@ -308,8 +316,14 @@ def run_server():
|
|||||||
logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n')
|
logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n')
|
||||||
|
|
||||||
if shared.args.api_key:
|
if shared.args.api_key:
|
||||||
|
if not shared.args.admin_key:
|
||||||
|
shared.args.admin_key = shared.args.api_key
|
||||||
|
|
||||||
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
|
||||||
|
|
||||||
|
if shared.args.admin_key:
|
||||||
|
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
|
||||||
|
|
||||||
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)
|
||||||
|
|
||||||
|
|
||||||
|
@ -170,6 +170,7 @@ parser.add_argument('--public-api', action='store_true', help='Create a public U
|
|||||||
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
|
||||||
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
|
||||||
parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
|
||||||
|
parser.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.')
|
||||||
|
|
||||||
# Multimodal
|
# Multimodal
|
||||||
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
|
||||||
|
Loading…
Reference in New Issue
Block a user