Improved FIM token handling per model
This commit is contained in:
parent
902a07606b
commit
af66570bfe
1 changed files with 412 additions and 32 deletions
|
@ -39,21 +39,182 @@ impl OllamaCompletionProvider {
|
||||||
|
|
||||||
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
||||||
// Use model-specific FIM patterns
|
// Use model-specific FIM patterns
|
||||||
match self.model.as_str() {
|
let model_lower = self.model.to_lowercase();
|
||||||
m if m.contains("codellama") => {
|
|
||||||
|
if model_lower.contains("qwen") && model_lower.contains("coder") {
|
||||||
|
// QwenCoder models use pipes
|
||||||
|
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
|
||||||
|
} else if model_lower.contains("codellama") {
|
||||||
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
|
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
|
||||||
}
|
} else if model_lower.contains("deepseek") {
|
||||||
m if m.contains("deepseek") => {
|
|
||||||
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
||||||
}
|
} else if model_lower.contains("codestral") {
|
||||||
m if m.contains("starcoder") => {
|
// Codestral uses suffix-first order
|
||||||
|
format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
|
||||||
|
} else if model_lower.contains("codegemma") {
|
||||||
|
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
|
||||||
|
} else if model_lower.contains("wizardcoder") {
|
||||||
|
// WizardCoder models inherit patterns from their base model
|
||||||
|
if model_lower.contains("deepseek") {
|
||||||
|
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
||||||
|
} else {
|
||||||
|
// Most WizardCoder models use stable code pattern
|
||||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||||
}
|
}
|
||||||
_ => {
|
} else if model_lower.contains("starcoder")
|
||||||
// Generic FIM pattern - fallback for models without specific support
|
|| model_lower.contains("santacoder")
|
||||||
format!("// Complete the following code:\n{prefix}\n// COMPLETION HERE\n{suffix}")
|
|| model_lower.contains("stable")
|
||||||
|
|| model_lower.contains("qwen")
|
||||||
|
|| model_lower.contains("replit")
|
||||||
|
{
|
||||||
|
// Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
|
||||||
|
// non-coder Qwen models, and Replit models
|
||||||
|
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||||
|
} else {
|
||||||
|
// Default to stable code pattern for unknown models
|
||||||
|
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_stop_tokens(&self) -> Vec<String> {
|
||||||
|
let model_lower = self.model.to_lowercase();
|
||||||
|
|
||||||
|
let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
|
||||||
|
|
||||||
|
if model_lower.contains("qwen") && model_lower.contains("coder") {
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<|endoftext|>".to_string(),
|
||||||
|
"<|fim_prefix|>".to_string(),
|
||||||
|
"<|fim_middle|>".to_string(),
|
||||||
|
"<|fim_suffix|>".to_string(),
|
||||||
|
"<|fim_pad|>".to_string(),
|
||||||
|
"<|repo_name|>".to_string(),
|
||||||
|
"<|file_sep|>".to_string(),
|
||||||
|
"<|im_start|>".to_string(),
|
||||||
|
"<|im_end|>".to_string(),
|
||||||
|
]);
|
||||||
|
} else if model_lower.contains("codellama") {
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<PRE>".to_string(),
|
||||||
|
"<SUF>".to_string(),
|
||||||
|
"<MID>".to_string(),
|
||||||
|
"</PRE>".to_string(),
|
||||||
|
]);
|
||||||
|
} else if model_lower.contains("deepseek") {
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<|fim▁begin|>".to_string(),
|
||||||
|
"<|fim▁hole|>".to_string(),
|
||||||
|
"<|fim▁end|>".to_string(),
|
||||||
|
"//".to_string(),
|
||||||
|
"<|end▁of▁sentence|>".to_string(),
|
||||||
|
]);
|
||||||
|
} else if model_lower.contains("codestral") {
|
||||||
|
stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]);
|
||||||
|
} else if model_lower.contains("codegemma") {
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<|fim_prefix|>".to_string(),
|
||||||
|
"<|fim_suffix|>".to_string(),
|
||||||
|
"<|fim_middle|>".to_string(),
|
||||||
|
"<|file_separator|>".to_string(),
|
||||||
|
"<|endoftext|>".to_string(),
|
||||||
|
]);
|
||||||
|
} else if model_lower.contains("wizardcoder") {
|
||||||
|
// WizardCoder models inherit patterns from their base model
|
||||||
|
if model_lower.contains("deepseek") {
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<|fim▁begin|>".to_string(),
|
||||||
|
"<|fim▁hole|>".to_string(),
|
||||||
|
"<|fim▁end|>".to_string(),
|
||||||
|
]);
|
||||||
|
} else {
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<fim_prefix>".to_string(),
|
||||||
|
"<fim_suffix>".to_string(),
|
||||||
|
"<fim_middle>".to_string(),
|
||||||
|
"<|endoftext|>".to_string(),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
} else if model_lower.contains("starcoder")
|
||||||
|
|| model_lower.contains("santacoder")
|
||||||
|
|| model_lower.contains("stable")
|
||||||
|
|| model_lower.contains("qwen")
|
||||||
|
|| model_lower.contains("replit")
|
||||||
|
{
|
||||||
|
// Stable code pattern stop tokens
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<fim_prefix>".to_string(),
|
||||||
|
"<fim_suffix>".to_string(),
|
||||||
|
"<fim_middle>".to_string(),
|
||||||
|
"<|endoftext|>".to_string(),
|
||||||
|
]);
|
||||||
|
} else {
|
||||||
|
// Generic stop tokens for unknown models - cover both patterns
|
||||||
|
stop_tokens.extend(vec![
|
||||||
|
"<|fim_prefix|>".to_string(),
|
||||||
|
"<|fim_suffix|>".to_string(),
|
||||||
|
"<|fim_middle|>".to_string(),
|
||||||
|
"<fim_prefix>".to_string(),
|
||||||
|
"<fim_suffix>".to_string(),
|
||||||
|
"<fim_middle>".to_string(),
|
||||||
|
"<|endoftext|>".to_string(),
|
||||||
|
]);
|
||||||
|
}
|
||||||
|
|
||||||
|
stop_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clean_completion(&self, completion: &str) -> String {
|
||||||
|
let mut cleaned = completion.to_string();
|
||||||
|
|
||||||
|
// Remove common FIM tokens that might appear in responses
|
||||||
|
let fim_tokens = [
|
||||||
|
"<|fim_prefix|>",
|
||||||
|
"<|fim_suffix|>",
|
||||||
|
"<|fim_middle|>",
|
||||||
|
"<|fim_pad|>",
|
||||||
|
"<|repo_name|>",
|
||||||
|
"<|file_sep|>",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
"<fim_prefix>",
|
||||||
|
"<fim_suffix>",
|
||||||
|
"<fim_middle>",
|
||||||
|
"<PRE>",
|
||||||
|
"<SUF>",
|
||||||
|
"<MID>",
|
||||||
|
"</PRE>",
|
||||||
|
"<|fim▁begin|>",
|
||||||
|
"<|fim▁hole|>",
|
||||||
|
"<|fim▁end|>",
|
||||||
|
"<|end▁of▁sentence|>",
|
||||||
|
"[PREFIX]",
|
||||||
|
"[SUFFIX]",
|
||||||
|
"<|file_separator|>",
|
||||||
|
"<|endoftext|>",
|
||||||
|
];
|
||||||
|
|
||||||
|
for token in &fim_tokens {
|
||||||
|
cleaned = cleaned.replace(token, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove leading/trailing whitespace and common prefixes
|
||||||
|
cleaned = cleaned.trim().to_string();
|
||||||
|
|
||||||
|
// Remove common unwanted prefixes that models sometimes generate
|
||||||
|
let unwanted_prefixes = [
|
||||||
|
"// COMPLETION HERE",
|
||||||
|
"// Complete the following code:",
|
||||||
|
"// completion:",
|
||||||
|
"// TODO:",
|
||||||
|
];
|
||||||
|
|
||||||
|
for prefix in &unwanted_prefixes {
|
||||||
|
if cleaned.starts_with(prefix) {
|
||||||
|
cleaned = cleaned[prefix.len()..].trim_start().to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cleaned
|
||||||
}
|
}
|
||||||
|
|
||||||
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
|
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
|
||||||
|
@ -129,7 +290,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
|
||||||
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
||||||
|
|
||||||
let model = this.update(cx, |this, _| this.model.clone())?;
|
let (model, stop_tokens) =
|
||||||
|
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
|
||||||
|
|
||||||
let request = GenerateRequest {
|
let request = GenerateRequest {
|
||||||
model,
|
model,
|
||||||
|
@ -139,12 +301,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
num_predict: Some(150), // Reasonable completion length
|
num_predict: Some(150), // Reasonable completion length
|
||||||
temperature: Some(0.1), // Low temperature for more deterministic results
|
temperature: Some(0.1), // Low temperature for more deterministic results
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
stop: Some(vec![
|
stop: Some(stop_tokens),
|
||||||
"\n\n".to_string(),
|
|
||||||
"```".to_string(),
|
|
||||||
"</PRE>".to_string(),
|
|
||||||
"<SUF>".to_string(),
|
|
||||||
]),
|
|
||||||
}),
|
}),
|
||||||
keep_alive: None,
|
keep_alive: None,
|
||||||
context: None,
|
context: None,
|
||||||
|
@ -156,8 +313,9 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.pending_refresh = None;
|
this.pending_refresh = None;
|
||||||
if !response.response.trim().is_empty() {
|
let cleaned_completion = this.clean_completion(&response.response);
|
||||||
this.current_completion = Some(response.response);
|
if !cleaned_completion.is_empty() {
|
||||||
|
this.current_completion = Some(cleaned_completion);
|
||||||
} else {
|
} else {
|
||||||
this.current_completion = None;
|
this.current_completion = None;
|
||||||
}
|
}
|
||||||
|
@ -210,12 +368,9 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
let buffer_snapshot = buffer.read(cx);
|
let buffer_snapshot = buffer.read(cx);
|
||||||
let position = cursor_position.bias_right(buffer_snapshot);
|
let position = cursor_position.bias_right(buffer_snapshot);
|
||||||
|
|
||||||
// Clean up the completion text
|
|
||||||
let completion_text = completion_text.trim_start().trim_end();
|
|
||||||
|
|
||||||
Some(InlineCompletion {
|
Some(InlineCompletion {
|
||||||
id: None,
|
id: None,
|
||||||
edits: vec![(position..position, completion_text.to_string())],
|
edits: vec![(position..position, completion_text)],
|
||||||
edit_preview: None,
|
edit_preview: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -229,7 +384,46 @@ mod tests {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_fim_prompt_patterns(_cx: &mut TestAppContext) {
|
async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<|fim_prefix|>"));
|
||||||
|
assert!(prompt.contains("<|fim_suffix|>"));
|
||||||
|
assert!(prompt.contains("<|fim_middle|>"));
|
||||||
|
assert!(prompt.contains(prefix));
|
||||||
|
assert!(prompt.contains(suffix));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5:32b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||||
|
assert!(prompt.contains(prefix));
|
||||||
|
assert!(prompt.contains(suffix));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) {
|
||||||
let provider = OllamaCompletionProvider::new(
|
let provider = OllamaCompletionProvider::new(
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
|
@ -281,6 +475,168 @@ mod tests {
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"replit-code:3b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// Replit should use stable code pattern (no pipes)
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codestral:22b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// Codestral uses suffix-first order
|
||||||
|
assert!(prompt.contains("[SUFFIX]"));
|
||||||
|
assert!(prompt.contains("[PREFIX]"));
|
||||||
|
assert!(prompt.starts_with("[SUFFIX]"));
|
||||||
|
assert!(prompt.contains(prefix));
|
||||||
|
assert!(prompt.contains(suffix));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codegemma:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<|fim_prefix|>"));
|
||||||
|
assert!(prompt.contains("<|fim_suffix|>"));
|
||||||
|
assert!(prompt.contains("<|fim_middle|>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"wizardcoder:13b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"santacoder:1b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let completion_with_tokens = "console.log('hello');<|fim_middle|>";
|
||||||
|
let cleaned = provider.clean_completion(completion_with_tokens);
|
||||||
|
assert_eq!(cleaned, "console.log('hello');");
|
||||||
|
|
||||||
|
let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>";
|
||||||
|
let cleaned = provider.clean_completion(completion_with_multiple_tokens);
|
||||||
|
assert_eq!(cleaned, "console.log('hello');");
|
||||||
|
|
||||||
|
let completion_with_starcoder_tokens = "console.log('hello');<fim_middle>";
|
||||||
|
let cleaned = provider.clean_completion(completion_with_starcoder_tokens);
|
||||||
|
assert_eq!(cleaned, "console.log('hello');");
|
||||||
|
|
||||||
|
let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]";
|
||||||
|
let cleaned = provider.clean_completion(completion_with_codestral_tokens);
|
||||||
|
assert_eq!(cleaned, "console.log('hello');");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let stop_tokens = provider.get_stop_tokens();
|
||||||
|
assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<|fim_middle|>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<|fim_pad|>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<|repo_name|>".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let stop_tokens = provider.get_stop_tokens();
|
||||||
|
assert!(stop_tokens.contains(&"<PRE>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<SUF>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<MID>".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"</PRE>".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codestral:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let stop_tokens = provider.get_stop_tokens();
|
||||||
|
assert!(stop_tokens.contains(&"[PREFIX]".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"[SUFFIX]".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_extract_context(cx: &mut TestAppContext) {
|
async fn test_extract_context(cx: &mut TestAppContext) {
|
||||||
let provider = OllamaCompletionProvider::new(
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
@ -378,19 +734,31 @@ mod tests {
|
||||||
// Verify initial model
|
// Verify initial model
|
||||||
assert_eq!(provider.model, "codellama:7b");
|
assert_eq!(provider.model, "codellama:7b");
|
||||||
|
|
||||||
// Test updating model
|
// Test updating model to Qwen Coder
|
||||||
provider.update_model("deepseek-coder:6.7b".to_string());
|
provider.update_model("qwen2.5-coder:32b".to_string());
|
||||||
assert_eq!(provider.model, "deepseek-coder:6.7b");
|
assert_eq!(provider.model, "qwen2.5-coder:32b");
|
||||||
|
|
||||||
// Test FIM prompt changes with different model
|
// Test FIM prompt changes with different model
|
||||||
let prefix = "def hello():";
|
let prefix = "def hello():";
|
||||||
let suffix = " pass";
|
let suffix = " pass";
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
// Should now use deepseek pattern
|
// Should now use qwen coder pattern (with pipes)
|
||||||
assert!(prompt.contains("<|fim▁begin|>"));
|
assert!(prompt.contains("<|fim_prefix|>"));
|
||||||
assert!(prompt.contains("<|fim▁hole|>"));
|
assert!(prompt.contains("<|fim_suffix|>"));
|
||||||
assert!(prompt.contains("<|fim▁end|>"));
|
assert!(prompt.contains("<|fim_middle|>"));
|
||||||
|
|
||||||
|
// Update to regular Qwen model (non-coder)
|
||||||
|
provider.update_model("qwen2.5:32b".to_string());
|
||||||
|
assert_eq!(provider.model, "qwen2.5:32b");
|
||||||
|
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// Should now use stable code pattern (no pipes)
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||||
|
|
||||||
// Update to starcoder model
|
// Update to starcoder model
|
||||||
provider.update_model("starcoder:7b".to_string());
|
provider.update_model("starcoder:7b".to_string());
|
||||||
|
@ -398,9 +766,21 @@ mod tests {
|
||||||
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
// Should now use starcoder pattern
|
// Should also use stable code pattern (no pipes)
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||||
|
|
||||||
|
// Update to codestral model
|
||||||
|
provider.update_model("codestral:22b".to_string());
|
||||||
|
assert_eq!(provider.model, "codestral:22b");
|
||||||
|
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// Should use codestral pattern (suffix-first)
|
||||||
|
assert!(prompt.contains("[SUFFIX]"));
|
||||||
|
assert!(prompt.contains("[PREFIX]"));
|
||||||
|
assert!(prompt.starts_with("[SUFFIX]"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue