Add /v1/internal/lora endpoints (#4652)

This commit is contained in:
oobabooga 2023-11-19 00:35:22 -03:00 committed by GitHub
parent ef6feedeb2
commit 771e62e476
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 72 additions and 19 deletions

View File

@ -1,8 +1,9 @@
from modules import shared
from modules.logging_colors import logger
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_models
from modules.utils import get_available_loras, get_available_models
def get_current_model_info():
@ -13,12 +14,17 @@ def get_current_model_info():
def list_models():
return {'model_names': get_available_models()[1:]}
def list_dummy_models():
result = {
"object": "list",
"data": []
}
for model in get_dummy_models() + get_available_models()[1:]:
# these are expected by so much, so include some here as a dummy
for model in ['gpt-3.5-turbo', 'text-embedding-ada-002']:
result["data"].append(model_info_dict(model))
return result
@ -33,13 +39,6 @@ def model_info_dict(model_name: str) -> dict:
}
def get_dummy_models() -> list:
return [ # these are expected by so much, so include some here as a dummy
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
def _load_model(data):
model_name = data["model_name"]
args = data["args"]
@ -67,3 +66,15 @@ def _load_model(data):
logger.info(f"TRUNCATION LENGTH (UPDATED): {shared.settings['truncation_length']}")
elif k == 'instruction_template':
logger.info(f"INSTRUCTION TEMPLATE (UPDATED): {shared.settings['instruction_template']}")
def list_loras():
return {'lora_names': get_available_loras()[1:]}
def load_loras(lora_names):
add_lora_to_model(lora_names)
def unload_all_loras():
add_lora_to_model([])

View File

@ -38,10 +38,13 @@ from .typing import (
EmbeddingsResponse,
EncodeRequest,
EncodeResponse,
LoadLorasRequest,
LoadModelRequest,
LogitsRequest,
LogitsResponse,
LoraListResponse,
ModelInfoResponse,
ModelListResponse,
TokenCountResponse,
to_dict
)
@ -141,7 +144,7 @@ async def handle_models(request: Request):
is_list = request.url.path.split('?')[0].split('#')[0] == '/v1/models'
if is_list:
response = OAImodels.list_models()
response = OAImodels.list_dummy_models()
else:
model_name = path[len('/v1/models/'):]
response = OAImodels.model_info_dict(model_name)
@ -267,6 +270,12 @@ async def handle_model_info():
return JSONResponse(content=payload)
@app.get("/v1/internal/model/list", response_model=ModelListResponse, dependencies=check_admin_key)
async def handle_list_models():
payload = OAImodels.list_models()
return JSONResponse(content=payload)
@app.post("/v1/internal/model/load", dependencies=check_admin_key)
async def handle_load_model(request_data: LoadModelRequest):
'''
@ -307,6 +316,27 @@ async def handle_load_model(request_data: LoadModelRequest):
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model():
unload_model()
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
async def handle_list_loras():
response = OAImodels.list_loras()
return JSONResponse(content=response)
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
async def handle_load_loras(request_data: LoadLorasRequest):
try:
OAImodels.load_loras(request_data.lora_names)
return JSONResponse(content="OK")
except:
traceback.print_exc()
return HTTPException(status_code=400, detail="Failed to apply the LoRA(s).")
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
async def handle_unload_loras():
OAImodels.unload_all_loras()
return JSONResponse(content="OK")

View File

@ -122,6 +122,19 @@ class ChatCompletionResponse(BaseModel):
usage: dict
class EmbeddingsRequest(BaseModel):
input: str | List[str]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
encoding_format: str = Field(default="float", description="Can be float or base64.")
user: str | None = Field(default=None, description="Unused parameter.")
class EmbeddingsResponse(BaseModel):
index: int
embedding: List[float]
object: str = "embedding"
class EncodeRequest(BaseModel):
text: str
@ -166,23 +179,22 @@ class ModelInfoResponse(BaseModel):
lora_names: List[str]
class ModelListResponse(BaseModel):
model_names: List[str]
class LoadModelRequest(BaseModel):
model_name: str
args: dict | None = None
settings: dict | None = None
class EmbeddingsRequest(BaseModel):
input: str | List[str]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
encoding_format: str = Field(default="float", description="Can be float or base64.")
user: str | None = Field(default=None, description="Unused parameter.")
class LoraListResponse(BaseModel):
lora_names: List[str]
class EmbeddingsResponse(BaseModel):
index: int
embedding: List[float]
object: str = "embedding"
class LoadLorasRequest(BaseModel):
lora_names: List[str]
def to_json(obj):