@@ -31,7 +31,8 @@ Llama::Llama() :
3131 _min_p(0 ),
3232 _top_k(0 ),
3333 _max_tokens(0 ),
34- _log_level(GGML_LOG_LEVEL_CONT) {
34+ _log_level(GGML_LOG_LEVEL_CONT),
35+ _seed(LLAMA_DEFAULT_SEED) {
3536 llama_log_set ([](enum ggml_log_level level, const char * text, void *user_data) {
3637 Llama *llama = (Llama *)user_data;
3738 if (level > llama->_log_level ) {
@@ -63,6 +64,9 @@ void Llama::reset() {
6364 _top_p = 1 .0f ;
6465 _min_p = 0 .0f ;
6566 _max_tokens = 150 ;
67+ _grammar_src.clear ();
68+ _grammar_root.clear ();
69+ _seed = LLAMA_DEFAULT_SEED;
6670 if (_ctx) {
6771 llama_memory_clear (llama_get_memory (_ctx), true );
6872 }
@@ -93,36 +97,53 @@ bool Llama::construct(string model_path, int n_ctx, int n_batch, int n_gpu_layer
9397 _last_error = " Failed to create context" ;
9498 } else {
9599 _vocab = llama_model_get_vocab (_model);
96-
97- auto sparams = llama_sampler_chain_default_params ();
98- sparams.no_perf = false ;
99- _sampler = llama_sampler_chain_init (sparams);
100100 }
101101 }
102102 return _last_error.empty ();
103103}
104104
105- void Llama::configure_sampler () {
106- llama_sampler_reset (_sampler);
105+ void Llama::set_grammar (const string &src, const string &root) {
106+ _grammar_src = src;
107+ _grammar_root = root;
108+ }
109+
110+ bool Llama::configure_sampler () {
111+ auto sparams = llama_sampler_chain_default_params ();
112+ sparams.no_perf = false ;
113+ llama_sampler *chain = llama_sampler_chain_init (sparams);
114+
115+ if (!_grammar_src.empty ()) {
116+ llama_sampler *grammar = llama_sampler_init_grammar (_vocab, _grammar_src.c_str (), _grammar_root.c_str ());
117+ if (!grammar) {
118+ _last_error = " failed to initialize grammar sampler" ;
119+ return false ;
120+ }
121+ llama_sampler_chain_add (chain, grammar);
122+ }
107123 if (_penalty_last_n != 0 && _penalty_repeat != 1 .0f ) {
108124 auto penalties = llama_sampler_init_penalties (_penalty_last_n, _penalty_repeat, 0 .0f , 0 .0f );
109- llama_sampler_chain_add (_sampler , penalties);
125+ llama_sampler_chain_add (chain , penalties);
110126 }
111127 if (_temperature <= 0 .0f ) {
112- llama_sampler_chain_add (_sampler , llama_sampler_init_greedy ());
128+ llama_sampler_chain_add (chain , llama_sampler_init_greedy ());
113129 } else {
114- llama_sampler_chain_add (_sampler, llama_sampler_init_temp (_temperature));
115130 if (_top_k > 0 ) {
116- llama_sampler_chain_add (_sampler , llama_sampler_init_top_k (_top_k));
131+ llama_sampler_chain_add (chain , llama_sampler_init_top_k (_top_k));
117132 }
118- if (_top_p < 1 .0f ) {
119- llama_sampler_chain_add (_sampler , llama_sampler_init_top_p (_top_p, 1 ));
133+ if (_top_p < 1 .0f || _min_p > 0 . 0f ) {
134+ llama_sampler_chain_add (chain , llama_sampler_init_top_p (_top_p, 1 ));
120135 }
121136 if (_min_p > 0 .0f ) {
122- llama_sampler_chain_add (_sampler , llama_sampler_init_min_p (_min_p, 1 ));
137+ llama_sampler_chain_add (chain , llama_sampler_init_min_p (_min_p, 1 ));
123138 }
124- llama_sampler_chain_add (_sampler, llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
139+ llama_sampler_chain_add (chain, llama_sampler_init_temp (_temperature));
140+ llama_sampler_chain_add (chain, llama_sampler_init_dist (_seed));
125141 }
142+ if (_sampler) {
143+ llama_sampler_free (_sampler);
144+ }
145+ _sampler = chain;
146+ return true ;
126147}
127148
128149vector<llama_token> Llama::tokenize (const string &prompt) {
@@ -201,7 +222,9 @@ bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
201222}
202223
203224bool Llama::generate (LlamaIter &iter, const string &prompt) {
204- configure_sampler ();
225+ if (!configure_sampler ()) {
226+ return false ;
227+ }
205228
206229 vector<llama_token> prompt_tokens = tokenize (prompt);
207230 if (prompt_tokens.size () == 0 ) {
0 commit comments