Skip to content

Commit 7a26890

Browse files
author
Chris Warren-Smith
committed
LLAMA: add support for grammar
1 parent 65b1e5f commit 7a26890

File tree

3 files changed

+85
-18
lines changed

3 files changed

+85
-18
lines changed

llama/llama-sb.cpp

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ Llama::Llama() :
3131
_min_p(0),
3232
_top_k(0),
3333
_max_tokens(0),
34-
_log_level(GGML_LOG_LEVEL_CONT) {
34+
_log_level(GGML_LOG_LEVEL_CONT),
35+
_seed(LLAMA_DEFAULT_SEED) {
3536
llama_log_set([](enum ggml_log_level level, const char * text, void *user_data) {
3637
Llama *llama = (Llama *)user_data;
3738
if (level > llama->_log_level) {
@@ -63,6 +64,9 @@ void Llama::reset() {
6364
_top_p = 1.0f;
6465
_min_p = 0.0f;
6566
_max_tokens = 150;
67+
_grammar_src.clear();
68+
_grammar_root.clear();
69+
_seed = LLAMA_DEFAULT_SEED;
6670
if (_ctx) {
6771
llama_memory_clear(llama_get_memory(_ctx), true);
6872
}
@@ -93,36 +97,53 @@ bool Llama::construct(string model_path, int n_ctx, int n_batch, int n_gpu_layer
9397
_last_error = "Failed to create context";
9498
} else {
9599
_vocab = llama_model_get_vocab(_model);
96-
97-
auto sparams = llama_sampler_chain_default_params();
98-
sparams.no_perf = false;
99-
_sampler = llama_sampler_chain_init(sparams);
100100
}
101101
}
102102
return _last_error.empty();
103103
}
104104

105-
void Llama::configure_sampler() {
106-
llama_sampler_reset(_sampler);
105+
void Llama::set_grammar(const string &src, const string &root) {
106+
_grammar_src = src;
107+
_grammar_root = root;
108+
}
109+
110+
bool Llama::configure_sampler() {
111+
auto sparams = llama_sampler_chain_default_params();
112+
sparams.no_perf = false;
113+
llama_sampler *chain = llama_sampler_chain_init(sparams);
114+
115+
if (!_grammar_src.empty()) {
116+
llama_sampler *grammar = llama_sampler_init_grammar(_vocab, _grammar_src.c_str(), _grammar_root.c_str());
117+
if (!grammar) {
118+
_last_error = "failed to initialize grammar sampler";
119+
return false;
120+
}
121+
llama_sampler_chain_add(chain, grammar);
122+
}
107123
if (_penalty_last_n != 0 && _penalty_repeat != 1.0f) {
108124
auto penalties = llama_sampler_init_penalties(_penalty_last_n, _penalty_repeat, 0.0f, 0.0f);
109-
llama_sampler_chain_add(_sampler, penalties);
125+
llama_sampler_chain_add(chain, penalties);
110126
}
111127
if (_temperature <= 0.0f) {
112-
llama_sampler_chain_add(_sampler, llama_sampler_init_greedy());
128+
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
113129
} else {
114-
llama_sampler_chain_add(_sampler, llama_sampler_init_temp(_temperature));
115130
if (_top_k > 0) {
116-
llama_sampler_chain_add(_sampler, llama_sampler_init_top_k(_top_k));
131+
llama_sampler_chain_add(chain, llama_sampler_init_top_k(_top_k));
117132
}
118-
if (_top_p < 1.0f) {
119-
llama_sampler_chain_add(_sampler, llama_sampler_init_top_p(_top_p, 1));
133+
if (_top_p < 1.0f || _min_p > 0.0f) {
134+
llama_sampler_chain_add(chain, llama_sampler_init_top_p(_top_p, 1));
120135
}
121136
if (_min_p > 0.0f) {
122-
llama_sampler_chain_add(_sampler, llama_sampler_init_min_p(_min_p, 1));
137+
llama_sampler_chain_add(chain, llama_sampler_init_min_p(_min_p, 1));
123138
}
124-
llama_sampler_chain_add(_sampler, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
139+
llama_sampler_chain_add(chain, llama_sampler_init_temp(_temperature));
140+
llama_sampler_chain_add(chain, llama_sampler_init_dist(_seed));
125141
}
142+
if (_sampler) {
143+
llama_sampler_free(_sampler);
144+
}
145+
_sampler = chain;
146+
return true;
126147
}
127148

128149
vector<llama_token> Llama::tokenize(const string &prompt) {
@@ -201,7 +222,9 @@ bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
201222
}
202223

203224
bool Llama::generate(LlamaIter &iter, const string &prompt) {
204-
configure_sampler();
225+
if (!configure_sampler()) {
226+
return false;
227+
}
205228

206229
vector<llama_token> prompt_tokens = tokenize(prompt);
207230
if (prompt_tokens.size() == 0) {

llama/llama-sb.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ struct Llama {
5050
void set_temperature(float temperature) { _temperature = temperature; }
5151
void set_top_k(int top_k) { _top_k = top_k; }
5252
void set_top_p(float top_p) { _top_p = top_p; }
53+
void set_grammar(const string &src, const string &root);
54+
void set_seed(unsigned int seed) { _seed = seed; }
5355

5456
// error handling
5557
const char *last_error() { return _last_error.c_str(); }
@@ -58,7 +60,7 @@ struct Llama {
5860

5961
private:
6062
bool ends_with_sentence_boundary(const string &out);
61-
void configure_sampler();
63+
bool configure_sampler();
6264
bool make_space_for_tokens(int n_tokens, int keep_min);
6365
vector<llama_token> tokenize(const string &prompt);
6466
string token_to_string(LlamaIter &iter, llama_token tok);
@@ -68,6 +70,8 @@ struct Llama {
6870
llama_sampler *_sampler;
6971
const llama_vocab *_vocab;
7072
vector<string> _stop_sequences;
73+
string _grammar_src;
74+
string _grammar_root;
7175
string _last_error;
7276
int32_t _penalty_last_n;
7377
float _penalty_repeat;
@@ -77,4 +81,5 @@ struct Llama {
7781
int _top_k;
7882
int _max_tokens;
7983
int _log_level;
84+
unsigned int _seed;
8085
};

llama/main.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,42 @@ static int cmd_llama_set_top_p(var_s *self, int argc, slib_par_t *arg, var_s *re
211211
return result;
212212
}
213213

214+
//
215+
// llama.set_grammar("text")
216+
//
217+
static int cmd_llama_set_grammar(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
218+
int result = 0;
219+
if (argc != 1) {
220+
error(retval, "llama.set_grammar", 1, 1);
221+
} else {
222+
int id = get_llama_class_id(self, retval);
223+
if (id != -1) {
224+
Llama &llama = g_llama.at(id);
225+
llama.set_grammar(get_param_str(argc, arg, 0, 0), "root");
226+
result = 1;
227+
}
228+
}
229+
return result;
230+
}
231+
232+
//
233+
// llama.set_seed(123)
234+
//
235+
static int cmd_llama_set_seed(var_s *self, int argc, slib_par_t *arg, var_s *retval) {
236+
int result = 0;
237+
if (argc != 1) {
238+
error(retval, "llama.set_seed", 1, 1);
239+
} else {
240+
int id = get_llama_class_id(self, retval);
241+
if (id != -1) {
242+
Llama &llama = g_llama.at(id);
243+
llama.set_seed(get_param_num(argc, arg, 0, 0));
244+
result = 1;
245+
}
246+
}
247+
return result;
248+
}
249+
214250
//
215251
// llama.reset() - make the model forget everything
216252
//
@@ -355,6 +391,8 @@ static int cmd_create_llama(int argc, slib_par_t *params, var_t *retval) {
355391
v_create_callback(retval, "set_temperature", cmd_llama_set_temperature);
356392
v_create_callback(retval, "set_top_k", cmd_llama_set_top_k);
357393
v_create_callback(retval, "set_top_p", cmd_llama_set_top_p);
394+
v_create_callback(retval, "set_grammar", cmd_llama_set_grammar);
395+
v_create_callback(retval, "set_seed", cmd_llama_set_seed);
358396
result = 1;
359397
} else {
360398
error(retval, llama.last_error());
@@ -388,7 +426,7 @@ int sblib_init(const char *sourceFile) {
388426
//
389427
// Release variables falling out of scope
390428
//
391-
SBLIB_API void sblib_free(int cls_id, int id) {
429+
SBLIB_API int sblib_free(int cls_id, int id) {
392430
if (id != -1) {
393431
switch (cls_id) {
394432
case CLASS_ID_LLAMA:
@@ -403,6 +441,7 @@ SBLIB_API void sblib_free(int cls_id, int id) {
403441
break;
404442
}
405443
}
444+
return 0;
406445
}
407446

408447
//

0 commit comments

Comments
 (0)