Use Ollama's suffix field and remove FIM token handling

This commit is contained in:
Oliver Azevedo Barnes 2025-07-03 14:37:24 -03:00
parent 4b096b9a6b
commit b50555b87a
No known key found for this signature in database
2 changed files with 95 additions and 426 deletions

View file

@ -103,6 +103,7 @@ impl Model {
pub struct GenerateRequest {
pub model: String,
pub prompt: String,
pub suffix: Option<String>,
pub stream: bool,
pub options: Option<GenerateOptions>,
pub keep_alive: Option<KeepAlive>,
@ -425,6 +426,33 @@ pub async fn generate(
mod tests {
use super::*;
#[test]
fn test_generate_request_with_suffix_serialization() {
let request = GenerateRequest {
model: "qwen2.5-coder:32b".to_string(),
prompt: "def fibonacci(n):".to_string(),
suffix: Some(" return result".to_string()),
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150),
temperature: Some(0.1),
top_p: Some(0.95),
stop: Some(vec!["<|endoftext|>".to_string()]),
}),
keep_alive: None,
context: None,
};
let json = serde_json::to_string(&request).unwrap();
let parsed: GenerateRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.model, "qwen2.5-coder:32b");
assert_eq!(parsed.prompt, "def fibonacci(n):");
assert_eq!(parsed.suffix, Some(" return result".to_string()));
assert!(!parsed.stream);
assert!(parsed.options.is_some());
}
#[test]
fn parse_completion() {
let response = serde_json::json!({

View file

@ -37,167 +37,21 @@ impl OllamaCompletionProvider {
self.model = new_model;
}
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
// Use model-specific FIM patterns
let model_lower = self.model.to_lowercase();
if model_lower.contains("qwen") && model_lower.contains("coder") {
// QwenCoder models use pipes
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
} else if model_lower.contains("codellama") {
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
} else if model_lower.contains("deepseek") {
format!("<fim▁begin>{prefix}<fim▁hole>{suffix}<fim▁end>")
} else if model_lower.contains("codestral") {
// Codestral uses suffix-first order
format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
} else if model_lower.contains("codegemma") {
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
} else if model_lower.contains("wizardcoder") {
// WizardCoder models inherit patterns from their base model
if model_lower.contains("deepseek") {
format!("<fim▁begin>{prefix}<fim▁hole>{suffix}<fim▁end>")
} else {
// Most WizardCoder models use stable code pattern
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
}
} else if model_lower.contains("starcoder")
|| model_lower.contains("santacoder")
|| model_lower.contains("stable")
|| model_lower.contains("qwen")
|| model_lower.contains("replit")
{
// Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
// non-coder Qwen models, and Replit models
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
} else {
// Default to stable code pattern for unknown models
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
}
}
fn get_stop_tokens(&self) -> Vec<String> {
let model_lower = self.model.to_lowercase();
let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
if model_lower.contains("qwen") && model_lower.contains("coder") {
stop_tokens.extend(vec![
"<|endoftext|>".to_string(),
"<|fim_prefix|>".to_string(),
"<|fim_middle|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|fim_pad|>".to_string(),
"<|repo_name|>".to_string(),
"<|file_sep|>".to_string(),
"<|im_start|>".to_string(),
"<|im_end|>".to_string(),
]);
} else if model_lower.contains("codellama") {
stop_tokens.extend(vec![
"<PRE>".to_string(),
"<SUF>".to_string(),
"<MID>".to_string(),
"</PRE>".to_string(),
]);
} else if model_lower.contains("deepseek") {
stop_tokens.extend(vec![
"<fim▁begin>".to_string(),
"<fim▁hole>".to_string(),
"<fim▁end>".to_string(),
"//".to_string(),
"<end▁of▁sentence>".to_string(),
]);
} else if model_lower.contains("codestral") {
stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]);
} else if model_lower.contains("codegemma") {
stop_tokens.extend(vec![
"<|fim_prefix|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|fim_middle|>".to_string(),
"<|file_separator|>".to_string(),
"<|endoftext|>".to_string(),
]);
} else if model_lower.contains("wizardcoder") {
// WizardCoder models inherit patterns from their base model
if model_lower.contains("deepseek") {
stop_tokens.extend(vec![
"<fim▁begin>".to_string(),
"<fim▁hole>".to_string(),
"<fim▁end>".to_string(),
]);
} else {
stop_tokens.extend(vec![
"<fim_prefix>".to_string(),
"<fim_suffix>".to_string(),
"<fim_middle>".to_string(),
"<|endoftext|>".to_string(),
]);
}
} else if model_lower.contains("starcoder")
|| model_lower.contains("santacoder")
|| model_lower.contains("stable")
|| model_lower.contains("qwen")
|| model_lower.contains("replit")
{
// Stable code pattern stop tokens
stop_tokens.extend(vec![
"<fim_prefix>".to_string(),
"<fim_suffix>".to_string(),
"<fim_middle>".to_string(),
"<|endoftext|>".to_string(),
]);
} else {
// Generic stop tokens for unknown models - cover both patterns
stop_tokens.extend(vec![
"<|fim_prefix|>".to_string(),
"<|fim_suffix|>".to_string(),
"<|fim_middle|>".to_string(),
"<fim_prefix>".to_string(),
"<fim_suffix>".to_string(),
"<fim_middle>".to_string(),
"<|endoftext|>".to_string(),
]);
}
stop_tokens
// Basic stop tokens for code completion
// Ollama handles FIM tokens internally, so we only need general completion stops
vec![
"\n\n".to_string(), // Double newline often indicates end of completion
"```".to_string(), // Code block delimiter
"<|endoftext|>".to_string(), // Common model end token
]
}
fn clean_completion(&self, completion: &str) -> String {
let mut cleaned = completion.to_string();
// Remove common FIM tokens that might appear in responses
let fim_tokens = [
"<|fim_prefix|>",
"<|fim_suffix|>",
"<|fim_middle|>",
"<|fim_pad|>",
"<|repo_name|>",
"<|file_sep|>",
"<|im_start|>",
"<|im_end|>",
"<fim_prefix>",
"<fim_suffix>",
"<fim_middle>",
"<PRE>",
"<SUF>",
"<MID>",
"</PRE>",
"<fim▁begin>",
"<fim▁hole>",
"<fim▁end>",
"<end▁of▁sentence>",
"[PREFIX]",
"[SUFFIX]",
"<|file_separator|>",
"<|endoftext|>",
];
for token in &fim_tokens {
cleaned = cleaned.replace(token, "");
}
// Remove leading/trailing whitespace and common prefixes
// Basic cleaning - Ollama should handle FIM tokens internally
// Remove leading/trailing whitespace
cleaned = cleaned.trim().to_string();
// Remove common unwanted prefixes that models sometimes generate
@ -288,14 +142,13 @@ impl EditPredictionProvider for OllamaCompletionProvider {
this.extract_context(buffer_snapshot, cursor_position)
})?;
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
let (model, stop_tokens) =
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
let request = GenerateRequest {
model,
prompt,
prompt: prefix,
suffix: Some(suffix),
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150), // Reasonable completion length
@ -384,216 +237,7 @@ mod tests {
use std::sync::Arc;
#[gpui::test]
async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<|fim_prefix|>"));
assert!(prompt.contains("<|fim_suffix|>"));
assert!(prompt.contains("<|fim_middle|>"));
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5:32b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
);
let prefix = "function hello() {";
let suffix = "}";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<PRE>"));
assert!(prompt.contains("<SUF>"));
assert!(prompt.contains("<MID>"));
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"deepseek-coder:6.7b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim▁begin>"));
assert!(prompt.contains("<fim▁hole>"));
assert!(prompt.contains("<fim▁end>"));
}
#[gpui::test]
async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"starcoder:7b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
}
#[gpui::test]
async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"replit-code:3b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// Replit should use stable code pattern (no pipes)
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
}
#[gpui::test]
async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codestral:22b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// Codestral uses suffix-first order
assert!(prompt.contains("[SUFFIX]"));
assert!(prompt.contains("[PREFIX]"));
assert!(prompt.starts_with("[SUFFIX]"));
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codegemma:7b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<|fim_prefix|>"));
assert!(prompt.contains("<|fim_suffix|>"));
assert!(prompt.contains("<|fim_middle|>"));
}
#[gpui::test]
async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"wizardcoder:13b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
}
#[gpui::test]
async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"santacoder:1b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
}
#[gpui::test]
async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let completion_with_tokens = "console.log('hello');<|fim_middle|>";
let cleaned = provider.clean_completion(completion_with_tokens);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>";
let cleaned = provider.clean_completion(completion_with_multiple_tokens);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_starcoder_tokens = "console.log('hello');<fim_middle>";
let cleaned = provider.clean_completion(completion_with_starcoder_tokens);
assert_eq!(cleaned, "console.log('hello');");
let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]";
let cleaned = provider.clean_completion(completion_with_codestral_tokens);
assert_eq!(cleaned, "console.log('hello');");
}
#[gpui::test]
async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) {
async fn test_get_stop_tokens(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
@ -601,40 +245,27 @@ mod tests {
);
let stop_tokens = provider.get_stop_tokens();
assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string()));
assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string()));
assert!(stop_tokens.contains(&"<|fim_middle|>".to_string()));
assert!(stop_tokens.contains(&"\n\n".to_string()));
assert!(stop_tokens.contains(&"```".to_string()));
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
assert!(stop_tokens.contains(&"<|fim_pad|>".to_string()));
assert!(stop_tokens.contains(&"<|repo_name|>".to_string()));
assert_eq!(stop_tokens.len(), 3);
}
#[gpui::test]
async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) {
async fn test_clean_completion_basic(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let stop_tokens = provider.get_stop_tokens();
assert!(stop_tokens.contains(&"<PRE>".to_string()));
assert!(stop_tokens.contains(&"<SUF>".to_string()));
assert!(stop_tokens.contains(&"<MID>".to_string()));
assert!(stop_tokens.contains(&"</PRE>".to_string()));
}
let completion = " console.log('hello'); ";
let cleaned = provider.clean_completion(completion);
assert_eq!(cleaned, "console.log('hello');");
#[gpui::test]
async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codestral:7b".to_string(),
);
let stop_tokens = provider.get_stop_tokens();
assert!(stop_tokens.contains(&"[PREFIX]".to_string()));
assert!(stop_tokens.contains(&"[SUFFIX]".to_string()));
let completion_with_prefix = "// COMPLETION HERE\nconsole.log('hello');";
let cleaned = provider.clean_completion(completion_with_prefix);
assert_eq!(cleaned, "console.log('hello');");
}
#[gpui::test]
@ -738,49 +369,59 @@ mod tests {
provider.update_model("qwen2.5-coder:32b".to_string());
assert_eq!(provider.model, "qwen2.5-coder:32b");
// Test FIM prompt changes with different model
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use qwen coder pattern (with pipes)
assert!(prompt.contains("<|fim_prefix|>"));
assert!(prompt.contains("<|fim_suffix|>"));
assert!(prompt.contains("<|fim_middle|>"));
// Update to regular Qwen model (non-coder)
// Test updating to different models
provider.update_model("qwen2.5:32b".to_string());
assert_eq!(provider.model, "qwen2.5:32b");
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use stable code pattern (no pipes)
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
// Update to starcoder model
provider.update_model("starcoder:7b".to_string());
assert_eq!(provider.model, "starcoder:7b");
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should also use stable code pattern (no pipes)
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
// Update to codestral model
provider.update_model("codestral:22b".to_string());
assert_eq!(provider.model, "codestral:22b");
let prompt = provider.build_fim_prompt(prefix, suffix);
// FIM patterns are now handled by Ollama natively, so we just test model updates
provider.update_model("deepseek-coder:6.7b".to_string());
assert_eq!(provider.model, "deepseek-coder:6.7b");
}
// Should use codestral pattern (suffix-first)
assert!(prompt.contains("[SUFFIX]"));
assert!(prompt.contains("[PREFIX]"));
assert!(prompt.starts_with("[SUFFIX]"));
#[gpui::test]
async fn test_native_fim_request_structure(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
);
let prefix = "def fibonacci(n):";
let suffix = " return result";
// Test that we create the correct request structure for native FIM
let request = GenerateRequest {
model: provider.model.clone(),
prompt: prefix.to_string(),
suffix: Some(suffix.to_string()),
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150),
temperature: Some(0.1),
top_p: Some(0.95),
stop: Some(provider.get_stop_tokens()),
}),
keep_alive: None,
context: None,
};
// Verify the request structure uses native FIM approach
assert_eq!(request.model, "qwen2.5-coder:32b");
assert_eq!(request.prompt, "def fibonacci(n):");
assert_eq!(request.suffix, Some(" return result".to_string()));
assert!(!request.stream);
// Verify stop tokens are simplified (no FIM-specific tokens)
let stop_tokens = request.options.as_ref().unwrap().stop.as_ref().unwrap();
assert!(stop_tokens.contains(&"\n\n".to_string()));
assert!(stop_tokens.contains(&"```".to_string()));
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
assert_eq!(stop_tokens.len(), 3);
}
}