Remove stop tokens and completion cleanup
This commit is contained in:
parent
b50555b87a
commit
cce9949d92
2 changed files with 14 additions and 76 deletions
|
@ -437,7 +437,7 @@ mod tests {
|
||||||
num_predict: Some(150),
|
num_predict: Some(150),
|
||||||
temperature: Some(0.1),
|
temperature: Some(0.1),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
stop: Some(vec!["<|endoftext|>".to_string()]),
|
stop: None,
|
||||||
}),
|
}),
|
||||||
keep_alive: None,
|
keep_alive: None,
|
||||||
context: None,
|
context: None,
|
||||||
|
|
|
@ -37,39 +37,12 @@ impl OllamaCompletionProvider {
|
||||||
self.model = new_model;
|
self.model = new_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_stop_tokens(&self) -> Vec<String> {
|
/// Updates the file extension used by this provider
|
||||||
// Basic stop tokens for code completion
|
pub fn update_file_extension(&mut self, new_file_extension: String) {
|
||||||
// Ollama handles FIM tokens internally, so we only need general completion stops
|
self.file_extension = Some(new_file_extension);
|
||||||
vec![
|
|
||||||
"\n\n".to_string(), // Double newline often indicates end of completion
|
|
||||||
"```".to_string(), // Code block delimiter
|
|
||||||
"<|endoftext|>".to_string(), // Common model end token
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clean_completion(&self, completion: &str) -> String {
|
// Removed get_stop_tokens and clean_completion - Ollama handles everything natively with FIM
|
||||||
let mut cleaned = completion.to_string();
|
|
||||||
|
|
||||||
// Basic cleaning - Ollama should handle FIM tokens internally
|
|
||||||
// Remove leading/trailing whitespace
|
|
||||||
cleaned = cleaned.trim().to_string();
|
|
||||||
|
|
||||||
// Remove common unwanted prefixes that models sometimes generate
|
|
||||||
let unwanted_prefixes = [
|
|
||||||
"// COMPLETION HERE",
|
|
||||||
"// Complete the following code:",
|
|
||||||
"// completion:",
|
|
||||||
"// TODO:",
|
|
||||||
];
|
|
||||||
|
|
||||||
for prefix in &unwanted_prefixes {
|
|
||||||
if cleaned.starts_with(prefix) {
|
|
||||||
cleaned = cleaned[prefix.len()..].trim_start().to_string();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cleaned
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
|
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
|
||||||
let cursor_offset = cursor_position.to_offset(buffer);
|
let cursor_offset = cursor_position.to_offset(buffer);
|
||||||
|
@ -102,7 +75,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
|
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
|
||||||
// TODO: Could ping Ollama API to check if it's running
|
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,8 +114,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
this.extract_context(buffer_snapshot, cursor_position)
|
this.extract_context(buffer_snapshot, cursor_position)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let (model, stop_tokens) =
|
let model = this.update(cx, |this, _| this.model.clone())?;
|
||||||
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
|
|
||||||
|
|
||||||
let request = GenerateRequest {
|
let request = GenerateRequest {
|
||||||
model,
|
model,
|
||||||
|
@ -154,7 +125,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
num_predict: Some(150), // Reasonable completion length
|
num_predict: Some(150), // Reasonable completion length
|
||||||
temperature: Some(0.1), // Low temperature for more deterministic results
|
temperature: Some(0.1), // Low temperature for more deterministic results
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
stop: Some(stop_tokens),
|
stop: None, // Let Ollama handle stop tokens natively
|
||||||
}),
|
}),
|
||||||
keep_alive: None,
|
keep_alive: None,
|
||||||
context: None,
|
context: None,
|
||||||
|
@ -166,9 +137,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.pending_refresh = None;
|
this.pending_refresh = None;
|
||||||
let cleaned_completion = this.clean_completion(&response.response);
|
if !response.response.trim().is_empty() {
|
||||||
if !cleaned_completion.is_empty() {
|
this.current_completion = Some(response.response);
|
||||||
this.current_completion = Some(cleaned_completion);
|
|
||||||
} else {
|
} else {
|
||||||
this.current_completion = None;
|
this.current_completion = None;
|
||||||
}
|
}
|
||||||
|
@ -236,37 +206,9 @@ mod tests {
|
||||||
use http_client::FakeHttpClient;
|
use http_client::FakeHttpClient;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[gpui::test]
|
// Removed test_get_stop_tokens - no longer using custom stop tokens
|
||||||
async fn test_get_stop_tokens(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"qwen2.5-coder:32b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let stop_tokens = provider.get_stop_tokens();
|
// Removed test_clean_completion_basic - no longer using custom completion cleaning
|
||||||
assert!(stop_tokens.contains(&"\n\n".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"```".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
|
||||||
assert_eq!(stop_tokens.len(), 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_clean_completion_basic(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"qwen2.5-coder:32b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let completion = " console.log('hello'); ";
|
|
||||||
let cleaned = provider.clean_completion(completion);
|
|
||||||
assert_eq!(cleaned, "console.log('hello');");
|
|
||||||
|
|
||||||
let completion_with_prefix = "// COMPLETION HERE\nconsole.log('hello');";
|
|
||||||
let cleaned = provider.clean_completion(completion_with_prefix);
|
|
||||||
assert_eq!(cleaned, "console.log('hello');");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_extract_context(cx: &mut TestAppContext) {
|
async fn test_extract_context(cx: &mut TestAppContext) {
|
||||||
|
@ -405,7 +347,7 @@ mod tests {
|
||||||
num_predict: Some(150),
|
num_predict: Some(150),
|
||||||
temperature: Some(0.1),
|
temperature: Some(0.1),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
stop: Some(provider.get_stop_tokens()),
|
stop: None, // Ollama handles stop tokens natively
|
||||||
}),
|
}),
|
||||||
keep_alive: None,
|
keep_alive: None,
|
||||||
context: None,
|
context: None,
|
||||||
|
@ -417,11 +359,7 @@ mod tests {
|
||||||
assert_eq!(request.suffix, Some(" return result".to_string()));
|
assert_eq!(request.suffix, Some(" return result".to_string()));
|
||||||
assert!(!request.stream);
|
assert!(!request.stream);
|
||||||
|
|
||||||
// Verify stop tokens are simplified (no FIM-specific tokens)
|
// Verify stop tokens are handled natively by Ollama
|
||||||
let stop_tokens = request.options.as_ref().unwrap().stop.as_ref().unwrap();
|
assert!(request.options.as_ref().unwrap().stop.is_none());
|
||||||
assert!(stop_tokens.contains(&"\n\n".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"```".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
|
||||||
assert_eq!(stop_tokens.len(), 3);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue