Improved FIM token handling per model

This commit is contained in:
Oliver Azevedo Barnes 2025-07-03 11:53:06 -03:00
parent 902a07606b
commit af66570bfe
No known key found for this signature in database

View file

@ -39,21 +39,182 @@ impl OllamaCompletionProvider {
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String { fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
// Use model-specific FIM patterns // Use model-specific FIM patterns
match self.model.as_str() { let model_lower = self.model.to_lowercase();
m if m.contains("codellama") => {
if model_lower.contains("qwen") && model_lower.contains("coder") {
// QwenCoder models use pipes
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
} else if model_lower.contains("codellama") {
format!("<PRE> {prefix} <SUF>{suffix} <MID>") format!("<PRE> {prefix} <SUF>{suffix} <MID>")
} } else if model_lower.contains("deepseek") {
m if m.contains("deepseek") => {
format!("<fim▁begin>{prefix}<fim▁hole>{suffix}<fim▁end>") format!("<fim▁begin>{prefix}<fim▁hole>{suffix}<fim▁end>")
} } else if model_lower.contains("codestral") {
m if m.contains("starcoder") => { // Codestral uses suffix-first order
format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
} else if model_lower.contains("codegemma") {
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
} else if model_lower.contains("wizardcoder") {
// WizardCoder models inherit patterns from their base model
if model_lower.contains("deepseek") {
format!("<fim▁begin>{prefix}<fim▁hole>{suffix}<fim▁end>")
} else {
// Most WizardCoder models use stable code pattern
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>") format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
} }
_ => { } else if model_lower.contains("starcoder")
// Generic FIM pattern - fallback for models without specific support || model_lower.contains("santacoder")
format!("// Complete the following code:\n{prefix}\n// COMPLETION HERE\n{suffix}") || model_lower.contains("stable")
|| model_lower.contains("qwen")
|| model_lower.contains("replit")
{
// Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
// non-coder Qwen models, and Replit models
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
} else {
// Default to stable code pattern for unknown models
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
} }
} }
fn get_stop_tokens(&self) -> Vec<String> {
let model_lower = self.model.to_lowercase();
let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
if model_lower.contains("qwen") && model_lower.contains("coder") {
stop_tokens.extend(vec![
"<|endoftext|>".to_string(),
"<|fim_prefix|>".to_string(),
"<|fim_middle|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|fim_pad|>".to_string(),
"<|repo_name|>".to_string(),
"<|file_sep|>".to_string(),
"<|im_start|>".to_string(),
"<|im_end|>".to_string(),
]);
} else if model_lower.contains("codellama") {
stop_tokens.extend(vec![
"<PRE>".to_string(),
"<SUF>".to_string(),
"<MID>".to_string(),
"</PRE>".to_string(),
]);
} else if model_lower.contains("deepseek") {
stop_tokens.extend(vec![
"<fim▁begin>".to_string(),
"<fim▁hole>".to_string(),
"<fim▁end>".to_string(),
"//".to_string(),
"<end▁of▁sentence>".to_string(),
]);
} else if model_lower.contains("codestral") {
stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]);
} else if model_lower.contains("codegemma") {
stop_tokens.extend(vec![
"<|fim_prefix|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|fim_middle|>".to_string(),
"<|file_separator|>".to_string(),
"<|endoftext|>".to_string(),
]);
} else if model_lower.contains("wizardcoder") {
// WizardCoder models inherit patterns from their base model
if model_lower.contains("deepseek") {
stop_tokens.extend(vec![
"<fim▁begin>".to_string(),
"<fim▁hole>".to_string(),
"<fim▁end>".to_string(),
]);
} else {
stop_tokens.extend(vec![
"<fim_prefix>".to_string(),
"<fim_suffix>".to_string(),
"<fim_middle>".to_string(),
"<|endoftext|>".to_string(),
]);
}
} else if model_lower.contains("starcoder")
|| model_lower.contains("santacoder")
|| model_lower.contains("stable")
|| model_lower.contains("qwen")
|| model_lower.contains("replit")
{
// Stable code pattern stop tokens
stop_tokens.extend(vec![
"<fim_prefix>".to_string(),
"<fim_suffix>".to_string(),
"<fim_middle>".to_string(),
"<|endoftext|>".to_string(),
]);
} else {
// Generic stop tokens for unknown models - cover both patterns
stop_tokens.extend(vec![
"<|fim_prefix|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|fim_middle|>".to_string(),
"<fim_prefix>".to_string(),
"<fim_suffix>".to_string(),
"<fim_middle>".to_string(),
"<|endoftext|>".to_string(),
]);
}
stop_tokens
}
fn clean_completion(&self, completion: &str) -> String {
let mut cleaned = completion.to_string();
// Remove common FIM tokens that might appear in responses
let fim_tokens = [
"<|fim_prefix|>",
"<|fim_suffix|>",
"<|fim_middle|>",
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<|im_start|>",
"<|im_end|>",
"<fim_prefix>",
"<fim_suffix>",
"<fim_middle>",
"<PRE>",
"<SUF>",
"<MID>",
"</PRE>",
"<fim▁begin>",
"<fim▁hole>",
"<fim▁end>",
"<end▁of▁sentence>",
"[PREFIX]",
"[SUFFIX]",
"<|file_separator|>",
"<|endoftext|>",
];
for token in &fim_tokens {
cleaned = cleaned.replace(token, "");
}
// Remove leading/trailing whitespace and common prefixes
cleaned = cleaned.trim().to_string();
// Remove common unwanted prefixes that models sometimes generate
let unwanted_prefixes = [
"// COMPLETION HERE",
"// Complete the following code:",
"// completion:",
"// TODO:",
];
for prefix in &unwanted_prefixes {
if cleaned.starts_with(prefix) {
cleaned = cleaned[prefix.len()..].trim_start().to_string();
}
}
cleaned
} }
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) { fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
@ -129,7 +290,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?; let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
let model = this.update(cx, |this, _| this.model.clone())?; let (model, stop_tokens) =
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
let request = GenerateRequest { let request = GenerateRequest {
model, model,
@ -139,12 +301,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
num_predict: Some(150), // Reasonable completion length num_predict: Some(150), // Reasonable completion length
temperature: Some(0.1), // Low temperature for more deterministic results temperature: Some(0.1), // Low temperature for more deterministic results
top_p: Some(0.95), top_p: Some(0.95),
stop: Some(vec![ stop: Some(stop_tokens),
"\n\n".to_string(),
"```".to_string(),
"</PRE>".to_string(),
"<SUF>".to_string(),
]),
}), }),
keep_alive: None, keep_alive: None,
context: None, context: None,
@ -156,8 +313,9 @@ impl EditPredictionProvider for OllamaCompletionProvider {
this.update(cx, |this, cx| { this.update(cx, |this, cx| {
this.pending_refresh = None; this.pending_refresh = None;
if !response.response.trim().is_empty() { let cleaned_completion = this.clean_completion(&response.response);
this.current_completion = Some(response.response); if !cleaned_completion.is_empty() {
this.current_completion = Some(cleaned_completion);
} else { } else {
this.current_completion = None; this.current_completion = None;
} }
@ -210,12 +368,9 @@ impl EditPredictionProvider for OllamaCompletionProvider {
let buffer_snapshot = buffer.read(cx); let buffer_snapshot = buffer.read(cx);
let position = cursor_position.bias_right(buffer_snapshot); let position = cursor_position.bias_right(buffer_snapshot);
// Clean up the completion text
let completion_text = completion_text.trim_start().trim_end();
Some(InlineCompletion { Some(InlineCompletion {
id: None, id: None,
edits: vec![(position..position, completion_text.to_string())], edits: vec![(position..position, completion_text)],
edit_preview: None, edit_preview: None,
}) })
} }
@ -229,7 +384,46 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
#[gpui::test] #[gpui::test]
async fn test_fim_prompt_patterns(_cx: &mut TestAppContext) { async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<|fim_prefix|>"));
assert!(prompt.contains("<|fim_suffix|>"));
assert!(prompt.contains("<|fim_middle|>"));
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5:32b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new( let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()), Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(), "http://localhost:11434".to_string(),
@ -281,6 +475,168 @@ mod tests {
assert!(prompt.contains("<fim_middle>")); assert!(prompt.contains("<fim_middle>"));
} }
#[gpui::test]
async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"replit-code:3b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// Replit should use stable code pattern (no pipes)
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
}
#[gpui::test]
async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codestral:22b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// Codestral uses suffix-first order
assert!(prompt.contains("[SUFFIX]"));
assert!(prompt.contains("[PREFIX]"));
assert!(prompt.starts_with("[SUFFIX]"));
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codegemma:7b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<|fim_prefix|>"));
assert!(prompt.contains("<|fim_suffix|>"));
assert!(prompt.contains("<|fim_middle|>"));
}
#[gpui::test]
async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"wizardcoder:13b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
}
#[gpui::test]
async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"santacoder:1b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
}
#[gpui::test]
async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let completion_with_tokens = "console.log('hello');<|fim_middle|>";
let cleaned = provider.clean_completion(completion_with_tokens);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>";
let cleaned = provider.clean_completion(completion_with_multiple_tokens);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_starcoder_tokens = "console.log('hello');<fim_middle>";
let cleaned = provider.clean_completion(completion_with_starcoder_tokens);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]";
let cleaned = provider.clean_completion(completion_with_codestral_tokens);
assert_eq!(cleaned, "console.log('hello');");
}
#[gpui::test]
async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let stop_tokens = provider.get_stop_tokens();
assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string()));
assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string()));
assert!(stop_tokens.contains(&"<|fim_middle|>".to_string()));
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
assert!(stop_tokens.contains(&"<|fim_pad|>".to_string()));
assert!(stop_tokens.contains(&"<|repo_name|>".to_string()));
}
#[gpui::test]
async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
);
let stop_tokens = provider.get_stop_tokens();
assert!(stop_tokens.contains(&"<PRE>".to_string()));
assert!(stop_tokens.contains(&"<SUF>".to_string()));
assert!(stop_tokens.contains(&"<MID>".to_string()));
assert!(stop_tokens.contains(&"</PRE>".to_string()));
}
#[gpui::test]
async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codestral:7b".to_string(),
);
let stop_tokens = provider.get_stop_tokens();
assert!(stop_tokens.contains(&"[PREFIX]".to_string()));
assert!(stop_tokens.contains(&"[SUFFIX]".to_string()));
}
#[gpui::test] #[gpui::test]
async fn test_extract_context(cx: &mut TestAppContext) { async fn test_extract_context(cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new( let provider = OllamaCompletionProvider::new(
@ -378,19 +734,31 @@ mod tests {
// Verify initial model // Verify initial model
assert_eq!(provider.model, "codellama:7b"); assert_eq!(provider.model, "codellama:7b");
// Test updating model // Test updating model to Qwen Coder
provider.update_model("deepseek-coder:6.7b".to_string()); provider.update_model("qwen2.5-coder:32b".to_string());
assert_eq!(provider.model, "deepseek-coder:6.7b"); assert_eq!(provider.model, "qwen2.5-coder:32b");
// Test FIM prompt changes with different model // Test FIM prompt changes with different model
let prefix = "def hello():"; let prefix = "def hello():";
let suffix = " pass"; let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix); let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use deepseek pattern // Should now use qwen coder pattern (with pipes)
assert!(prompt.contains("<fim▁begin>")); assert!(prompt.contains("<|fim_prefix|>"));
assert!(prompt.contains("<fim▁hole>")); assert!(prompt.contains("<|fim_suffix|>"));
assert!(prompt.contains("<fim▁end>")); assert!(prompt.contains("<|fim_middle|>"));
// Update to regular Qwen model (non-coder)
provider.update_model("qwen2.5:32b".to_string());
assert_eq!(provider.model, "qwen2.5:32b");
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use stable code pattern (no pipes)
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
// Update to starcoder model // Update to starcoder model
provider.update_model("starcoder:7b".to_string()); provider.update_model("starcoder:7b".to_string());
@ -398,9 +766,21 @@ mod tests {
let prompt = provider.build_fim_prompt(prefix, suffix); let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use starcoder pattern // Should also use stable code pattern (no pipes)
assert!(prompt.contains("<fim_prefix>")); assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>")); assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>")); assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
// Update to codestral model
provider.update_model("codestral:22b".to_string());
assert_eq!(provider.model, "codestral:22b");
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should use codestral pattern (suffix-first)
assert!(prompt.contains("[SUFFIX]"));
assert!(prompt.contains("[PREFIX]"));
assert!(prompt.starts_with("[SUFFIX]"));
} }
} }