Compare commits

..

2 Commits

Author SHA1 Message Date
James Ravenscroft
6b0a25cb71
Merge pull request #59 from ravenscroftj/feature/batch-flag
expose batch size flag to cli
2023-08-24 14:13:40 +01:00
James Ravenscroft
2d617b458e expose batch size flag to cli 2023-08-24 11:40:19 +01:00

View File

@ -60,6 +60,11 @@ int main(int argc, char **argv)
.default_value(0.1f) .default_value(0.1f)
.scan<'g', float>(); .scan<'g', float>();
program.add_argument("-b", "--batch-size")
.help("set batch size for model completion")
.default_value(512)
.scan<'i',int>();
program.add_argument("prompt").remaining(); program.add_argument("prompt").remaining();
@ -96,6 +101,7 @@ int main(int argc, char **argv)
config.n_threads = program.get<int>("--threads"); config.n_threads = program.get<int>("--threads");
config.temp = program.get<float>("--temperature"); config.temp = program.get<float>("--temperature");
config.top_p = program.get<float>("--top-p"); config.top_p = program.get<float>("--top-p");
config.n_batch = program.get<int>("--batch-size");
if(model_type.compare("codegen") == 0) { if(model_type.compare("codegen") == 0) {
spdlog::info("Initializing GPT-J type model for '{}' model", model_type); spdlog::info("Initializing GPT-J type model for '{}' model", model_type);
@ -131,6 +137,7 @@ int main(int argc, char **argv)
return "Hello world"; return "Hello world";
}); });
CROW_ROUTE(app, "/copilot_internal/v2/token")([](){ CROW_ROUTE(app, "/copilot_internal/v2/token")([](){
//return "Hello world"; //return "Hello world";