Send codellama:7b-code stop token in request
So Ollama filters it out
This commit is contained in:
parent
a50dc886da
commit
909b2eca03
2 changed files with 72 additions and 1 deletions
|
@ -831,4 +831,33 @@ mod tests {
|
|||
// Note: The API key parameter is passed to the generate function itself,
|
||||
// not included in the GenerateRequest struct that gets serialized to JSON
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_with_stop_tokens() {
|
||||
let request = GenerateRequest {
|
||||
model: "codellama:7b-code".to_string(),
|
||||
prompt: "def fibonacci(n):".to_string(),
|
||||
suffix: Some(" return result".to_string()),
|
||||
stream: false,
|
||||
options: Some(GenerateOptions {
|
||||
num_predict: Some(150),
|
||||
temperature: Some(0.1),
|
||||
top_p: Some(0.95),
|
||||
stop: Some(vec!["<EOT>".to_string()]),
|
||||
}),
|
||||
keep_alive: None,
|
||||
context: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
let parsed: GenerateRequest = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.model, "codellama:7b-code");
|
||||
assert_eq!(parsed.prompt, "def fibonacci(n):");
|
||||
assert_eq!(parsed.suffix, Some(" return result".to_string()));
|
||||
assert!(!parsed.stream);
|
||||
assert!(parsed.options.is_some());
|
||||
let options = parsed.options.unwrap();
|
||||
assert_eq!(options.stop, Some(vec!["<EOT>".to_string()]));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,6 +66,19 @@ impl OllamaCompletionProvider {
|
|||
|
||||
(prefix, suffix)
|
||||
}
|
||||
|
||||
/// Get stop tokens for the current model
|
||||
/// For now we only handle the case for codellama:7b-code model
|
||||
/// that we found was including the stop token in the completion suggestion.
|
||||
/// We wanted to avoid going down this route and let Ollama abstract all template tokens away.
|
||||
/// But apparently, and surprisingly for a llama model, Ollama misses this case.
|
||||
fn get_stop_tokens(&self) -> Option<Vec<String>> {
|
||||
if self.model.contains("codellama") && self.model.contains("code") {
|
||||
Some(vec!["<EOT>".to_string()])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EditPredictionProvider for OllamaCompletionProvider {
|
||||
|
@ -124,6 +137,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
let (model, api_key) =
|
||||
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
|
||||
|
||||
let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?;
|
||||
|
||||
let request = GenerateRequest {
|
||||
model,
|
||||
prompt: prefix,
|
||||
|
@ -133,7 +148,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
num_predict: Some(150), // Reasonable completion length
|
||||
temperature: Some(0.1), // Low temperature for more deterministic results
|
||||
top_p: Some(0.95),
|
||||
stop: None, // Let Ollama handle stop tokens natively
|
||||
stop: stop_tokens,
|
||||
}),
|
||||
keep_alive: None,
|
||||
context: None,
|
||||
|
@ -254,6 +269,33 @@ mod tests {
|
|||
}
|
||||
|
||||
/// Test the complete Ollama completion flow from refresh to suggestion
|
||||
#[test]
|
||||
fn test_get_stop_tokens() {
|
||||
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
|
||||
// Test CodeLlama code model gets stop tokens
|
||||
let codellama_provider = OllamaCompletionProvider::new(
|
||||
http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codellama:7b-code".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
codellama_provider.get_stop_tokens(),
|
||||
Some(vec!["<EOT>".to_string()])
|
||||
);
|
||||
|
||||
// Test non-CodeLlama model doesn't get stop tokens
|
||||
let qwen_provider = OllamaCompletionProvider::new(
|
||||
http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5-coder:3b".to_string(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(qwen_provider.get_stop_tokens(), None);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_full_completion_flow(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue