Send codellama:7b-code stop token in request
So Ollama filters it out
This commit is contained in:
parent
a50dc886da
commit
909b2eca03
2 changed files with 72 additions and 1 deletions
|
@ -66,6 +66,19 @@ impl OllamaCompletionProvider {
|
|||
|
||||
(prefix, suffix)
|
||||
}
|
||||
|
||||
/// Get stop tokens for the current model
|
||||
/// For now we only handle the case for codellama:7b-code model
|
||||
/// that we found was including the stop token in the completion suggestion.
|
||||
/// We wanted to avoid going down this route and let Ollama abstract all template tokens away.
|
||||
/// But apparently, and surprisingly for a llama model, Ollama misses this case.
|
||||
fn get_stop_tokens(&self) -> Option<Vec<String>> {
|
||||
if self.model.contains("codellama") && self.model.contains("code") {
|
||||
Some(vec!["<EOT>".to_string()])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for OllamaCompletionProvider {
|
||||
|
@ -124,6 +137,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
let (model, api_key) =
|
||||
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
|
||||
|
||||
let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?;
|
||||
|
||||
let request = GenerateRequest {
|
||||
model,
|
||||
prompt: prefix,
|
||||
|
@ -133,7 +148,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
num_predict: Some(150), // Reasonable completion length
|
||||
temperature: Some(0.1), // Low temperature for more deterministic results
|
||||
top_p: Some(0.95),
|
||||
stop: None, // Let Ollama handle stop tokens natively
|
||||
stop: stop_tokens,
|
||||
}),
|
||||
keep_alive: None,
|
||||
context: None,
|
||||
|
@ -254,6 +269,33 @@ mod tests {
|
|||
}
|
||||
|
||||
/// Test the complete Ollama completion flow from refresh to suggestion
|
||||
#[test]
|
||||
fn test_get_stop_tokens() {
|
||||
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
|
||||
// Test CodeLlama code model gets stop tokens
|
||||
let codellama_provider = OllamaCompletionProvider::new(
|
||||
http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codellama:7b-code".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
codellama_provider.get_stop_tokens(),
|
||||
Some(vec!["<EOT>".to_string()])
|
||||
);
|
||||
|
||||
// Test non-CodeLlama model doesn't get stop tokens
|
||||
let qwen_provider = OllamaCompletionProvider::new(
|
||||
http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5-coder:3b".to_string(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(qwen_provider.get_stop_tokens(), None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_full_completion_flow(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue