From f7f1991e2c839e481a6bc5388d56b00afb38ca5a Mon Sep 17 00:00:00 2001 From: James Ravenscroft Date: Thu, 10 Aug 2023 09:26:54 +0100 Subject: [PATCH] add huggingface request handler and refactor old req handler --- include/turbopilot/server.hpp | 3 +- src/main.cpp | 12 +++++-- src/server.cpp | 68 +++++++++++++++++++++-------------- 3 files changed, 52 insertions(+), 31 deletions(-) diff --git a/include/turbopilot/server.hpp b/include/turbopilot/server.hpp index 71735e8..7000544 100644 --- a/include/turbopilot/server.hpp +++ b/include/turbopilot/server.hpp @@ -6,8 +6,9 @@ #include "crow_all.h" -crow::response serve_response(TurbopilotModel *model, const crow::request& req); +crow::response handle_openai_request(TurbopilotModel *model, const crow::request& req); +crow::response handle_hf_request(TurbopilotModel *model, const crow::request& req); #endif // __TURBOPILOT_SERVER_H diff --git a/src/main.cpp b/src/main.cpp index d2ce02e..32ddcf9 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -117,22 +117,28 @@ int main(int argc, char **argv) return res; }); + //huggingface code compatible endpoint + CROW_ROUTE(app, "/api/generate").methods(crow::HTTPMethod::Post) + ([&model](const crow::request& req) { + return handle_hf_request(model, req); + }); CROW_ROUTE(app, "/v1/completions").methods(crow::HTTPMethod::Post) ([&model](const crow::request& req) { - return serve_response(model, req); + return handle_openai_request(model, req); }); CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::Post) ([&model](const crow::request& req) { - return serve_response(model, req); + return handle_openai_request(model, req); }); CROW_ROUTE(app, "/v1/engines/copilot-codex/completions").methods(crow::HTTPMethod::Post) ([&model](const crow::request& req) { - return serve_response(model, req); + return handle_openai_request(model, req); }); + app.port(program.get("--port")).multithreaded().run(); diff --git a/src/server.cpp b/src/server.cpp index f4cdbaa..b373137 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -7,11 +7,51 @@ #include // generators #include // streaming operators etc. + + /** * This function serves requests for autocompletion from crow * */ -crow::response serve_response(TurbopilotModel *model, const crow::request& req){ +crow::response handle_hf_request(TurbopilotModel *model, const crow::request& req){ + + crow::json::rvalue data = crow::json::load(req.body); + + if(!data.has("inputs")){ + crow::response res; + res.code = 400; + res.set_header("Content-Type", "application/json"); + res.body = "{\"message\":\"you must specify inputs field or\"}"; + return res; + } + + // std::string suffix = data["suffix"].s(); + int maxTokens = 200; + if(data.has("max_tokens")){ + maxTokens = data["max_tokens"].i(); + } + + + auto result = model->predict(data["inputs"].s(), maxTokens, false); + + crow::json::wvalue response = { + {"generated_text", result.str()}, + }; + + + + crow::response res; + res.code = 200; + res.set_header("Content-Type", "application/json"); + res.body = response.dump(); //ss.str(); + return res; +} + +/** + * This function serves requests for autocompletion from crow + * +*/ +crow::response handle_openai_request(TurbopilotModel *model, const crow::request& req){ crow::json::rvalue data = crow::json::load(req.body); @@ -23,24 +63,6 @@ crow::response serve_response(TurbopilotModel *model, const crow::request& req){ return res; } - - - - // tokenize the prompt - // std::vector embd_inp; - - // if (data.has("prompt")) { - // std::string prompt = data["prompt"].s(); - // embd_inp = ::gpt_tokenize(vocab, prompt); - // } - // else { - // crow::json::rvalue input_ids = data["input_ids"]; - // for (auto id : input_ids.lo()) { - // embd_inp.push_back(id.i()); - // } - // } - - // std::string suffix = data["suffix"].s(); int maxTokens = 200; if(data.has("max_tokens")){ @@ -64,14 +86,6 @@ crow::response serve_response(TurbopilotModel *model, const crow::request& req){ crow::json::wvalue::list choices = {choice}; - // crow::json::wvalue usage = { - // {"completion_tokens", n_past}, - // // {"prompt_tokens", static_cast(embd_inp.size())}, - // {"prompt_tokens", 0}, - // {"total_tokens", static_cast(n_past - embd_inp.size())} - // }; - - crow::json::wvalue usage = { {"completion_tokens", 0}, // {"prompt_tokens", static_cast(embd_inp.size())},