Use Ollama's suffix field and remove FIM token handling
This commit is contained in:
parent
4b096b9a6b
commit
b50555b87a
2 changed files with 95 additions and 426 deletions
|
@ -103,6 +103,7 @@ impl Model {
|
||||||
pub struct GenerateRequest {
|
pub struct GenerateRequest {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub prompt: String,
|
pub prompt: String,
|
||||||
|
pub suffix: Option<String>,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
pub options: Option<GenerateOptions>,
|
pub options: Option<GenerateOptions>,
|
||||||
pub keep_alive: Option<KeepAlive>,
|
pub keep_alive: Option<KeepAlive>,
|
||||||
|
@ -425,6 +426,33 @@ pub async fn generate(
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_request_with_suffix_serialization() {
|
||||||
|
let request = GenerateRequest {
|
||||||
|
model: "qwen2.5-coder:32b".to_string(),
|
||||||
|
prompt: "def fibonacci(n):".to_string(),
|
||||||
|
suffix: Some(" return result".to_string()),
|
||||||
|
stream: false,
|
||||||
|
options: Some(GenerateOptions {
|
||||||
|
num_predict: Some(150),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
stop: Some(vec!["<|endoftext|>".to_string()]),
|
||||||
|
}),
|
||||||
|
keep_alive: None,
|
||||||
|
context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
|
let parsed: GenerateRequest = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.model, "qwen2.5-coder:32b");
|
||||||
|
assert_eq!(parsed.prompt, "def fibonacci(n):");
|
||||||
|
assert_eq!(parsed.suffix, Some(" return result".to_string()));
|
||||||
|
assert!(!parsed.stream);
|
||||||
|
assert!(parsed.options.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn parse_completion() {
|
fn parse_completion() {
|
||||||
let response = serde_json::json!({
|
let response = serde_json::json!({
|
||||||
|
|
|
@ -37,167 +37,21 @@ impl OllamaCompletionProvider {
|
||||||
self.model = new_model;
|
self.model = new_model;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
|
||||||
// Use model-specific FIM patterns
|
|
||||||
let model_lower = self.model.to_lowercase();
|
|
||||||
|
|
||||||
if model_lower.contains("qwen") && model_lower.contains("coder") {
|
|
||||||
// QwenCoder models use pipes
|
|
||||||
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
|
|
||||||
} else if model_lower.contains("codellama") {
|
|
||||||
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
|
|
||||||
} else if model_lower.contains("deepseek") {
|
|
||||||
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
|
||||||
} else if model_lower.contains("codestral") {
|
|
||||||
// Codestral uses suffix-first order
|
|
||||||
format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
|
|
||||||
} else if model_lower.contains("codegemma") {
|
|
||||||
format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
|
|
||||||
} else if model_lower.contains("wizardcoder") {
|
|
||||||
// WizardCoder models inherit patterns from their base model
|
|
||||||
if model_lower.contains("deepseek") {
|
|
||||||
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
|
||||||
} else {
|
|
||||||
// Most WizardCoder models use stable code pattern
|
|
||||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
|
||||||
}
|
|
||||||
} else if model_lower.contains("starcoder")
|
|
||||||
|| model_lower.contains("santacoder")
|
|
||||||
|| model_lower.contains("stable")
|
|
||||||
|| model_lower.contains("qwen")
|
|
||||||
|| model_lower.contains("replit")
|
|
||||||
{
|
|
||||||
// Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
|
|
||||||
// non-coder Qwen models, and Replit models
|
|
||||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
|
||||||
} else {
|
|
||||||
// Default to stable code pattern for unknown models
|
|
||||||
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_stop_tokens(&self) -> Vec<String> {
|
fn get_stop_tokens(&self) -> Vec<String> {
|
||||||
let model_lower = self.model.to_lowercase();
|
// Basic stop tokens for code completion
|
||||||
|
// Ollama handles FIM tokens internally, so we only need general completion stops
|
||||||
let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
|
vec![
|
||||||
|
"\n\n".to_string(), // Double newline often indicates end of completion
|
||||||
if model_lower.contains("qwen") && model_lower.contains("coder") {
|
"```".to_string(), // Code block delimiter
|
||||||
stop_tokens.extend(vec![
|
"<|endoftext|>".to_string(), // Common model end token
|
||||||
"<|endoftext|>".to_string(),
|
]
|
||||||
"<|fim_prefix|>".to_string(),
|
|
||||||
"<|fim_middle|>".to_string(),
|
|
||||||
"<|fim_suffix|>".to_string(),
|
|
||||||
"<|fim_pad|>".to_string(),
|
|
||||||
"<|repo_name|>".to_string(),
|
|
||||||
"<|file_sep|>".to_string(),
|
|
||||||
"<|im_start|>".to_string(),
|
|
||||||
"<|im_end|>".to_string(),
|
|
||||||
]);
|
|
||||||
} else if model_lower.contains("codellama") {
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<PRE>".to_string(),
|
|
||||||
"<SUF>".to_string(),
|
|
||||||
"<MID>".to_string(),
|
|
||||||
"</PRE>".to_string(),
|
|
||||||
]);
|
|
||||||
} else if model_lower.contains("deepseek") {
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<|fim▁begin|>".to_string(),
|
|
||||||
"<|fim▁hole|>".to_string(),
|
|
||||||
"<|fim▁end|>".to_string(),
|
|
||||||
"//".to_string(),
|
|
||||||
"<|end▁of▁sentence|>".to_string(),
|
|
||||||
]);
|
|
||||||
} else if model_lower.contains("codestral") {
|
|
||||||
stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]);
|
|
||||||
} else if model_lower.contains("codegemma") {
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<|fim_prefix|>".to_string(),
|
|
||||||
"<|fim_suffix|>".to_string(),
|
|
||||||
"<|fim_middle|>".to_string(),
|
|
||||||
"<|file_separator|>".to_string(),
|
|
||||||
"<|endoftext|>".to_string(),
|
|
||||||
]);
|
|
||||||
} else if model_lower.contains("wizardcoder") {
|
|
||||||
// WizardCoder models inherit patterns from their base model
|
|
||||||
if model_lower.contains("deepseek") {
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<|fim▁begin|>".to_string(),
|
|
||||||
"<|fim▁hole|>".to_string(),
|
|
||||||
"<|fim▁end|>".to_string(),
|
|
||||||
]);
|
|
||||||
} else {
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<fim_prefix>".to_string(),
|
|
||||||
"<fim_suffix>".to_string(),
|
|
||||||
"<fim_middle>".to_string(),
|
|
||||||
"<|endoftext|>".to_string(),
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
} else if model_lower.contains("starcoder")
|
|
||||||
|| model_lower.contains("santacoder")
|
|
||||||
|| model_lower.contains("stable")
|
|
||||||
|| model_lower.contains("qwen")
|
|
||||||
|| model_lower.contains("replit")
|
|
||||||
{
|
|
||||||
// Stable code pattern stop tokens
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<fim_prefix>".to_string(),
|
|
||||||
"<fim_suffix>".to_string(),
|
|
||||||
"<fim_middle>".to_string(),
|
|
||||||
"<|endoftext|>".to_string(),
|
|
||||||
]);
|
|
||||||
} else {
|
|
||||||
// Generic stop tokens for unknown models - cover both patterns
|
|
||||||
stop_tokens.extend(vec![
|
|
||||||
"<|fim_prefix|>".to_string(),
|
|
||||||
"<|fim_suffix|>".to_string(),
|
|
||||||
"<|fim_middle|>".to_string(),
|
|
||||||
"<fim_prefix>".to_string(),
|
|
||||||
"<fim_suffix>".to_string(),
|
|
||||||
"<fim_middle>".to_string(),
|
|
||||||
"<|endoftext|>".to_string(),
|
|
||||||
]);
|
|
||||||
}
|
|
||||||
|
|
||||||
stop_tokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clean_completion(&self, completion: &str) -> String {
|
fn clean_completion(&self, completion: &str) -> String {
|
||||||
let mut cleaned = completion.to_string();
|
let mut cleaned = completion.to_string();
|
||||||
|
|
||||||
// Remove common FIM tokens that might appear in responses
|
// Basic cleaning - Ollama should handle FIM tokens internally
|
||||||
let fim_tokens = [
|
// Remove leading/trailing whitespace
|
||||||
"<|fim_prefix|>",
|
|
||||||
"<|fim_suffix|>",
|
|
||||||
"<|fim_middle|>",
|
|
||||||
"<|fim_pad|>",
|
|
||||||
"<|repo_name|>",
|
|
||||||
"<|file_sep|>",
|
|
||||||
"<|im_start|>",
|
|
||||||
"<|im_end|>",
|
|
||||||
"<fim_prefix>",
|
|
||||||
"<fim_suffix>",
|
|
||||||
"<fim_middle>",
|
|
||||||
"<PRE>",
|
|
||||||
"<SUF>",
|
|
||||||
"<MID>",
|
|
||||||
"</PRE>",
|
|
||||||
"<|fim▁begin|>",
|
|
||||||
"<|fim▁hole|>",
|
|
||||||
"<|fim▁end|>",
|
|
||||||
"<|end▁of▁sentence|>",
|
|
||||||
"[PREFIX]",
|
|
||||||
"[SUFFIX]",
|
|
||||||
"<|file_separator|>",
|
|
||||||
"<|endoftext|>",
|
|
||||||
];
|
|
||||||
|
|
||||||
for token in &fim_tokens {
|
|
||||||
cleaned = cleaned.replace(token, "");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove leading/trailing whitespace and common prefixes
|
|
||||||
cleaned = cleaned.trim().to_string();
|
cleaned = cleaned.trim().to_string();
|
||||||
|
|
||||||
// Remove common unwanted prefixes that models sometimes generate
|
// Remove common unwanted prefixes that models sometimes generate
|
||||||
|
@ -288,14 +142,13 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
this.extract_context(buffer_snapshot, cursor_position)
|
this.extract_context(buffer_snapshot, cursor_position)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
|
||||||
|
|
||||||
let (model, stop_tokens) =
|
let (model, stop_tokens) =
|
||||||
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
|
this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?;
|
||||||
|
|
||||||
let request = GenerateRequest {
|
let request = GenerateRequest {
|
||||||
model,
|
model,
|
||||||
prompt,
|
prompt: prefix,
|
||||||
|
suffix: Some(suffix),
|
||||||
stream: false,
|
stream: false,
|
||||||
options: Some(GenerateOptions {
|
options: Some(GenerateOptions {
|
||||||
num_predict: Some(150), // Reasonable completion length
|
num_predict: Some(150), // Reasonable completion length
|
||||||
|
@ -384,216 +237,7 @@ mod tests {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) {
|
async fn test_get_stop_tokens(_cx: &mut TestAppContext) {
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"qwen2.5-coder:32b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<|fim_prefix|>"));
|
|
||||||
assert!(prompt.contains("<|fim_suffix|>"));
|
|
||||||
assert!(prompt.contains("<|fim_middle|>"));
|
|
||||||
assert!(prompt.contains(prefix));
|
|
||||||
assert!(prompt.contains(suffix));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"qwen2.5:32b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
|
||||||
assert!(prompt.contains(prefix));
|
|
||||||
assert!(prompt.contains(suffix));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"codellama:7b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "function hello() {";
|
|
||||||
let suffix = "}";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<PRE>"));
|
|
||||||
assert!(prompt.contains("<SUF>"));
|
|
||||||
assert!(prompt.contains("<MID>"));
|
|
||||||
assert!(prompt.contains(prefix));
|
|
||||||
assert!(prompt.contains(suffix));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"deepseek-coder:6.7b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<|fim▁begin|>"));
|
|
||||||
assert!(prompt.contains("<|fim▁hole|>"));
|
|
||||||
assert!(prompt.contains("<|fim▁end|>"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"starcoder:7b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"replit-code:3b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
// Replit should use stable code pattern (no pipes)
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"codestral:22b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
// Codestral uses suffix-first order
|
|
||||||
assert!(prompt.contains("[SUFFIX]"));
|
|
||||||
assert!(prompt.contains("[PREFIX]"));
|
|
||||||
assert!(prompt.starts_with("[SUFFIX]"));
|
|
||||||
assert!(prompt.contains(prefix));
|
|
||||||
assert!(prompt.contains(suffix));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"codegemma:7b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<|fim_prefix|>"));
|
|
||||||
assert!(prompt.contains("<|fim_suffix|>"));
|
|
||||||
assert!(prompt.contains("<|fim_middle|>"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"wizardcoder:13b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
// WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"santacoder:1b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"qwen2.5-coder:32b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let completion_with_tokens = "console.log('hello');<|fim_middle|>";
|
|
||||||
let cleaned = provider.clean_completion(completion_with_tokens);
|
|
||||||
assert_eq!(cleaned, "console.log('hello');");
|
|
||||||
|
|
||||||
let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>";
|
|
||||||
let cleaned = provider.clean_completion(completion_with_multiple_tokens);
|
|
||||||
assert_eq!(cleaned, "console.log('hello');");
|
|
||||||
|
|
||||||
let completion_with_starcoder_tokens = "console.log('hello');<fim_middle>";
|
|
||||||
let cleaned = provider.clean_completion(completion_with_starcoder_tokens);
|
|
||||||
assert_eq!(cleaned, "console.log('hello');");
|
|
||||||
|
|
||||||
let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]";
|
|
||||||
let cleaned = provider.clean_completion(completion_with_codestral_tokens);
|
|
||||||
assert_eq!(cleaned, "console.log('hello');");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
|
||||||
async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) {
|
|
||||||
let provider = OllamaCompletionProvider::new(
|
let provider = OllamaCompletionProvider::new(
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
|
@ -601,40 +245,27 @@ mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let stop_tokens = provider.get_stop_tokens();
|
let stop_tokens = provider.get_stop_tokens();
|
||||||
assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string()));
|
assert!(stop_tokens.contains(&"\n\n".to_string()));
|
||||||
assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string()));
|
assert!(stop_tokens.contains(&"```".to_string()));
|
||||||
assert!(stop_tokens.contains(&"<|fim_middle|>".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
||||||
assert!(stop_tokens.contains(&"<|fim_pad|>".to_string()));
|
assert_eq!(stop_tokens.len(), 3);
|
||||||
assert!(stop_tokens.contains(&"<|repo_name|>".to_string()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) {
|
async fn test_clean_completion_basic(_cx: &mut TestAppContext) {
|
||||||
let provider = OllamaCompletionProvider::new(
|
let provider = OllamaCompletionProvider::new(
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"codellama:7b".to_string(),
|
"qwen2.5-coder:32b".to_string(),
|
||||||
);
|
);
|
||||||
|
|
||||||
let stop_tokens = provider.get_stop_tokens();
|
let completion = " console.log('hello'); ";
|
||||||
assert!(stop_tokens.contains(&"<PRE>".to_string()));
|
let cleaned = provider.clean_completion(completion);
|
||||||
assert!(stop_tokens.contains(&"<SUF>".to_string()));
|
assert_eq!(cleaned, "console.log('hello');");
|
||||||
assert!(stop_tokens.contains(&"<MID>".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"</PRE>".to_string()));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[gpui::test]
|
let completion_with_prefix = "// COMPLETION HERE\nconsole.log('hello');";
|
||||||
async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) {
|
let cleaned = provider.clean_completion(completion_with_prefix);
|
||||||
let provider = OllamaCompletionProvider::new(
|
assert_eq!(cleaned, "console.log('hello');");
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"codestral:7b".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let stop_tokens = provider.get_stop_tokens();
|
|
||||||
assert!(stop_tokens.contains(&"[PREFIX]".to_string()));
|
|
||||||
assert!(stop_tokens.contains(&"[SUFFIX]".to_string()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -738,49 +369,59 @@ mod tests {
|
||||||
provider.update_model("qwen2.5-coder:32b".to_string());
|
provider.update_model("qwen2.5-coder:32b".to_string());
|
||||||
assert_eq!(provider.model, "qwen2.5-coder:32b");
|
assert_eq!(provider.model, "qwen2.5-coder:32b");
|
||||||
|
|
||||||
// Test FIM prompt changes with different model
|
// Test updating to different models
|
||||||
let prefix = "def hello():";
|
|
||||||
let suffix = " pass";
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
// Should now use qwen coder pattern (with pipes)
|
|
||||||
assert!(prompt.contains("<|fim_prefix|>"));
|
|
||||||
assert!(prompt.contains("<|fim_suffix|>"));
|
|
||||||
assert!(prompt.contains("<|fim_middle|>"));
|
|
||||||
|
|
||||||
// Update to regular Qwen model (non-coder)
|
|
||||||
provider.update_model("qwen2.5:32b".to_string());
|
provider.update_model("qwen2.5:32b".to_string());
|
||||||
assert_eq!(provider.model, "qwen2.5:32b");
|
assert_eq!(provider.model, "qwen2.5:32b");
|
||||||
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
// Should now use stable code pattern (no pipes)
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
|
||||||
|
|
||||||
// Update to starcoder model
|
|
||||||
provider.update_model("starcoder:7b".to_string());
|
provider.update_model("starcoder:7b".to_string());
|
||||||
assert_eq!(provider.model, "starcoder:7b");
|
assert_eq!(provider.model, "starcoder:7b");
|
||||||
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
|
||||||
|
|
||||||
// Should also use stable code pattern (no pipes)
|
|
||||||
assert!(prompt.contains("<fim_prefix>"));
|
|
||||||
assert!(prompt.contains("<fim_suffix>"));
|
|
||||||
assert!(prompt.contains("<fim_middle>"));
|
|
||||||
assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
|
|
||||||
|
|
||||||
// Update to codestral model
|
|
||||||
provider.update_model("codestral:22b".to_string());
|
provider.update_model("codestral:22b".to_string());
|
||||||
assert_eq!(provider.model, "codestral:22b");
|
assert_eq!(provider.model, "codestral:22b");
|
||||||
|
|
||||||
let prompt = provider.build_fim_prompt(prefix, suffix);
|
// FIM patterns are now handled by Ollama natively, so we just test model updates
|
||||||
|
provider.update_model("deepseek-coder:6.7b".to_string());
|
||||||
|
assert_eq!(provider.model, "deepseek-coder:6.7b");
|
||||||
|
}
|
||||||
|
|
||||||
// Should use codestral pattern (suffix-first)
|
#[gpui::test]
|
||||||
assert!(prompt.contains("[SUFFIX]"));
|
async fn test_native_fim_request_structure(_cx: &mut TestAppContext) {
|
||||||
assert!(prompt.contains("[PREFIX]"));
|
let provider = OllamaCompletionProvider::new(
|
||||||
assert!(prompt.starts_with("[SUFFIX]"));
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def fibonacci(n):";
|
||||||
|
let suffix = " return result";
|
||||||
|
|
||||||
|
// Test that we create the correct request structure for native FIM
|
||||||
|
let request = GenerateRequest {
|
||||||
|
model: provider.model.clone(),
|
||||||
|
prompt: prefix.to_string(),
|
||||||
|
suffix: Some(suffix.to_string()),
|
||||||
|
stream: false,
|
||||||
|
options: Some(GenerateOptions {
|
||||||
|
num_predict: Some(150),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
stop: Some(provider.get_stop_tokens()),
|
||||||
|
}),
|
||||||
|
keep_alive: None,
|
||||||
|
context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Verify the request structure uses native FIM approach
|
||||||
|
assert_eq!(request.model, "qwen2.5-coder:32b");
|
||||||
|
assert_eq!(request.prompt, "def fibonacci(n):");
|
||||||
|
assert_eq!(request.suffix, Some(" return result".to_string()));
|
||||||
|
assert!(!request.stream);
|
||||||
|
|
||||||
|
// Verify stop tokens are simplified (no FIM-specific tokens)
|
||||||
|
let stop_tokens = request.options.as_ref().unwrap().stop.as_ref().unwrap();
|
||||||
|
assert!(stop_tokens.contains(&"\n\n".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"```".to_string()));
|
||||||
|
assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
|
||||||
|
assert_eq!(stop_tokens.len(), 3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue