Remove stop tokens and completion cleanup

This commit is contained in:
Oliver Azevedo Barnes 2025-07-07 20:49:08 -03:00
parent b50555b87a
commit cce9949d92
No known key found for this signature in database
2 changed files with 14 additions and 76 deletions

View file

@ -437,7 +437,7 @@ mod tests {
num_predict: Some(150), num_predict: Some(150),
temperature: Some(0.1), temperature: Some(0.1),
top_p: Some(0.95), top_p: Some(0.95),
stop: Some(vec!["<|endoftext|>".to_string()]), stop: None,
}), }),
keep_alive: None, keep_alive: None,
context: None, context: None,

View file

@ -37,39 +37,12 @@ impl OllamaCompletionProvider {
self.model = new_model; self.model = new_model;
} }
fn get_stop_tokens(&self) -> Vec<String> { /// Updates the file extension used by this provider
// Basic stop tokens for code completion pub fn update_file_extension(&mut self, new_file_extension: String) {
// Ollama handles FIM tokens internally, so we only need general completion stops self.file_extension = Some(new_file_extension);
vec![
"\n\n".to_string(), // Double newline often indicates end of completion
"```".to_string(), // Code block delimiter
"<|endoftext|>".to_string(), // Common model end token
]
} }
fn clean_completion(&self, completion: &str) -> String { // Removed get_stop_tokens and clean_completion - Ollama handles everything natively with FIM
let mut cleaned = completion.to_string();
// Basic cleaning - Ollama should handle FIM tokens internally
// Remove leading/trailing whitespace
cleaned = cleaned.trim().to_string();
// Remove common unwanted prefixes that models sometimes generate
let unwanted_prefixes = [
"// COMPLETION HERE",
"// Complete the following code:",
"// completion:",
"// TODO:",
];
for prefix in &unwanted_prefixes {
if cleaned.starts_with(prefix) {
cleaned = cleaned[prefix.len()..].trim_start().to_string();
}
}
cleaned
}
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) { fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
let cursor_offset = cursor_position.to_offset(buffer); let cursor_offset = cursor_position.to_offset(buffer);
@ -102,7 +75,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
} }
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool { fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
// TODO: Could ping Ollama API to check if it's running
true true
} }
@ -142,8 +114,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
this.extract_context(buffer_snapshot, cursor_position) this.extract_context(buffer_snapshot, cursor_position)
})?; })?;
let (model, stop_tokens) = let model = this.update(cx, |this, _| this.model.clone())?;
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
let request = GenerateRequest { let request = GenerateRequest {
model, model,
@ -154,7 +125,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
num_predict: Some(150), // Reasonable completion length num_predict: Some(150), // Reasonable completion length
temperature: Some(0.1), // Low temperature for more deterministic results temperature: Some(0.1), // Low temperature for more deterministic results
top_p: Some(0.95), top_p: Some(0.95),
stop: Some(stop_tokens), stop: None, // Let Ollama handle stop tokens natively
}), }),
keep_alive: None, keep_alive: None,
context: None, context: None,
@ -166,9 +137,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.pending_refresh = None; this.pending_refresh = None;
let cleaned_completion = this.clean_completion(&response.response); if !response.response.trim().is_empty() {
if !cleaned_completion.is_empty() { this.current_completion = Some(response.response);
this.current_completion = Some(cleaned_completion);
} else { } else {
this.current_completion = None; this.current_completion = None;
} }
@ -236,37 +206,9 @@ mod tests {
use http_client::FakeHttpClient; use http_client::FakeHttpClient;
use std::sync::Arc; use std::sync::Arc;
#[gpui::test] // Removed test_get_stop_tokens - no longer using custom stop tokens
async fn test_get_stop_tokens(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let stop_tokens = provider.get_stop_tokens(); // Removed test_clean_completion_basic - no longer using custom completion cleaning
assert!(stop_tokens.contains(&"\n\n".to_string()));
assert!(stop_tokens.contains(&"```".to_string()));
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
assert_eq!(stop_tokens.len(), 3);
}
#[gpui::test]
async fn test_clean_completion_basic(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let completion = " console.log('hello'); ";
let cleaned = provider.clean_completion(completion);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_prefix = "// COMPLETION HERE\nconsole.log('hello');";
let cleaned = provider.clean_completion(completion_with_prefix);
assert_eq!(cleaned, "console.log('hello');");
}
#[gpui::test] #[gpui::test]
async fn test_extract_context(cx: &mut TestAppContext) { async fn test_extract_context(cx: &mut TestAppContext) {
@ -405,7 +347,7 @@ mod tests {
num_predict: Some(150), num_predict: Some(150),
temperature: Some(0.1), temperature: Some(0.1),
top_p: Some(0.95), top_p: Some(0.95),
stop: Some(provider.get_stop_tokens()), stop: None, // Ollama handles stop tokens natively
}), }),
keep_alive: None, keep_alive: None,
context: None, context: None,
@ -417,11 +359,7 @@ mod tests {
assert_eq!(request.suffix, Some(" return result".to_string())); assert_eq!(request.suffix, Some(" return result".to_string()));
assert!(!request.stream); assert!(!request.stream);
// Verify stop tokens are simplified (no FIM-specific tokens) // Verify stop tokens are handled natively by Ollama
let stop_tokens = request.options.as_ref().unwrap().stop.as_ref().unwrap(); assert!(request.options.as_ref().unwrap().stop.is_none());
assert!(stop_tokens.contains(&"\n\n".to_string()));
assert!(stop_tokens.contains(&"```".to_string()));
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
assert_eq!(stop_tokens.len(), 3);
} }
} }