mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-10-01 01:26:03 -04:00
Give API extension access to all generate_reply parameters (#744)
* Make every parameter of the generate_reply function parameterizable * Add stopping strings as parameterizable
This commit is contained in:
parent
9318e16ed5
commit
7aab88bcc6
@ -44,20 +44,21 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
generator = generate_reply(
|
generator = generate_reply(
|
||||||
question = prompt,
|
question = prompt,
|
||||||
max_new_tokens = int(body.get('max_length', 200)),
|
max_new_tokens = int(body.get('max_length', 200)),
|
||||||
do_sample=True,
|
do_sample=bool(body.get('do_sample', True)),
|
||||||
temperature=float(body.get('temperature', 0.5)),
|
temperature=float(body.get('temperature', 0.5)),
|
||||||
top_p=float(body.get('top_p', 1)),
|
top_p=float(body.get('top_p', 1)),
|
||||||
typical_p=float(body.get('typical', 1)),
|
typical_p=float(body.get('typical', 1)),
|
||||||
repetition_penalty=float(body.get('rep_pen', 1.1)),
|
repetition_penalty=float(body.get('rep_pen', 1.1)),
|
||||||
encoder_repetition_penalty=1,
|
encoder_repetition_penalty=1,
|
||||||
top_k=int(body.get('top_k', 0)),
|
top_k=int(body.get('top_k', 0)),
|
||||||
min_length=0,
|
min_length=int(body.get('min_length', 0)),
|
||||||
no_repeat_ngram_size=0,
|
no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
|
||||||
num_beams=1,
|
num_beams=int(body.get('num_beams',1)),
|
||||||
penalty_alpha=0,
|
penalty_alpha=float(body.get('penalty_alpha', 0)),
|
||||||
length_penalty=1,
|
length_penalty=float(body.get('length_penalty', 1)),
|
||||||
early_stopping=False,
|
early_stopping=bool(body.get('early_stopping', False)),
|
||||||
seed=-1,
|
seed=int(body.get('seed', -1)),
|
||||||
|
stopping_strings=body.get('stopping_strings', []),
|
||||||
)
|
)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
|
Loading…
Reference in New Issue
Block a user