Use Ollama's suffix field and remove FIM token handling
This commit is contained in:
parent
4b096b9a6b
commit
b50555b87a
2 changed files with 95 additions and 426 deletions
|
@ -37,167 +37,21 @@ impl OllamaCompletionProvider {
|
|||
self.model = new_model;
|
||||
}
|
||||
|
||||
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
||||
// Use model-specific FIM patterns
|
||||
let model_lower = self.model.to_lowercase();
|
||||
|
||||
if model_lower.contains("qwen") && model_lower.contains("coder") {
|
||||
// QwenCoder models use pipes
|
||||
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
|
||||
} else if model_lower.contains("codellama") {
|
||||
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
|
||||
} else if model_lower.contains("deepseek") {
|
||||
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
||||
} else if model_lower.contains("codestral") {
|
||||
// Codestral uses suffix-first order
|
||||
format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
|
||||
} else if model_lower.contains("codegemma") {
|
||||
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
|
||||
} else if model_lower.contains("wizardcoder") {
|
||||
// WizardCoder models inherit patterns from their base model
|
||||
if model_lower.contains("deepseek") {
|
||||
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
||||
} else {
|
||||
// Most WizardCoder models use stable code pattern
|
||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||
}
|
||||
} else if model_lower.contains("starcoder")
|
||||
|| model_lower.contains("santacoder")
|
||||
|| model_lower.contains("stable")
|
||||
|| model_lower.contains("qwen")
|
||||
|| model_lower.contains("replit")
|
||||
{
|
||||
// Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
|
||||
// non-coder Qwen models, and Replit models
|
||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||
} else {
|
||||
// Default to stable code pattern for unknown models
|
||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_stop_tokens(&self) -> Vec<String> {
|
||||
let model_lower = self.model.to_lowercase();
|
||||
|
||||
let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
|
||||
|
||||
if model_lower.contains("qwen") && model_lower.contains("coder") {
|
||||
stop_tokens.extend(vec![
|
||||
"<|endoftext|>".to_string(),
|
||||
"<|fim_prefix|>".to_string(),
|
||||
"<|fim_middle|>".to_string(),
|
||||
"<|fim_suffix|>".to_string(),
|
||||
"<|fim_pad|>".to_string(),
|
||||
"<|repo_name|>".to_string(),
|
||||
"<|file_sep|>".to_string(),
|
||||
"<|im_start|>".to_string(),
|
||||
"<|im_end|>".to_string(),
|
||||
]);
|
||||
} else if model_lower.contains("codellama") {
|
||||
stop_tokens.extend(vec![
|
||||
"<PRE>".to_string(),
|
||||
"<SUF>".to_string(),
|
||||
"<MID>".to_string(),
|
||||
"</PRE>".to_string(),
|
||||
]);
|
||||
} else if model_lower.contains("deepseek") {
|
||||
stop_tokens.extend(vec![
|
||||
"<|fim▁begin|>".to_string(),
|
||||
"<|fim▁hole|>".to_string(),
|
||||
"<|fim▁end|>".to_string(),
|
||||
"//".to_string(),
|
||||
"<|end▁of▁sentence|>".to_string(),
|
||||
]);
|
||||
} else if model_lower.contains("codestral") {
|
||||
stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]);
|
||||
} else if model_lower.contains("codegemma") {
|
||||
stop_tokens.extend(vec![
|
||||
"<|fim_prefix|>".to_string(),
|
||||
"<|fim_suffix|>".to_string(),
|
||||
"<|fim_middle|>".to_string(),
|
||||
"<|file_separator|>".to_string(),
|
||||
"<|endoftext|>".to_string(),
|
||||
]);
|
||||
} else if model_lower.contains("wizardcoder") {
|
||||
// WizardCoder models inherit patterns from their base model
|
||||
if model_lower.contains("deepseek") {
|
||||
stop_tokens.extend(vec![
|
||||
"<|fim▁begin|>".to_string(),
|
||||
"<|fim▁hole|>".to_string(),
|
||||
"<|fim▁end|>".to_string(),
|
||||
]);
|
||||
} else {
|
||||
stop_tokens.extend(vec![
|
||||
"<fim_prefix>".to_string(),
|
||||
"<fim_suffix>".to_string(),
|
||||
"<fim_middle>".to_string(),
|
||||
"<|endoftext|>".to_string(),
|
||||
]);
|
||||
}
|
||||
} else if model_lower.contains("starcoder")
|
||||
|| model_lower.contains("santacoder")
|
||||
|| model_lower.contains("stable")
|
||||
|| model_lower.contains("qwen")
|
||||
|| model_lower.contains("replit")
|
||||
{
|
||||
// Stable code pattern stop tokens
|
||||
stop_tokens.extend(vec![
|
||||
"<fim_prefix>".to_string(),
|
||||
"<fim_suffix>".to_string(),
|
||||
"<fim_middle>".to_string(),
|
||||
"<|endoftext|>".to_string(),
|
||||
]);
|
||||
} else {
|
||||
// Generic stop tokens for unknown models - cover both patterns
|
||||
stop_tokens.extend(vec![
|
||||
"<|fim_prefix|>".to_string(),
|
||||
"<|fim_suffix|>".to_string(),
|
||||
"<|fim_middle|>".to_string(),
|
||||
"<fim_prefix>".to_string(),
|
||||
"<fim_suffix>".to_string(),
|
||||
"<fim_middle>".to_string(),
|
||||
"<|endoftext|>".to_string(),
|
||||
]);
|
||||
}
|
||||
|
||||
stop_tokens
|
||||
// Basic stop tokens for code completion
|
||||
// Ollama handles FIM tokens internally, so we only need general completion stops
|
||||
vec![
|
||||
"\n\n".to_string(), // Double newline often indicates end of completion
|
||||
"```".to_string(), // Code block delimiter
|
||||
"<|endoftext|>".to_string(), // Common model end token
|
||||
]
|
||||
}
|
||||
|
||||
fn clean_completion(&self, completion: &str) -> String {
|
||||
let mut cleaned = completion.to_string();
|
||||
|
||||
// Remove common FIM tokens that might appear in responses
|
||||
let fim_tokens = [
|
||||
"<|fim_prefix|>",
|
||||
"<|fim_suffix|>",
|
||||
"<|fim_middle|>",
|
||||
"<|fim_pad|>",
|
||||
"<|repo_name|>",
|
||||
"<|file_sep|>",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<fim_prefix>",
|
||||
"<fim_suffix>",
|
||||
"<fim_middle>",
|
||||
"<PRE>",
|
||||
"<SUF>",
|
||||
"<MID>",
|
||||
"</PRE>",
|
||||
"<|fim▁begin|>",
|
||||
"<|fim▁hole|>",
|
||||
"<|fim▁end|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
"[PREFIX]",
|
||||
"[SUFFIX]",
|
||||
"<|file_separator|>",
|
||||
"<|endoftext|>",
|
||||
];
|
||||
|
||||
for token in &fim_tokens {
|
||||
cleaned = cleaned.replace(token, "");
|
||||
}
|
||||
|
||||
// Remove leading/trailing whitespace and common prefixes
|
||||
// Basic cleaning - Ollama should handle FIM tokens internally
|
||||
// Remove leading/trailing whitespace
|
||||
cleaned = cleaned.trim().to_string();
|
||||
|
||||
// Remove common unwanted prefixes that models sometimes generate
|
||||
|
@ -288,14 +142,13 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
this.extract_context(buffer_snapshot, cursor_position)
|
||||
})?;
|
||||
|
||||
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
||||
|
||||
let (model, stop_tokens) =
|
||||
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
|
||||
|
||||
let request = GenerateRequest {
|
||||
model,
|
||||
prompt,
|
||||
prompt: prefix,
|
||||
suffix: Some(suffix),
|
||||
stream: false,
|
||||
options: Some(GenerateOptions {
|
||||
num_predict: Some(150), // Reasonable completion length
|
||||
|
@ -384,216 +237,7 @@ mod tests {
|
|||
use std::sync::Arc;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5-coder:32b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<|fim_prefix|>"));
|
||||
assert!(prompt.contains("<|fim_suffix|>"));
|
||||
assert!(prompt.contains("<|fim_middle|>"));
|
||||
assert!(prompt.contains(prefix));
|
||||
assert!(prompt.contains(suffix));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5:32b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||
assert!(prompt.contains(prefix));
|
||||
assert!(prompt.contains(suffix));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codellama:7b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "function hello() {";
|
||||
let suffix = "}";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<PRE>"));
|
||||
assert!(prompt.contains("<SUF>"));
|
||||
assert!(prompt.contains("<MID>"));
|
||||
assert!(prompt.contains(prefix));
|
||||
assert!(prompt.contains(suffix));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"deepseek-coder:6.7b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<|fim▁begin|>"));
|
||||
assert!(prompt.contains("<|fim▁hole|>"));
|
||||
assert!(prompt.contains("<|fim▁end|>"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"starcoder:7b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"replit-code:3b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
// Replit should use stable code pattern (no pipes)
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codestral:22b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
// Codestral uses suffix-first order
|
||||
assert!(prompt.contains("[SUFFIX]"));
|
||||
assert!(prompt.contains("[PREFIX]"));
|
||||
assert!(prompt.starts_with("[SUFFIX]"));
|
||||
assert!(prompt.contains(prefix));
|
||||
assert!(prompt.contains(suffix));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codegemma:7b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<|fim_prefix|>"));
|
||||
assert!(prompt.contains("<|fim_suffix|>"));
|
||||
assert!(prompt.contains("<|fim_middle|>"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"wizardcoder:13b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
// WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"santacoder:1b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5-coder:32b".to_string(),
|
||||
);
|
||||
|
||||
let completion_with_tokens = "console.log('hello');<|fim_middle|>";
|
||||
let cleaned = provider.clean_completion(completion_with_tokens);
|
||||
assert_eq!(cleaned, "console.log('hello');");
|
||||
|
||||
let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>";
|
||||
let cleaned = provider.clean_completion(completion_with_multiple_tokens);
|
||||
assert_eq!(cleaned, "console.log('hello');");
|
||||
|
||||
let completion_with_starcoder_tokens = "console.log('hello');<fim_middle>";
|
||||
let cleaned = provider.clean_completion(completion_with_starcoder_tokens);
|
||||
assert_eq!(cleaned, "console.log('hello');");
|
||||
|
||||
let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]";
|
||||
let cleaned = provider.clean_completion(completion_with_codestral_tokens);
|
||||
assert_eq!(cleaned, "console.log('hello');");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) {
|
||||
async fn test_get_stop_tokens(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
|
@ -601,40 +245,27 @@ mod tests {
|
|||
);
|
||||
|
||||
let stop_tokens = provider.get_stop_tokens();
|
||||
assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string()));
|
||||
assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string()));
|
||||
assert!(stop_tokens.contains(&"<|fim_middle|>".to_string()));
|
||||
assert!(stop_tokens.contains(&"\n\n".to_string()));
|
||||
assert!(stop_tokens.contains(&"```".to_string()));
|
||||
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
||||
assert!(stop_tokens.contains(&"<|fim_pad|>".to_string()));
|
||||
assert!(stop_tokens.contains(&"<|repo_name|>".to_string()));
|
||||
assert_eq!(stop_tokens.len(), 3);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) {
|
||||
async fn test_clean_completion_basic(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codellama:7b".to_string(),
|
||||
"qwen2.5-coder:32b".to_string(),
|
||||
);
|
||||
|
||||
let stop_tokens = provider.get_stop_tokens();
|
||||
assert!(stop_tokens.contains(&"<PRE>".to_string()));
|
||||
assert!(stop_tokens.contains(&"<SUF>".to_string()));
|
||||
assert!(stop_tokens.contains(&"<MID>".to_string()));
|
||||
assert!(stop_tokens.contains(&"</PRE>".to_string()));
|
||||
}
|
||||
let completion = " console.log('hello'); ";
|
||||
let cleaned = provider.clean_completion(completion);
|
||||
assert_eq!(cleaned, "console.log('hello');");
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codestral:7b".to_string(),
|
||||
);
|
||||
|
||||
let stop_tokens = provider.get_stop_tokens();
|
||||
assert!(stop_tokens.contains(&"[PREFIX]".to_string()));
|
||||
assert!(stop_tokens.contains(&"[SUFFIX]".to_string()));
|
||||
let completion_with_prefix = "// COMPLETION HERE\nconsole.log('hello');";
|
||||
let cleaned = provider.clean_completion(completion_with_prefix);
|
||||
assert_eq!(cleaned, "console.log('hello');");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -738,49 +369,59 @@ mod tests {
|
|||
provider.update_model("qwen2.5-coder:32b".to_string());
|
||||
assert_eq!(provider.model, "qwen2.5-coder:32b");
|
||||
|
||||
// Test FIM prompt changes with different model
|
||||
let prefix = "def hello():";
|
||||
let suffix = " pass";
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
// Should now use qwen coder pattern (with pipes)
|
||||
assert!(prompt.contains("<|fim_prefix|>"));
|
||||
assert!(prompt.contains("<|fim_suffix|>"));
|
||||
assert!(prompt.contains("<|fim_middle|>"));
|
||||
|
||||
// Update to regular Qwen model (non-coder)
|
||||
// Test updating to different models
|
||||
provider.update_model("qwen2.5:32b".to_string());
|
||||
assert_eq!(provider.model, "qwen2.5:32b");
|
||||
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
// Should now use stable code pattern (no pipes)
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||
|
||||
// Update to starcoder model
|
||||
provider.update_model("starcoder:7b".to_string());
|
||||
assert_eq!(provider.model, "starcoder:7b");
|
||||
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
|
||||
// Should also use stable code pattern (no pipes)
|
||||
assert!(prompt.contains("<fim_prefix>"));
|
||||
assert!(prompt.contains("<fim_suffix>"));
|
||||
assert!(prompt.contains("<fim_middle>"));
|
||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
||||
|
||||
// Update to codestral model
|
||||
provider.update_model("codestral:22b".to_string());
|
||||
assert_eq!(provider.model, "codestral:22b");
|
||||
|
||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||
// FIM patterns are now handled by Ollama natively, so we just test model updates
|
||||
provider.update_model("deepseek-coder:6.7b".to_string());
|
||||
assert_eq!(provider.model, "deepseek-coder:6.7b");
|
||||
}
|
||||
|
||||
// Should use codestral pattern (suffix-first)
|
||||
assert!(prompt.contains("[SUFFIX]"));
|
||||
assert!(prompt.contains("[PREFIX]"));
|
||||
assert!(prompt.starts_with("[SUFFIX]"));
|
||||
#[gpui::test]
|
||||
async fn test_native_fim_request_structure(_cx: &mut TestAppContext) {
|
||||
let provider = OllamaCompletionProvider::new(
|
||||
Arc::new(FakeHttpClient::with_404_response()),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5-coder:32b".to_string(),
|
||||
);
|
||||
|
||||
let prefix = "def fibonacci(n):";
|
||||
let suffix = " return result";
|
||||
|
||||
// Test that we create the correct request structure for native FIM
|
||||
let request = GenerateRequest {
|
||||
model: provider.model.clone(),
|
||||
prompt: prefix.to_string(),
|
||||
suffix: Some(suffix.to_string()),
|
||||
stream: false,
|
||||
options: Some(GenerateOptions {
|
||||
num_predict: Some(150),
|
||||
temperature: Some(0.1),
|
||||
top_p: Some(0.95),
|
||||
stop: Some(provider.get_stop_tokens()),
|
||||
}),
|
||||
keep_alive: None,
|
||||
context: None,
|
||||
};
|
||||
|
||||
// Verify the request structure uses native FIM approach
|
||||
assert_eq!(request.model, "qwen2.5-coder:32b");
|
||||
assert_eq!(request.prompt, "def fibonacci(n):");
|
||||
assert_eq!(request.suffix, Some(" return result".to_string()));
|
||||
assert!(!request.stream);
|
||||
|
||||
// Verify stop tokens are simplified (no FIM-specific tokens)
|
||||
let stop_tokens = request.options.as_ref().unwrap().stop.as_ref().unwrap();
|
||||
assert!(stop_tokens.contains(&"\n\n".to_string()));
|
||||
assert!(stop_tokens.contains(&"```".to_string()));
|
||||
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
||||
assert_eq!(stop_tokens.len(), 3);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue