diff --git a/src/main.cpp b/src/main.cpp index 32ddcf9..34ebbbd 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -32,7 +32,6 @@ int main(int argc, char **argv) .default_value(4) .scan<'i', int>(); - program.add_argument("-p", "--port") .help("The tcp port that turbopilot should listen on") .default_value(18080) @@ -43,6 +42,17 @@ int main(int argc, char **argv) .default_value(-1) .scan<'i', int>(); + program.add_argument("--temperature") + .help("Set the generation temperature") + .default_value(0.2) + .scan<'g', double>(); + + program.add_argument("--top-p") + .help("Set the generation top_p") + .default_value(0.1) + .scan<'g', double>(); + + program.add_argument("prompt").remaining(); @@ -70,6 +80,8 @@ int main(int argc, char **argv) std::mt19937 rng(program.get("--random-seed")); config.n_threads = program.get("--threads"); + config.temp = program.get("--temperature"); + config.top_p = program.get("--top-p"); if(model_type.compare("codegen") == 0) { spdlog::info("Initializing GPT-J type model for '{}' model", model_type);