mirror of
https://github.com/tatsu-lab/stanford_alpaca.git
synced 2024-10-01 05:35:37 -04:00
174 lines
6.3 KiB
Python
174 lines
6.3 KiB
Python
|
import dataclasses
|
||
|
import logging
|
||
|
import math
|
||
|
import os
|
||
|
import io
|
||
|
import sys
|
||
|
import time
|
||
|
import json
|
||
|
from typing import Optional, Sequence, Union
|
||
|
|
||
|
import openai
|
||
|
import tqdm
|
||
|
from openai import openai_object
|
||
|
import copy
|
||
|
|
||
|
StrOrOpenAIObject = Union[str, openai_object.OpenAIObject]
|
||
|
|
||
|
openai_org = os.getenv("OPENAI_ORG")
|
||
|
if openai_org is not None:
|
||
|
openai.organization = openai_org
|
||
|
logging.warning(f"Switching to organization: {openai_org} for OAI API key.")
|
||
|
|
||
|
|
||
|
@dataclasses.dataclass
|
||
|
class OpenAIDecodingArguments(object):
|
||
|
max_tokens: int = 1800
|
||
|
temperature: float = 0.2
|
||
|
top_p: float = 1.0
|
||
|
n: int = 1
|
||
|
stream: bool = False
|
||
|
stop: Optional[Sequence[str]] = None
|
||
|
presence_penalty: float = 0.0
|
||
|
frequency_penalty: float = 0.0
|
||
|
suffix: Optional[str] = None
|
||
|
logprobs: Optional[int] = None
|
||
|
echo: bool = False
|
||
|
|
||
|
|
||
|
def openai_completion(
|
||
|
prompts: Union[str, Sequence[str], Sequence[dict[str, str]], dict[str, str]],
|
||
|
decoding_args: OpenAIDecodingArguments,
|
||
|
model_name="text-davinci-003",
|
||
|
sleep_time=2,
|
||
|
batch_size=1,
|
||
|
max_instances=sys.maxsize,
|
||
|
max_batches=sys.maxsize,
|
||
|
return_text=False,
|
||
|
**decoding_kwargs,
|
||
|
) -> Union[Union[StrOrOpenAIObject], Sequence[StrOrOpenAIObject], Sequence[Sequence[StrOrOpenAIObject]],]:
|
||
|
"""Decode with OpenAI API.
|
||
|
|
||
|
Args:
|
||
|
prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted
|
||
|
as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model
|
||
|
it can also be a dictionary (or list thereof) as explained here:
|
||
|
https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||
|
decoding_args: Decoding arguments.
|
||
|
model_name: Model name. Can be either in the format of "org/model" or just "model".
|
||
|
sleep_time: Time to sleep once the rate-limit is hit.
|
||
|
batch_size: Number of prompts to send in a single request. Only for non chat model.
|
||
|
max_instances: Maximum number of prompts to decode.
|
||
|
max_batches: Maximum number of batches to decode. This argument will be deprecated in the future.
|
||
|
return_text: If True, return text instead of full completion object (which contains things like logprob).
|
||
|
decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them.
|
||
|
|
||
|
Returns:
|
||
|
A completion or a list of completions.
|
||
|
Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of
|
||
|
- a string (if return_text is True)
|
||
|
- an openai_object.OpenAIObject object (if return_text is False)
|
||
|
- a list of objects of the above types (if decoding_args.n > 1)
|
||
|
"""
|
||
|
is_single_prompt = isinstance(prompts, (str, dict))
|
||
|
if is_single_prompt:
|
||
|
prompts = [prompts]
|
||
|
|
||
|
if max_batches < sys.maxsize:
|
||
|
logging.warning(
|
||
|
"`max_batches` will be deprecated in the future, please use `max_instances` instead."
|
||
|
"Setting `max_instances` to `max_batches * batch_size` for now."
|
||
|
)
|
||
|
max_instances = max_batches * batch_size
|
||
|
|
||
|
prompts = prompts[:max_instances]
|
||
|
num_prompts = len(prompts)
|
||
|
prompt_batches = [
|
||
|
prompts[batch_id * batch_size : (batch_id + 1) * batch_size]
|
||
|
for batch_id in range(int(math.ceil(num_prompts / batch_size)))
|
||
|
]
|
||
|
|
||
|
completions = []
|
||
|
for batch_id, prompt_batch in tqdm.tqdm(
|
||
|
enumerate(prompt_batches),
|
||
|
desc="prompt_batches",
|
||
|
total=len(prompt_batches),
|
||
|
):
|
||
|
batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args
|
||
|
|
||
|
while True:
|
||
|
try:
|
||
|
shared_kwargs = dict(
|
||
|
model=model_name,
|
||
|
**batch_decoding_args.__dict__,
|
||
|
**decoding_kwargs,
|
||
|
)
|
||
|
completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs)
|
||
|
choices = completion_batch.choices
|
||
|
|
||
|
for choice in choices:
|
||
|
choice["total_tokens"] = completion_batch.usage.total_tokens
|
||
|
completions.extend(choices)
|
||
|
break
|
||
|
except openai.error.OpenAIError as e:
|
||
|
logging.warning(f"OpenAIError: {e}.")
|
||
|
if "Please reduce your prompt" in str(e):
|
||
|
batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8)
|
||
|
logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...")
|
||
|
else:
|
||
|
logging.warning("Hit request rate limit; retrying...")
|
||
|
time.sleep(sleep_time) # Annoying rate limit on requests.
|
||
|
|
||
|
if return_text:
|
||
|
completions = [completion.text for completion in completions]
|
||
|
if decoding_args.n > 1:
|
||
|
# make completions a nested list, where each entry is a consecutive decoding_args.n of original entries.
|
||
|
completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)]
|
||
|
if is_single_prompt:
|
||
|
# Return non-tuple if only 1 input and 1 generation.
|
||
|
(completions,) = completions
|
||
|
return completions
|
||
|
|
||
|
|
||
|
def _make_w_io_base(f, mode: str):
|
||
|
if not isinstance(f, io.IOBase):
|
||
|
f_dirname = os.path.dirname(f)
|
||
|
if f_dirname != "":
|
||
|
os.makedirs(f_dirname, exist_ok=True)
|
||
|
f = open(f, mode=mode)
|
||
|
return f
|
||
|
|
||
|
|
||
|
def _make_r_io_base(f, mode: str):
|
||
|
if not isinstance(f, io.IOBase):
|
||
|
f = open(f, mode=mode)
|
||
|
return f
|
||
|
|
||
|
|
||
|
def jdump(obj, f, mode="w", indent=4, default=str):
|
||
|
"""Dump a str or dictionary to a file in json format.
|
||
|
|
||
|
Args:
|
||
|
obj: An object to be written.
|
||
|
f: A string path to the location on disk.
|
||
|
mode: Mode for opening the file.
|
||
|
indent: Indent for storing json dictionaries.
|
||
|
default: A function to handle non-serializable entries; defaults to `str`.
|
||
|
"""
|
||
|
f = _make_w_io_base(f, mode)
|
||
|
if isinstance(obj, (dict, list)):
|
||
|
json.dump(obj, f, indent=indent, default=default)
|
||
|
elif isinstance(obj, str):
|
||
|
f.write(obj)
|
||
|
else:
|
||
|
raise ValueError(f"Unexpected type: {type(obj)}")
|
||
|
f.close()
|
||
|
|
||
|
|
||
|
def jload(f, mode="r"):
|
||
|
"""Load a .json file into a dictionary."""
|
||
|
f = _make_r_io_base(f, mode)
|
||
|
jdict = json.load(f)
|
||
|
f.close()
|
||
|
return jdict
|