Send codellama:7b-code stop token in request
So Ollama filters it out
This commit is contained in:
parent
a50dc886da
commit
909b2eca03
2 changed files with 72 additions and 1 deletions
|
@ -831,4 +831,33 @@ mod tests {
|
||||||
// Note: The API key parameter is passed to the generate function itself,
|
// Note: The API key parameter is passed to the generate function itself,
|
||||||
// not included in the GenerateRequest struct that gets serialized to JSON
|
// not included in the GenerateRequest struct that gets serialized to JSON
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_request_with_stop_tokens() {
|
||||||
|
let request = GenerateRequest {
|
||||||
|
model: "codellama:7b-code".to_string(),
|
||||||
|
prompt: "def fibonacci(n):".to_string(),
|
||||||
|
suffix: Some(" return result".to_string()),
|
||||||
|
stream: false,
|
||||||
|
options: Some(GenerateOptions {
|
||||||
|
num_predict: Some(150),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
stop: Some(vec!["<EOT>".to_string()]),
|
||||||
|
}),
|
||||||
|
keep_alive: None,
|
||||||
|
context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
|
let parsed: GenerateRequest = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.model, "codellama:7b-code");
|
||||||
|
assert_eq!(parsed.prompt, "def fibonacci(n):");
|
||||||
|
assert_eq!(parsed.suffix, Some(" return result".to_string()));
|
||||||
|
assert!(!parsed.stream);
|
||||||
|
assert!(parsed.options.is_some());
|
||||||
|
let options = parsed.options.unwrap();
|
||||||
|
assert_eq!(options.stop, Some(vec!["<EOT>".to_string()]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,6 +66,19 @@ impl OllamaCompletionProvider {
|
||||||
|
|
||||||
(prefix, suffix)
|
(prefix, suffix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get stop tokens for the current model
|
||||||
|
/// For now we only handle the case for codellama:7b-code model
|
||||||
|
/// that we found was including the stop token in the completion suggestion.
|
||||||
|
/// We wanted to avoid going down this route and let Ollama abstract all template tokens away.
|
||||||
|
/// But apparently, and surprisingly for a llama model, Ollama misses this case.
|
||||||
|
fn get_stop_tokens(&self) -> Option<Vec<String>> {
|
||||||
|
if self.model.contains("codellama") && self.model.contains("code") {
|
||||||
|
Some(vec!["<EOT>".to_string()])
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditPredictionProvider for OllamaCompletionProvider {
|
impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
@ -124,6 +137,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
let (model, api_key) =
|
let (model, api_key) =
|
||||||
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
|
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
|
||||||
|
|
||||||
|
let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?;
|
||||||
|
|
||||||
let request = GenerateRequest {
|
let request = GenerateRequest {
|
||||||
model,
|
model,
|
||||||
prompt: prefix,
|
prompt: prefix,
|
||||||
|
@ -133,7 +148,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
num_predict: Some(150), // Reasonable completion length
|
num_predict: Some(150), // Reasonable completion length
|
||||||
temperature: Some(0.1), // Low temperature for more deterministic results
|
temperature: Some(0.1), // Low temperature for more deterministic results
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
stop: None, // Let Ollama handle stop tokens natively
|
stop: stop_tokens,
|
||||||
}),
|
}),
|
||||||
keep_alive: None,
|
keep_alive: None,
|
||||||
context: None,
|
context: None,
|
||||||
|
@ -254,6 +269,33 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test the complete Ollama completion flow from refresh to suggestion
|
/// Test the complete Ollama completion flow from refresh to suggestion
|
||||||
|
#[test]
|
||||||
|
fn test_get_stop_tokens() {
|
||||||
|
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Test CodeLlama code model gets stop tokens
|
||||||
|
let codellama_provider = OllamaCompletionProvider::new(
|
||||||
|
http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b-code".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
codellama_provider.get_stop_tokens(),
|
||||||
|
Some(vec!["<EOT>".to_string()])
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test non-CodeLlama model doesn't get stop tokens
|
||||||
|
let qwen_provider = OllamaCompletionProvider::new(
|
||||||
|
http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:3b".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
assert_eq!(qwen_provider.get_stop_tokens(), None);
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_full_completion_flow(cx: &mut TestAppContext) {
|
async fn test_full_completion_flow(cx: &mut TestAppContext) {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue