Merge pull request #44 from ravenscroftj/feature/hf-code

Feature/hf code
This commit is contained in:
James Ravenscroft 2023-08-10 10:08:20 +01:00 committed by GitHub
commit c0cde10046
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 31 deletions

View File

@ -6,8 +6,9 @@
#include "crow_all.h"
crow::response serve_response(TurbopilotModel *model, const crow::request& req);
crow::response handle_openai_request(TurbopilotModel *model, const crow::request& req);
crow::response handle_hf_request(TurbopilotModel *model, const crow::request& req);
#endif // __TURBOPILOT_SERVER_H

View File

@ -117,22 +117,28 @@ int main(int argc, char **argv)
return res;
});
//huggingface code compatible endpoint
CROW_ROUTE(app, "/api/generate").methods(crow::HTTPMethod::Post)
([&model](const crow::request& req) {
return handle_hf_request(model, req);
});
CROW_ROUTE(app, "/v1/completions").methods(crow::HTTPMethod::Post)
([&model](const crow::request& req) {
return serve_response(model, req);
return handle_openai_request(model, req);
});
CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::Post)
([&model](const crow::request& req) {
return serve_response(model, req);
return handle_openai_request(model, req);
});
CROW_ROUTE(app, "/v1/engines/copilot-codex/completions").methods(crow::HTTPMethod::Post)
([&model](const crow::request& req) {
return serve_response(model, req);
return handle_openai_request(model, req);
});
app.port(program.get<int>("--port")).multithreaded().run();

View File

@ -7,11 +7,51 @@
#include <boost/uuid/uuid_generators.hpp> // generators
#include <boost/uuid/uuid_io.hpp> // streaming operators etc.
/**
* This function serves requests for autocompletion from crow
*
*/
crow::response serve_response(TurbopilotModel *model, const crow::request& req){
crow::response handle_hf_request(TurbopilotModel *model, const crow::request& req){
crow::json::rvalue data = crow::json::load(req.body);
if(!data.has("inputs")){
crow::response res;
res.code = 400;
res.set_header("Content-Type", "application/json");
res.body = "{\"message\":\"you must specify inputs field or\"}";
return res;
}
// std::string suffix = data["suffix"].s();
int maxTokens = 200;
if(data.has("max_tokens")){
maxTokens = data["max_tokens"].i();
}
auto result = model->predict(data["inputs"].s(), maxTokens, false);
crow::json::wvalue response = {
{"generated_text", result.str()},
};
crow::response res;
res.code = 200;
res.set_header("Content-Type", "application/json");
res.body = response.dump(); //ss.str();
return res;
}
/**
* This function serves requests for autocompletion from crow
*
*/
crow::response handle_openai_request(TurbopilotModel *model, const crow::request& req){
crow::json::rvalue data = crow::json::load(req.body);
@ -23,24 +63,6 @@ crow::response serve_response(TurbopilotModel *model, const crow::request& req){
return res;
}
// tokenize the prompt
// std::vector<gpt_vocab::id> embd_inp;
// if (data.has("prompt")) {
// std::string prompt = data["prompt"].s();
// embd_inp = ::gpt_tokenize(vocab, prompt);
// }
// else {
// crow::json::rvalue input_ids = data["input_ids"];
// for (auto id : input_ids.lo()) {
// embd_inp.push_back(id.i());
// }
// }
// std::string suffix = data["suffix"].s();
int maxTokens = 200;
if(data.has("max_tokens")){
@ -64,14 +86,6 @@ crow::response serve_response(TurbopilotModel *model, const crow::request& req){
crow::json::wvalue::list choices = {choice};
// crow::json::wvalue usage = {
// {"completion_tokens", n_past},
// // {"prompt_tokens", static_cast<std::uint64_t>(embd_inp.size())},
// {"prompt_tokens", 0},
// {"total_tokens", static_cast<std::uint64_t>(n_past - embd_inp.size())}
// };
crow::json::wvalue usage = {
{"completion_tokens", 0},
// {"prompt_tokens", static_cast<std::uint64_t>(embd_inp.size())},