mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-10-01 01:06:10 -04:00
Improves Java API signatures maintaining back compatibility
This commit is contained in:
parent
f39df0906e
commit
3c45a555e9
@ -8,9 +8,8 @@ import java.io.ByteArrayOutputStream;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class LLModel implements AutoCloseable {
|
||||
|
||||
@ -306,6 +305,197 @@ public class LLModel implements AutoCloseable {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* The array of messages for the conversation.
|
||||
*/
|
||||
public static class Messages {
|
||||
|
||||
private final List<PromptMessage> messages = new ArrayList<>();
|
||||
|
||||
public Messages(PromptMessage...messages) {
|
||||
this.messages.addAll(Arrays.asList(messages));
|
||||
}
|
||||
|
||||
public Messages(List<PromptMessage> messages) {
|
||||
this.messages.addAll(messages);
|
||||
}
|
||||
|
||||
public Messages addPromptMessage(PromptMessage promptMessage) {
|
||||
this.messages.add(promptMessage);
|
||||
return this;
|
||||
}
|
||||
|
||||
List<PromptMessage> toList() {
|
||||
return Collections.unmodifiableList(this.messages);
|
||||
}
|
||||
|
||||
List<Map<String, String>> toListMap() {
|
||||
return messages.stream()
|
||||
.map(PromptMessage::toMap).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* A message in the conversation, identical to OpenAI's chat message.
|
||||
*/
|
||||
public static class PromptMessage {
|
||||
|
||||
private static final String ROLE = "role";
|
||||
private static final String CONTENT = "content";
|
||||
|
||||
private final Map<String, String> message = new HashMap<>();
|
||||
|
||||
public PromptMessage() {
|
||||
}
|
||||
|
||||
public PromptMessage(Role role, String content) {
|
||||
addRole(role);
|
||||
addContent(content);
|
||||
}
|
||||
|
||||
public PromptMessage addRole(Role role) {
|
||||
return this.addParameter(ROLE, role.type());
|
||||
}
|
||||
|
||||
public PromptMessage addContent(String content) {
|
||||
return this.addParameter(CONTENT, content);
|
||||
}
|
||||
|
||||
public PromptMessage addParameter(String key, String value) {
|
||||
this.message.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public String content() {
|
||||
return this.parameter(CONTENT);
|
||||
}
|
||||
|
||||
public Role role() {
|
||||
String role = this.parameter(ROLE);
|
||||
return Role.from(role);
|
||||
}
|
||||
|
||||
public String parameter(String key) {
|
||||
return this.message.get(key);
|
||||
}
|
||||
|
||||
Map<String, String> toMap() {
|
||||
return Collections.unmodifiableMap(this.message);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public enum Role {
|
||||
|
||||
SYSTEM("system"), ASSISTANT("assistant"), USER("user");
|
||||
|
||||
private final String type;
|
||||
|
||||
String type() {
|
||||
return this.type;
|
||||
}
|
||||
|
||||
static Role from(String type) {
|
||||
|
||||
if (type == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case "system": return SYSTEM;
|
||||
case "assistant": return ASSISTANT;
|
||||
case "user": return USER;
|
||||
default: throw new IllegalArgumentException(
|
||||
String.format("You passed %s type but only %s are supported",
|
||||
type, Arrays.toString(Role.values())
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Role(String type) {
|
||||
this.type = type;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return type();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The result of the completion, similar to OpenAI's format.
|
||||
*/
|
||||
public static class CompletionReturn {
|
||||
private String model;
|
||||
private Usage usage;
|
||||
private Choices choices;
|
||||
|
||||
public CompletionReturn(String model, Usage usage, Choices choices) {
|
||||
this.model = model;
|
||||
this.usage = usage;
|
||||
this.choices = choices;
|
||||
}
|
||||
|
||||
public Choices choices() {
|
||||
return choices;
|
||||
}
|
||||
|
||||
public String model() {
|
||||
return model;
|
||||
}
|
||||
|
||||
public Usage usage() {
|
||||
return usage;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The generated completions.
|
||||
*/
|
||||
public static class Choices {
|
||||
|
||||
private final List<CompletionChoice> choices = new ArrayList<>();
|
||||
|
||||
public Choices(List<CompletionChoice> choices) {
|
||||
this.choices.addAll(choices);
|
||||
}
|
||||
|
||||
public Choices(CompletionChoice...completionChoices){
|
||||
this.choices.addAll(Arrays.asList(completionChoices));
|
||||
}
|
||||
|
||||
public Choices addCompletionChoice(CompletionChoice completionChoice) {
|
||||
this.choices.add(completionChoice);
|
||||
return this;
|
||||
}
|
||||
|
||||
public CompletionChoice first() {
|
||||
return this.choices.get(0);
|
||||
}
|
||||
|
||||
public int totalChoices() {
|
||||
return this.choices.size();
|
||||
}
|
||||
|
||||
public CompletionChoice get(int index) {
|
||||
return this.choices.get(index);
|
||||
}
|
||||
|
||||
public List<CompletionChoice> choices() {
|
||||
return Collections.unmodifiableList(choices);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A completion choice, similar to OpenAI's format.
|
||||
*/
|
||||
public static class CompletionChoice extends PromptMessage {
|
||||
public CompletionChoice(Role role, String content) {
|
||||
super(role, content);
|
||||
}
|
||||
}
|
||||
|
||||
public static class ChatCompletionResponse {
|
||||
public String model;
|
||||
@ -323,6 +513,41 @@ public class LLModel implements AutoCloseable {
|
||||
// Getters and setters
|
||||
}
|
||||
|
||||
public CompletionReturn chatCompletionResponse(Messages messages,
|
||||
GenerationConfig generationConfig) {
|
||||
return chatCompletion(messages, generationConfig, false, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* chatCompletion formats the existing chat conversation into a template to be
|
||||
* easier to process for chat UIs. It is not absolutely necessary as generate method
|
||||
* may be directly used to make generations with gpt models.
|
||||
*
|
||||
* @param messages object to create theMessages to send to GPT model
|
||||
* @param generationConfig How to decode/process the generation.
|
||||
* @param streamToStdOut Send tokens as they are calculated Standard output.
|
||||
* @param outputFullPromptToStdOut Should full prompt built out of messages be sent to Standard output.
|
||||
* @return CompletionReturn contains stats and generated Text.
|
||||
*/
|
||||
public CompletionReturn chatCompletion(Messages messages,
|
||||
GenerationConfig generationConfig, boolean streamToStdOut,
|
||||
boolean outputFullPromptToStdOut) {
|
||||
|
||||
String fullPrompt = buildPrompt(messages.toListMap());
|
||||
|
||||
if(outputFullPromptToStdOut)
|
||||
System.out.print(fullPrompt);
|
||||
|
||||
String generatedText = generate(fullPrompt, generationConfig, streamToStdOut);
|
||||
|
||||
final CompletionChoice promptMessage = new CompletionChoice(Role.ASSISTANT, generatedText);
|
||||
final Choices choices = new Choices(promptMessage);
|
||||
|
||||
final Usage usage = getUsage(fullPrompt, generatedText);
|
||||
return new CompletionReturn(this.modelName, usage, choices);
|
||||
|
||||
}
|
||||
|
||||
public ChatCompletionResponse chatCompletion(List<Map<String, String>> messages,
|
||||
GenerationConfig generationConfig) {
|
||||
return chatCompletion(messages, generationConfig, false, false);
|
||||
@ -352,19 +577,23 @@ public class LLModel implements AutoCloseable {
|
||||
ChatCompletionResponse response = new ChatCompletionResponse();
|
||||
response.model = this.modelName;
|
||||
|
||||
Usage usage = new Usage();
|
||||
usage.promptTokens = fullPrompt.length();
|
||||
usage.completionTokens = generatedText.length();
|
||||
usage.totalTokens = fullPrompt.length() + generatedText.length();
|
||||
response.usage = usage;
|
||||
response.usage = getUsage(fullPrompt, generatedText);
|
||||
|
||||
Map<String, String> message = new HashMap<>();
|
||||
message.put("role", "assistant");
|
||||
message.put("content", generatedText);
|
||||
|
||||
response.choices = List.of(message);
|
||||
|
||||
return response;
|
||||
|
||||
}
|
||||
|
||||
private Usage getUsage(String fullPrompt, String generatedText) {
|
||||
Usage usage = new Usage();
|
||||
usage.promptTokens = fullPrompt.length();
|
||||
usage.completionTokens = generatedText.length();
|
||||
usage.totalTokens = fullPrompt.length() + generatedText.length();
|
||||
return usage;
|
||||
}
|
||||
|
||||
protected static String buildPrompt(List<Map<String, String>> messages) {
|
||||
|
@ -28,6 +28,33 @@ import static org.mockito.Mockito.*;
|
||||
@ExtendWith(MockitoExtension.class)
|
||||
public class BasicTests {
|
||||
|
||||
@Test
|
||||
public void simplePromptWithObject(){
|
||||
|
||||
LLModel model = Mockito.spy(new LLModel());
|
||||
|
||||
LLModel.GenerationConfig config =
|
||||
LLModel.config()
|
||||
.withNPredict(20)
|
||||
.build();
|
||||
|
||||
// The generate method will return "4"
|
||||
doReturn("4").when( model ).generate(anyString(), eq(config), eq(true));
|
||||
|
||||
LLModel.PromptMessage promptMessage1 = new LLModel.PromptMessage(LLModel.Role.SYSTEM, "You are a helpful assistant");
|
||||
LLModel.PromptMessage promptMessage2 = new LLModel.PromptMessage(LLModel.Role.USER, "Add 2+2");
|
||||
|
||||
LLModel.Messages messages = new LLModel.Messages(promptMessage1, promptMessage2);
|
||||
|
||||
LLModel.CompletionReturn response = model.chatCompletion(
|
||||
messages, config, true, true);
|
||||
|
||||
assertTrue( response.choices().first().content().contains("4") );
|
||||
|
||||
// Verifies the prompt and response are certain length.
|
||||
assertEquals( 224 , response.usage().totalTokens );
|
||||
}
|
||||
|
||||
@Test
|
||||
public void simplePrompt(){
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user