Send codellama:7b-code stop token in request

So Ollama filters it out
This commit is contained in:
Oliver Azevedo Barnes 2025-07-17 17:55:21 +01:00
parent a50dc886da
commit 909b2eca03
No known key found for this signature in database
2 changed files with 72 additions and 1 deletions

View file

@ -831,4 +831,33 @@ mod tests {
// Note: The API key parameter is passed to the generate function itself,
// not included in the GenerateRequest struct that gets serialized to JSON
}
#[test]
fn test_generate_request_with_stop_tokens() {
let request = GenerateRequest {
model: "codellama:7b-code".to_string(),
prompt: "def fibonacci(n):".to_string(),
suffix: Some(" return result".to_string()),
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150),
temperature: Some(0.1),
top_p: Some(0.95),
stop: Some(vec!["<EOT>".to_string()]),
}),
keep_alive: None,
context: None,
};
let json = serde_json::to_string(&request).unwrap();
let parsed: GenerateRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model, "codellama:7b-code");
assert_eq!(parsed.prompt, "def fibonacci(n):");
assert_eq!(parsed.suffix, Some(" return result".to_string()));
assert!(!parsed.stream);
assert!(parsed.options.is_some());
let options = parsed.options.unwrap();
assert_eq!(options.stop, Some(vec!["<EOT>".to_string()]));
}
}

View file

@ -66,6 +66,19 @@ impl OllamaCompletionProvider {
(prefix, suffix)
}
/// Get stop tokens for the current model
/// For now we only handle the case for codellama:7b-code model
/// that we found was including the stop token in the completion suggestion.
/// We wanted to avoid going down this route and let Ollama abstract all template tokens away.
/// But apparently, and surprisingly for a llama model, Ollama misses this case.
fn get_stop_tokens(&self) -> Option<Vec<String>> {
if self.model.contains("codellama") && self.model.contains("code") {
Some(vec!["<EOT>".to_string()])
} else {
None
}
}
}
impl EditPredictionProvider for OllamaCompletionProvider {
@ -124,6 +137,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
let (model, api_key) =
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?;
let request = GenerateRequest {
model,
prompt: prefix,
@ -133,7 +148,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
num_predict: Some(150), // Reasonable completion length
temperature: Some(0.1), // Low temperature for more deterministic results
top_p: Some(0.95),
stop: None, // Let Ollama handle stop tokens natively
stop: stop_tokens,
}),
keep_alive: None,
context: None,
@ -254,6 +269,33 @@ mod tests {
}
/// Test the complete Ollama completion flow from refresh to suggestion
#[test]
fn test_get_stop_tokens() {
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Test CodeLlama code model gets stop tokens
let codellama_provider = OllamaCompletionProvider::new(
http_client.clone(),
"http://localhost:11434".to_string(),
"codellama:7b-code".to_string(),
None,
);
assert_eq!(
codellama_provider.get_stop_tokens(),
Some(vec!["<EOT>".to_string()])
);
// Test non-CodeLlama model doesn't get stop tokens
let qwen_provider = OllamaCompletionProvider::new(
http_client.clone(),
"http://localhost:11434".to_string(),
"qwen2.5-coder:3b".to_string(),
None,
);
assert_eq!(qwen_provider.get_stop_tokens(), None);
}
#[gpui::test]
async fn test_full_completion_flow(cx: &mut TestAppContext) {
init_test(cx);