Add --admin-key flag for API (#4649)

This commit is contained in:
oobabooga 2023-11-18 22:33:27 -03:00 committed by GitHub
parent af76fbedb8
commit 8f4f4daf8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 34 additions and 18 deletions

View File

@ -413,6 +413,7 @@ Optionally, you can use the following command-line flags:
| `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. | | `--public-api-id PUBLIC_API_ID` | Tunnel ID for named Cloudflare Tunnel. Use together with public-api option. |
| `--api-port API_PORT` | The listening port for the API. | | `--api-port API_PORT` | The listening port for the API. |
| `--api-key API_KEY` | API authentication key. | | `--api-key API_KEY` | API authentication key. |
| `--admin-key ADMIN_KEY` | API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key. |
#### Multimodal #### Multimodal

View File

@ -60,7 +60,15 @@ def verify_api_key(authorization: str = Header(None)) -> None:
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
app = FastAPI(dependencies=[Depends(verify_api_key)]) def verify_admin_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.admin_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
raise HTTPException(status_code=401, detail="Unauthorized")
app = FastAPI()
check_key = [Depends(verify_api_key)]
check_admin_key = [Depends(verify_admin_key)]
# Configure CORS settings to allow all origins, methods, and headers # Configure CORS settings to allow all origins, methods, and headers
app.add_middleware( app.add_middleware(
@ -72,12 +80,12 @@ app.add_middleware(
) )
@app.options("/") @app.options("/", dependencies=check_key)
async def options_route(): async def options_route():
return JSONResponse(content="OK") return JSONResponse(content="OK")
@app.post('/v1/completions', response_model=CompletionResponse) @app.post('/v1/completions', response_model=CompletionResponse, dependencies=check_key)
async def openai_completions(request: Request, request_data: CompletionRequest): async def openai_completions(request: Request, request_data: CompletionRequest):
path = request.url.path path = request.url.path
is_legacy = "/generate" in path is_legacy = "/generate" in path
@ -100,7 +108,7 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
return JSONResponse(response) return JSONResponse(response)
@app.post('/v1/chat/completions', response_model=ChatCompletionResponse) @app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key)
async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest): async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest):
path = request.url.path path = request.url.path
is_legacy = "/generate" in path is_legacy = "/generate" in path
@ -123,8 +131,8 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
return JSONResponse(response) return JSONResponse(response)
@app.get("/v1/models") @app.get("/v1/models", dependencies=check_key)
@app.get("/v1/models/{model}") @app.get("/v1/models/{model}", dependencies=check_key)
async def handle_models(request: Request): async def handle_models(request: Request):
path = request.url.path path = request.url.path
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models' is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
@ -138,7 +146,7 @@ async def handle_models(request: Request):
return JSONResponse(response) return JSONResponse(response)
@app.get('/v1/billing/usage') @app.get('/v1/billing/usage', dependencies=check_key)
def handle_billing_usage(): def handle_billing_usage():
''' '''
Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31 Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
@ -146,7 +154,7 @@ def handle_billing_usage():
return JSONResponse(content={"total_usage": 0}) return JSONResponse(content={"total_usage": 0})
@app.post('/v1/audio/transcriptions') @app.post('/v1/audio/transcriptions', dependencies=check_key)
async def handle_audio_transcription(request: Request): async def handle_audio_transcription(request: Request):
r = sr.Recognizer() r = sr.Recognizer()
@ -176,7 +184,7 @@ async def handle_audio_transcription(request: Request):
return JSONResponse(content=transcription) return JSONResponse(content=transcription)
@app.post('/v1/images/generations') @app.post('/v1/images/generations', dependencies=check_key)
async def handle_image_generation(request: Request): async def handle_image_generation(request: Request):
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
@ -192,7 +200,7 @@ async def handle_image_generation(request: Request):
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/embeddings", response_model=EmbeddingsResponse) @app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest): async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
input = request_data.input input = request_data.input
if not input: if not input:
@ -205,7 +213,7 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/moderations") @app.post("/v1/moderations", dependencies=check_key)
async def handle_moderations(request: Request): async def handle_moderations(request: Request):
body = await request.json() body = await request.json()
input = body["input"] input = body["input"]
@ -216,37 +224,37 @@ async def handle_moderations(request: Request):
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/encode", response_model=EncodeResponse) @app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
async def handle_token_encode(request_data: EncodeRequest): async def handle_token_encode(request_data: EncodeRequest):
response = token_encode(request_data.text) response = token_encode(request_data.text)
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/decode", response_model=DecodeResponse) @app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
async def handle_token_decode(request_data: DecodeRequest): async def handle_token_decode(request_data: DecodeRequest):
response = token_decode(request_data.tokens) response = token_decode(request_data.tokens)
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/token-count", response_model=TokenCountResponse) @app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
async def handle_token_count(request_data: EncodeRequest): async def handle_token_count(request_data: EncodeRequest):
response = token_count(request_data.text) response = token_count(request_data.text)
return JSONResponse(response) return JSONResponse(response)
@app.post("/v1/internal/stop-generation") @app.post("/v1/internal/stop-generation", dependencies=check_key)
async def handle_stop_generation(request: Request): async def handle_stop_generation(request: Request):
stop_everything_event() stop_everything_event()
return JSONResponse(content="OK") return JSONResponse(content="OK")
@app.get("/v1/internal/model/info", response_model=ModelInfoResponse) @app.get("/v1/internal/model/info", response_model=ModelInfoResponse, dependencies=check_key)
async def handle_model_info(): async def handle_model_info():
payload = OAImodels.get_current_model_info() payload = OAImodels.get_current_model_info()
return JSONResponse(content=payload) return JSONResponse(content=payload)
@app.post("/v1/internal/model/load") @app.post("/v1/internal/model/load", dependencies=check_admin_key)
async def handle_load_model(request_data: LoadModelRequest): async def handle_load_model(request_data: LoadModelRequest):
''' '''
This endpoint is experimental and may change in the future. This endpoint is experimental and may change in the future.
@ -283,7 +291,7 @@ async def handle_load_model(request_data: LoadModelRequest):
return HTTPException(status_code=400, detail="Failed to load the model.") return HTTPException(status_code=400, detail="Failed to load the model.")
@app.post("/v1/internal/model/unload") @app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model(): async def handle_unload_model():
unload_model() unload_model()
return JSONResponse(content="OK") return JSONResponse(content="OK")
@ -308,8 +316,14 @@ def run_server():
logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n') logger.info(f'OpenAI-compatible API URL:\n\nhttp://{server_addr}:{port}\n')
if shared.args.api_key: if shared.args.api_key:
if not shared.args.admin_key:
shared.args.admin_key = shared.args.api_key
logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n') logger.info(f'OpenAI API key:\n\n{shared.args.api_key}\n')
if shared.args.admin_key:
logger.info(f'OpenAI API admin key (for loading/unloading models):\n\n{shared.args.admin_key}\n')
uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile) uvicorn.run(app, host=server_addr, port=port, ssl_certfile=ssl_certfile, ssl_keyfile=ssl_keyfile)

View File

@ -170,6 +170,7 @@ parser.add_argument('--public-api', action='store_true', help='Create a public U
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None) parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.') parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
parser.add_argument('--api-key', type=str, default='', help='API authentication key.') parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
parser.add_argument('--admin-key', type=str, default='', help='API authentication key for admin tasks like loading and unloading models. If not set, will be the same as --api-key.')
# Multimodal # Multimodal
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')