Added optional top_p and top_k

This commit is contained in:
Andriy Mulyar 2023-08-15 12:06:49 -04:00
parent 2c0ee50dce
commit a9668eb2e4

View File

@ -2,7 +2,7 @@ import json
from typing import List, Dict, Iterable, AsyncIterable from typing import List, Dict, Iterable, AsyncIterable
import logging import logging
import time import time
from typing import Dict, List, Union from typing import Dict, List, Union, Optional
from uuid import uuid4 from uuid import uuid4
import aiohttp import aiohttp
import asyncio import asyncio
@ -24,8 +24,8 @@ class CompletionRequest(BaseModel):
prompt: Union[List[str], str] = Field(..., description='The prompt to begin completing from.') prompt: Union[List[str], str] = Field(..., description='The prompt to begin completing from.')
max_tokens: int = Field(None, description='Max tokens to generate') max_tokens: int = Field(None, description='Max tokens to generate')
temperature: float = Field(settings.temp, description='Model temperature') temperature: float = Field(settings.temp, description='Model temperature')
top_p: float = Field(settings.top_p, description='top_p') top_p: Optional[float] = Field(settings.top_p, description='top_p')
top_k: int = Field(settings.top_k, description='top_k') top_k: Optional[int] = Field(settings.top_k, description='top_k')
n: int = Field(1, description='How many completions to generate for each prompt') n: int = Field(1, description='How many completions to generate for each prompt')
stream: bool = Field(False, description='Stream responses') stream: bool = Field(False, description='Stream responses')
repeat_penalty: float = Field(settings.repeat_penalty, description='Repeat penalty') repeat_penalty: float = Field(settings.repeat_penalty, description='Repeat penalty')