mirror of
https://github.com/ravenscroftj/turbopilot.git
synced 2024-07-01 00:31:27 +00:00
add server component
This commit is contained in:
parent
887d348188
commit
fd3a127aaa
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -7,3 +7,6 @@
|
|||
[submodule "extern/sbdlog"]
|
||||
path = extern/spdlog
|
||||
url = https://github.com/gabime/spdlog.git
|
||||
[submodule "extern/crow"]
|
||||
path = extern/crow
|
||||
url = https://github.com/CrowCpp/Crow.git
|
||||
|
|
|
@ -8,6 +8,7 @@ set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
|||
add_subdirectory(extern/ggml)
|
||||
add_subdirectory(extern/argparse)
|
||||
add_subdirectory(extern/spdlog)
|
||||
add_subdirectory(extern/crow)
|
||||
add_subdirectory(src)
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
1
extern/crow
vendored
Submodule
1
extern/crow
vendored
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 4f3f5deaaa01825c63c83431bfa96ccec195f741
|
|
@ -71,7 +71,7 @@ public:
|
|||
}
|
||||
virtual ~GPTJModel();
|
||||
bool load_model(std::string path);
|
||||
virtual std::stringstream predict(std::string prompt, int max_length);
|
||||
virtual std::stringstream predict(std::string prompt, int max_length, bool include_prompt);
|
||||
|
||||
private:
|
||||
gptj_model *model = NULL;
|
||||
|
|
|
@ -51,7 +51,7 @@ public:
|
|||
rng(rng)
|
||||
{}
|
||||
virtual bool load_model(std::string model_path) = 0;
|
||||
virtual std::stringstream predict(std::string prompt, int max_length) = 0;
|
||||
virtual std::stringstream predict(std::string prompt, int max_length, bool include_prompt) = 0;
|
||||
|
||||
protected:
|
||||
ModelConfig config;
|
||||
|
|
19
include/turbopilot/server.hpp
Normal file
19
include/turbopilot/server.hpp
Normal file
|
@ -0,0 +1,19 @@
|
|||
#ifndef __TURBOPILOT_SERVER_H
|
||||
#define __TURBOPILOT_SERVER_H
|
||||
|
||||
|
||||
#include "turbopilot/model.hpp"
|
||||
|
||||
#include <crow.h>
|
||||
|
||||
crow::response serve_response(TurbopilotModel *model, const crow::request& req);
|
||||
|
||||
extern "C"
|
||||
crow::response server_response(const crow::request& req)
|
||||
{
|
||||
return serve_response(&model, req);
|
||||
}
|
||||
|
||||
|
||||
#endif // __TURBOPILOT_SERVER_H
|
||||
|
|
@ -1,12 +1,14 @@
|
|||
set(TURBOPILOT_TARGET turbopilot)
|
||||
|
||||
|
||||
find_package(Boost REQUIRED)
|
||||
include_directories(${Boost_INCLUDE_DIRS})
|
||||
|
||||
|
||||
add_executable(${TURBOPILOT_TARGET}
|
||||
main.cpp
|
||||
gptj.cpp
|
||||
common.cpp
|
||||
server.cpp
|
||||
../include/turbopilot/model.hpp
|
||||
../include/turbopilot/gptj.hpp
|
||||
)
|
||||
|
@ -15,7 +17,10 @@ add_executable(${TURBOPILOT_TARGET}
|
|||
target_include_directories(${TURBOPILOT_TARGET} PRIVATE
|
||||
../include
|
||||
../extern/spdlog/include
|
||||
../extern/crow/include
|
||||
)
|
||||
|
||||
|
||||
target_link_libraries(${TURBOPILOT_TARGET} PRIVATE ggml argparse)
|
||||
|
||||
target_link_libraries(${TURBOPILOT_TARGET} PUBLIC Crow::Crow)
|
19
src/gptj.cpp
19
src/gptj.cpp
|
@ -555,7 +555,7 @@ bool GPTJModel::load_model(std::string fname) {
|
|||
return true;
|
||||
}
|
||||
|
||||
std::stringstream GPTJModel::predict(std::string prompt, int max_length) {
|
||||
std::stringstream GPTJModel::predict(std::string prompt, int max_length, bool include_prompt) {
|
||||
|
||||
std::stringstream result;
|
||||
// tokenize the prompt
|
||||
|
@ -614,10 +614,20 @@ std::stringstream GPTJModel::predict(std::string prompt, int max_length) {
|
|||
|
||||
// add it to the context
|
||||
embd.push_back(id);
|
||||
|
||||
if(id != 50256){
|
||||
result << vocab->id_to_token[id].c_str();
|
||||
}
|
||||
|
||||
} else {
|
||||
// if here, it means we are still processing the input prompt
|
||||
for (int k = i; k < embd_inp.size(); k++) {
|
||||
embd.push_back(embd_inp[k]);
|
||||
|
||||
if(include_prompt){
|
||||
result << vocab->id_to_token[embd_inp[k]].c_str();
|
||||
}
|
||||
|
||||
if (embd.size() > config.n_batch) {
|
||||
break;
|
||||
}
|
||||
|
@ -625,13 +635,6 @@ std::stringstream GPTJModel::predict(std::string prompt, int max_length) {
|
|||
i += embd.size() - 1;
|
||||
}
|
||||
|
||||
// display text
|
||||
for (auto id : embd) {
|
||||
result << vocab->id_to_token[id].c_str();
|
||||
//printf("%s", vocab->id_to_token[id].c_str());
|
||||
}
|
||||
fflush(stdout);
|
||||
|
||||
// end of text token
|
||||
if (embd.back() == 50256) {
|
||||
break;
|
||||
|
|
50
src/main.cpp
50
src/main.cpp
|
@ -4,10 +4,13 @@
|
|||
|
||||
#include <spdlog/spdlog.h>
|
||||
|
||||
#include <argparse/argparse.hpp>
|
||||
#include "turbopilot/model.hpp"
|
||||
#include <crow.h>
|
||||
|
||||
#include <argparse/argparse.hpp>
|
||||
|
||||
#include "turbopilot/model.hpp"
|
||||
#include "turbopilot/gptj.hpp"
|
||||
#include "turbopilot/server.hpp"
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
|
@ -24,13 +27,16 @@ int main(int argc, char **argv)
|
|||
|
||||
program.add_argument("-p", "--port")
|
||||
.help("The tcp port that turbopilot should listen on")
|
||||
.default_value("18080");
|
||||
.default_value(18080)
|
||||
.scan<'i', int>();
|
||||
|
||||
program.add_argument("-r", "--random-seed")
|
||||
.help("Set the random seed for RNG functions")
|
||||
.default_value(-1)
|
||||
.scan<'i', int>();
|
||||
|
||||
program.add_argument("prompt").remaining();
|
||||
|
||||
|
||||
try
|
||||
{
|
||||
|
@ -76,9 +82,43 @@ int main(int argc, char **argv)
|
|||
|
||||
spdlog::info("Loaded model in {:0.2f}ms", t_load_us/1000.0f);
|
||||
|
||||
auto result = model->predict("test", 100);
|
||||
|
||||
spdlog::info("output: {}", result.str());
|
||||
crow::SimpleApp app;
|
||||
|
||||
CROW_ROUTE(app, "/")([](){
|
||||
return "Hello world";
|
||||
});
|
||||
|
||||
CROW_ROUTE(app, "/copilot_internal/v2/token")([](){
|
||||
//return "Hello world";
|
||||
|
||||
crow::json::wvalue response = {{"token","1"}, {"expires_at", static_cast<std::uint64_t>(2600000000)}, {"refresh_in",900}};
|
||||
|
||||
crow::response res;
|
||||
res.code = 200;
|
||||
res.set_header("Content-Type", "application/json");
|
||||
res.body = response.dump();
|
||||
return res;
|
||||
});
|
||||
|
||||
|
||||
CROW_ROUTE(app, "/v1/completions").methods(crow::HTTPMethod::Post)
|
||||
([&model](const crow::request& req) {
|
||||
return serve_response(model, req);
|
||||
});
|
||||
|
||||
CROW_ROUTE(app, "/v1/engines/codegen/completions").methods(crow::HTTPMethod::Post)
|
||||
([&model](const crow::request& req) {
|
||||
return serve_response(model, req);
|
||||
});
|
||||
|
||||
|
||||
CROW_ROUTE(app, "/v1/engines/copilot-codex/completions").methods(crow::HTTPMethod::Post)
|
||||
([&model](const crow::request& req) {
|
||||
return serve_response(model, req);
|
||||
});
|
||||
|
||||
app.port(program.get<int>("--port")).multithreaded().run();
|
||||
|
||||
free(model);
|
||||
}
|
97
src/server.cpp
Normal file
97
src/server.cpp
Normal file
|
@ -0,0 +1,97 @@
|
|||
|
||||
#include "turbopilot/server.hpp"
|
||||
#include "turbopilot/model.hpp"
|
||||
|
||||
#include <boost/lexical_cast.hpp>
|
||||
#include <boost/uuid/uuid.hpp> // uuid class
|
||||
#include <boost/uuid/uuid_generators.hpp> // generators
|
||||
#include <boost/uuid/uuid_io.hpp> // streaming operators etc.
|
||||
|
||||
/**
|
||||
* This function serves requests for autocompletion from crow
|
||||
*
|
||||
*/
|
||||
crow::response serve_response(TurbopilotModel *model, const crow::request& req){
|
||||
|
||||
crow::json::rvalue data = crow::json::load(req.body);
|
||||
|
||||
if(!data.has("prompt") && !data.has("input_ids")){
|
||||
crow::response res;
|
||||
res.code = 400;
|
||||
res.set_header("Content-Type", "application/json");
|
||||
res.body = "{\"message\":\"you must specify a prompt or input_ids\"}";
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// tokenize the prompt
|
||||
// std::vector<gpt_vocab::id> embd_inp;
|
||||
|
||||
// if (data.has("prompt")) {
|
||||
// std::string prompt = data["prompt"].s();
|
||||
// embd_inp = ::gpt_tokenize(vocab, prompt);
|
||||
// }
|
||||
// else {
|
||||
// crow::json::rvalue input_ids = data["input_ids"];
|
||||
// for (auto id : input_ids.lo()) {
|
||||
// embd_inp.push_back(id.i());
|
||||
// }
|
||||
// }
|
||||
|
||||
|
||||
// std::string suffix = data["suffix"].s();
|
||||
int maxTokens = 200;
|
||||
if(data.has("max_tokens")){
|
||||
maxTokens = data["max_tokens"].i();
|
||||
}
|
||||
|
||||
|
||||
auto result = model->predict(data["prompt"].s(), maxTokens, false);
|
||||
|
||||
|
||||
boost::uuids::uuid uuid = boost::uuids::random_generator()();
|
||||
|
||||
|
||||
// Generate a mock response based on the input parameters
|
||||
crow::json::wvalue choice = {
|
||||
{"text", result.str()},
|
||||
{"index",0},
|
||||
{"finish_reason", "length"},
|
||||
{"logprobs", nullptr}
|
||||
};
|
||||
crow::json::wvalue::list choices = {choice};
|
||||
|
||||
|
||||
// crow::json::wvalue usage = {
|
||||
// {"completion_tokens", n_past},
|
||||
// // {"prompt_tokens", static_cast<std::uint64_t>(embd_inp.size())},
|
||||
// {"prompt_tokens", 0},
|
||||
// {"total_tokens", static_cast<std::uint64_t>(n_past - embd_inp.size())}
|
||||
// };
|
||||
|
||||
|
||||
crow::json::wvalue usage = {
|
||||
{"completion_tokens", 0},
|
||||
// {"prompt_tokens", static_cast<std::uint64_t>(embd_inp.size())},
|
||||
{"prompt_tokens", 0},
|
||||
{"total_tokens", 0}
|
||||
};
|
||||
|
||||
crow::json::wvalue response = {
|
||||
{"id", boost::lexical_cast<std::string>(uuid)},
|
||||
{"model", "codegen"},
|
||||
{"object","text_completion"},
|
||||
{"created", static_cast<std::int64_t>(std::time(nullptr))},
|
||||
{"choices", choices },
|
||||
{"usage", usage}
|
||||
};
|
||||
|
||||
crow::response res;
|
||||
res.code = 200;
|
||||
res.set_header("Content-Type", "application/json");
|
||||
|
||||
res.body = response.dump(); //ss.str();
|
||||
return res;
|
||||
}
|
Loading…
Reference in New Issue
Block a user