fixed code

This commit is contained in:
Saifeddine ALOUI 2024-09-05 01:54:52 +02:00
parent f7e0742a50
commit 7a0290d6b6

View File

@ -186,6 +186,8 @@ class LollmsClient {
switch (this.default_generation_mode) {
case ELF_GENERATION_FORMAT.LOLLMS:
return this.lollms_generate_with_images(prompt, images, this.host_address, this.model_name, -1, n_predict, stream, temperature, top_k, top_p, repeat_penalty, repeat_last_n, seed, n_threads, service_key, streamingCallback);
case ELF_GENERATION_FORMAT.OPENAI:
return this.openai_generate_with_images(prompt, images, this.host_address, this.model_name, -1, n_predict, stream, temperature, top_k, top_p, repeat_penalty, repeat_last_n, seed, n_threads, ELF_COMPLETION_FORMAT.INSTRUCT, service_key, streamingCallback);
default:
throw new Error('Invalid generation mode');
}
@ -385,6 +387,330 @@ async openai_generate(prompt, host_address = this.host_address, model_name = thi
}
}
async openai_generate_with_images(prompt, images, options = {}) {
const {
host_address = this.host_address,
model_name = this.model_name,
personality = this.personality,
n_predict = this.n_predict,
stream = false,
temperature = this.temperature,
top_k = this.top_k,
top_p = this.top_p,
repeat_penalty = this.repeat_penalty,
repeat_last_n = this.repeat_last_n,
seed = this.seed,
n_threads = this.n_threads,
max_image_width = -1,
service_key = this.service_key,
streamingCallback = null,
} = options;
const headers = {
'Content-Type': 'application/json',
...(service_key ? { 'Authorization': `Bearer ${service_key}` } : {})
};
const data = {
model: model_name,
messages: [
{
role: "user",
content: [
{
type: "text",
text: prompt
},
...images.map(image_path => ({
type: "image_url",
image_url: {
url: `data:image/jpeg;base64,${this.encode_image(image_path, max_image_width)}`
}
}))
]
}
],
stream: true,
temperature: parseFloat(temperature),
max_tokens: n_predict
};
const completion_format_path = "/v1/chat/completions";
const url = `${host_address.endsWith("/") ? host_address.slice(0, -1) : host_address}${completion_format_path}`;
try {
const response = await fetch(url, {
method: 'POST',
headers: headers,
body: JSON.stringify(data)
});
if (!response.ok) {
const content = await response.json();
if (response.status === 400) {
this.error(content.error?.message || content.message);
} else if (response.status === 404) {
console.error(await response.text());
}
return;
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
let text = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n');
for (const line of lines) {
if (line.startsWith("data: ")) {
try {
const jsonData = JSON.parse(line.slice(5).trim());
const content = jsonData.choices[0]?.delta?.content || "";
text += content;
if (streamingCallback) {
if (!streamingCallback(content, "MSG_TYPE_CHUNK")) {
return text;
}
}
} catch (error) {
// Handle JSON parsing error
}
} else if (line.startsWith("{")) {
try {
const jsonData = JSON.parse(line);
if (jsonData.object === "error") {
this.error(jsonData.message);
return text;
}
} catch (error) {
this.error("Couldn't generate text, verify your key or model name");
}
} else {
text += line;
if (streamingCallback) {
if (!streamingCallback(line, "MSG_TYPE_CHUNK")) {
return text;
}
}
}
}
}
return text;
} catch (error) {
console.error("Error in openai_generate_with_images:", error);
throw error;
}
}
async encode_image(image_path, max_image_width = -1) {
// In a browser environment, we'll use the File API and canvas
// For Node.js, you'd need to use libraries like 'sharp' or 'jimp'
return new Promise((resolve, reject) => {
const img = new Image();
img.onload = () => {
let width = img.width;
let height = img.height;
// Resize if necessary
if (max_image_width !== -1 && width > max_image_width) {
const ratio = max_image_width / width;
width = max_image_width;
height = Math.round(height * ratio);
}
// Create a canvas to draw the image
const canvas = document.createElement('canvas');
canvas.width = width;
canvas.height = height;
const ctx = canvas.getContext('2d');
// Draw the image on the canvas
ctx.drawImage(img, 0, 0, width, height);
// Convert to base64
const base64Image = canvas.toDataURL('image/jpeg').split(',')[1];
resolve(base64Image);
};
img.onerror = (error) => reject(error);
// Load the image from the provided path
img.src = image_path;
});
}
async generateCode(prompt, images = [], {
n_predict = null,
stream = false,
temperature = 0.1,
top_k = 50,
top_p = 0.95,
repeat_penalty = 0.8,
repeat_last_n = 40,
seed = null,
n_threads = 8,
service_key = "",
streamingCallback = null
} = {}){
let response;
const systemHeader = this.custom_message("Generation infos");
const codeInstructions = "Generated code must be put inside the adequate markdown code tag. Use this template:\n```language name\nCode\n```\nMake sure only a single code tag is generated at each dialogue turn.";
const fullPrompt = systemHeader + codeInstructions + this.separatorTemplate + prompt;
if (images.length > 0) {
response = await this.generate_with_images(fullPrompt, images, {
n_predict: n_predict,
temperature: temperature,
top_k: top_k,
top_p: top_p,
repeat_penalty: repeat_penalty,
repeat_last_n: repeat_last_n,
callback: streamingCallback
});
} else {
response = await this.generate(fullPrompt, {
n_predict: n_predict,
temperature: temperature,
top_k: top_k,
top_p: top_p,
repeat_penalty: repeat_penalty,
repeat_last_n: repeat_last_n,
callback: streamingCallback
});
}
const codes = this.extractCodeBlocks(response);
if (codes.length > 0) {
let code = '';
if (!codes[0].is_complete) {
code = codes[0].content.split('\n').slice(0, -1).join('\n');
while (!codes[0].is_complete) {
console.warn("The AI did not finish the code, let's ask it to continue")
const continuePrompt = prompt + code + this.userFullHeader + "continue the code. Rewrite last line and continue the code." + this.separatorTemplate + this.aiFullHeader;
response = await this.generate(fullPrompt, {
n_predict: n_predict,
temperature: temperature,
top_k: top_k,
top_p: top_p,
repeat_penalty: repeat_penalty,
repeat_last_n: repeat_last_n,
callback: streamingCallback
});
const newCodes = this.extractCodeBlocks(response);
if (newCodes.length === 0) break;
if (!newCodes[0].is_complete) {
code += '\n' + newCodes[0].content.split('\n').slice(0, -1).join('\n');
} else {
code += '\n' + newCodes[0].content;
}
}
} else {
code = codes[0].content;
}
return code;
} else {
return null;
}
}
extractCodeBlocks(text) {
const codeBlocks = [];
let remaining = text;
let blocIndex = 0;
let firstIndex = 0;
const indices = [];
// Find all code block delimiters
while (remaining.length > 0) {
const index = remaining.indexOf("```");
if (index === -1) {
if (blocIndex % 2 === 1) {
indices.push(remaining.length + firstIndex);
}
break;
}
indices.push(index + firstIndex);
remaining = remaining.slice(index + 3);
firstIndex += index + 3;
blocIndex++;
}
let isStart = true;
for (let i = 0; i < indices.length; i++) {
if (isStart) {
const blockInfo = {
index: i,
file_name: "",
section: "",
content: "",
type: "",
is_complete: false
};
// Check for file name in preceding line
const precedingText = text.slice(0, indices[i]).trim().split('\n');
if (precedingText.length > 0) {
const lastLine = precedingText[precedingText.length - 1].trim();
if (lastLine.startsWith("<file_name>") && lastLine.endsWith("</file_name>")) {
blockInfo.file_name = lastLine.slice("<file_name>".length, -"</file_name>".length).trim();
} else if (lastLine.startsWith("## filename:")) {
blockInfo.file_name = lastLine.slice("## filename:".length).trim();
}
if (lastLine.startsWith("<section>") && lastLine.endsWith("</section>")) {
blockInfo.section = lastLine.slice("<section>".length, -"</section>".length).trim();
}
}
const subText = text.slice(indices[i] + 3);
if (subText.length > 0) {
const findSpace = subText.indexOf(" ");
const findReturn = subText.indexOf("\n");
let nextIndex = Math.min(findSpace === -1 ? Infinity : findSpace, findReturn === -1 ? Infinity : findReturn);
if (subText.slice(0, nextIndex).includes('{')) {
nextIndex = 0;
}
const startPos = nextIndex;
if (text[indices[i] + 3] === "\n" || text[indices[i] + 3] === " " || text[indices[i] + 3] === "\t") {
blockInfo.type = 'language-specific';
} else {
blockInfo.type = subText.slice(0, nextIndex);
}
if (i + 1 < indices.length) {
const nextPos = indices[i + 1] - indices[i];
if (nextPos - 3 < subText.length && subText[nextPos - 3] === "`") {
blockInfo.content = subText.slice(startPos, nextPos - 3).trim();
blockInfo.is_complete = true;
} else {
blockInfo.content = subText.slice(startPos, nextPos).trim();
blockInfo.is_complete = false;
}
} else {
blockInfo.content = subText.slice(startPos).trim();
blockInfo.is_complete = false;
}
codeBlocks.push(blockInfo);
}
isStart = false;
} else {
isStart = true;
}
}
return codeBlocks;
}
async listMountedPersonalities(host_address = this.host_address) {
const url = `${host_address}/list_mounted_personalities`;
@ -785,34 +1111,95 @@ buildPrompt(promptParts, sacrificeId = -1, contextSize = null, minimumSpareConte
}
extractCodeBlocks(text) {
const codeBlockRegex = /```([\s\S]*?)```/g;
const codeBlocks = [];
let match;
let index = 0;
let remaining = text;
let blocIndex = 0;
let firstIndex = 0;
const indices = [];
while ((match = codeBlockRegex.exec(text)) !== null) {
const [fullMatch, content] = match;
const blockLines = content.trim().split('\n');
let type = 'language-specific';
let blockContent = content.trim();
// Check if the first line is a language specifier
if (blockLines.length > 1 && blockLines[0].trim().length > 0 && !blockLines[0].includes(' ')) {
type = blockLines[0].trim().toLowerCase();
blockContent = blockLines.slice(1).join('\n').trim();
// Find all code block delimiters
while (remaining.length > 0) {
const index = remaining.indexOf("```");
if (index === -1) {
if (blocIndex % 2 === 1) {
indices.push(remaining.length + firstIndex);
}
break;
}
indices.push(index + firstIndex);
remaining = remaining.slice(index + 3);
firstIndex += index + 3;
blocIndex++;
}
codeBlocks.push({
index: index++,
file_name: '',
content: blockContent,
type: type
});
let isStart = true;
for (let i = 0; i < indices.length; i++) {
if (isStart) {
const blockInfo = {
index: i,
file_name: "",
section: "",
content: "",
type: "",
is_complete: false
};
// Check for file name in preceding line
const precedingText = text.slice(0, indices[i]).trim().split('\n');
if (precedingText.length > 0) {
const lastLine = precedingText[precedingText.length - 1].trim();
if (lastLine.startsWith("<file_name>") && lastLine.endsWith("</file_name>")) {
blockInfo.file_name = lastLine.slice("<file_name>".length, -"</file_name>".length).trim();
} else if (lastLine.startsWith("## filename:")) {
blockInfo.file_name = lastLine.slice("## filename:".length).trim();
}
if (lastLine.startsWith("<section>") && lastLine.endsWith("</section>")) {
blockInfo.section = lastLine.slice("<section>".length, -"</section>".length).trim();
}
}
const subText = text.slice(indices[i] + 3);
if (subText.length > 0) {
const findSpace = subText.indexOf(" ");
const findReturn = subText.indexOf("\n");
let nextIndex = Math.min(findSpace === -1 ? Infinity : findSpace, findReturn === -1 ? Infinity : findReturn);
if (subText.slice(0, nextIndex).includes('{')) {
nextIndex = 0;
}
const startPos = nextIndex;
if (text[indices[i] + 3] === "\n" || text[indices[i] + 3] === " " || text[indices[i] + 3] === "\t") {
blockInfo.type = 'language-specific';
} else {
blockInfo.type = subText.slice(0, nextIndex);
}
if (i + 1 < indices.length) {
const nextPos = indices[i + 1] - indices[i];
if (nextPos - 3 < subText.length && subText[nextPos - 3] === "`") {
blockInfo.content = subText.slice(startPos, nextPos - 3).trim();
blockInfo.is_complete = true;
} else {
blockInfo.content = subText.slice(startPos, nextPos).trim();
blockInfo.is_complete = false;
}
} else {
blockInfo.content = subText.slice(startPos).trim();
blockInfo.is_complete = false;
}
codeBlocks.push(blockInfo);
}
isStart = false;
} else {
isStart = true;
}
}
return codeBlocks;
}
/**
* Updates the given code based on the provided query string.
* The query string can contain two types of modifications: