diff --git a/crates/ollama/src/ollama_completion_provider.rs b/crates/ollama/src/ollama_completion_provider.rs index 3112ad831a..095967aa1f 100644 --- a/crates/ollama/src/ollama_completion_provider.rs +++ b/crates/ollama/src/ollama_completion_provider.rs @@ -39,21 +39,182 @@ impl OllamaCompletionProvider { fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String { // Use model-specific FIM patterns - match self.model.as_str() { - m if m.contains("codellama") => { - format!("
 {prefix} {suffix} ")
-            }
-            m if m.contains("deepseek") => {
+        let model_lower = self.model.to_lowercase();
+
+        if model_lower.contains("qwen") && model_lower.contains("coder") {
+            // QwenCoder models use pipes
+            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
+        } else if model_lower.contains("codellama") {
+            format!("
 {prefix} {suffix} ")
+        } else if model_lower.contains("deepseek") {
+            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
+        } else if model_lower.contains("codestral") {
+            // Codestral uses suffix-first order
+            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
+        } else if model_lower.contains("codegemma") {
+            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
+        } else if model_lower.contains("wizardcoder") {
+            // WizardCoder models inherit patterns from their base model
+            if model_lower.contains("deepseek") {
                 format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
-            }
-            m if m.contains("starcoder") => {
+            } else {
+                // Most WizardCoder models use stable code pattern
                 format!("{prefix}{suffix}")
             }
-            _ => {
-                // Generic FIM pattern - fallback for models without specific support
-                format!("// Complete the following code:\n{prefix}\n// COMPLETION HERE\n{suffix}")
+        } else if model_lower.contains("starcoder")
+            || model_lower.contains("santacoder")
+            || model_lower.contains("stable")
+            || model_lower.contains("qwen")
+            || model_lower.contains("replit")
+        {
+            // Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
+            // non-coder Qwen models, and Replit models
+            format!("{prefix}{suffix}")
+        } else {
+            // Default to stable code pattern for unknown models
+            format!("{prefix}{suffix}")
+        }
+    }
+
+    fn get_stop_tokens(&self) -> Vec {
+        let model_lower = self.model.to_lowercase();
+
+        let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
+
+        if model_lower.contains("qwen") && model_lower.contains("coder") {
+            stop_tokens.extend(vec![
+                "<|endoftext|>".to_string(),
+                "<|fim_prefix|>".to_string(),
+                "<|fim_middle|>".to_string(),
+                "<|fim_suffix|>".to_string(),
+                "<|fim_pad|>".to_string(),
+                "<|repo_name|>".to_string(),
+                "<|file_sep|>".to_string(),
+                "<|im_start|>".to_string(),
+                "<|im_end|>".to_string(),
+            ]);
+        } else if model_lower.contains("codellama") {
+            stop_tokens.extend(vec![
+                "
".to_string(),
+                "".to_string(),
+                "".to_string(),
+                "
".to_string(), + ]); + } else if model_lower.contains("deepseek") { + stop_tokens.extend(vec![ + "<|fim▁begin|>".to_string(), + "<|fim▁hole|>".to_string(), + "<|fim▁end|>".to_string(), + "//".to_string(), + "<|end▁of▁sentence|>".to_string(), + ]); + } else if model_lower.contains("codestral") { + stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]); + } else if model_lower.contains("codegemma") { + stop_tokens.extend(vec![ + "<|fim_prefix|>".to_string(), + "<|fim_suffix|>".to_string(), + "<|fim_middle|>".to_string(), + "<|file_separator|>".to_string(), + "<|endoftext|>".to_string(), + ]); + } else if model_lower.contains("wizardcoder") { + // WizardCoder models inherit patterns from their base model + if model_lower.contains("deepseek") { + stop_tokens.extend(vec![ + "<|fim▁begin|>".to_string(), + "<|fim▁hole|>".to_string(), + "<|fim▁end|>".to_string(), + ]); + } else { + stop_tokens.extend(vec![ + "".to_string(), + "".to_string(), + "".to_string(), + "<|endoftext|>".to_string(), + ]); + } + } else if model_lower.contains("starcoder") + || model_lower.contains("santacoder") + || model_lower.contains("stable") + || model_lower.contains("qwen") + || model_lower.contains("replit") + { + // Stable code pattern stop tokens + stop_tokens.extend(vec![ + "".to_string(), + "".to_string(), + "".to_string(), + "<|endoftext|>".to_string(), + ]); + } else { + // Generic stop tokens for unknown models - cover both patterns + stop_tokens.extend(vec![ + "<|fim_prefix|>".to_string(), + "<|fim_suffix|>".to_string(), + "<|fim_middle|>".to_string(), + "".to_string(), + "".to_string(), + "".to_string(), + "<|endoftext|>".to_string(), + ]); + } + + stop_tokens + } + + fn clean_completion(&self, completion: &str) -> String { + let mut cleaned = completion.to_string(); + + // Remove common FIM tokens that might appear in responses + let fim_tokens = [ + "<|fim_prefix|>", + "<|fim_suffix|>", + "<|fim_middle|>", + "<|fim_pad|>", + "<|repo_name|>", + "<|file_sep|>", + "<|im_start|>", + "<|im_end|>", + "", + "", + "", + "
",
+            "",
+            "",
+            "
", + "<|fim▁begin|>", + "<|fim▁hole|>", + "<|fim▁end|>", + "<|end▁of▁sentence|>", + "[PREFIX]", + "[SUFFIX]", + "<|file_separator|>", + "<|endoftext|>", + ]; + + for token in &fim_tokens { + cleaned = cleaned.replace(token, ""); + } + + // Remove leading/trailing whitespace and common prefixes + cleaned = cleaned.trim().to_string(); + + // Remove common unwanted prefixes that models sometimes generate + let unwanted_prefixes = [ + "// COMPLETION HERE", + "// Complete the following code:", + "// completion:", + "// TODO:", + ]; + + for prefix in &unwanted_prefixes { + if cleaned.starts_with(prefix) { + cleaned = cleaned[prefix.len()..].trim_start().to_string(); } } + + cleaned } fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) { @@ -129,7 +290,8 @@ impl EditPredictionProvider for OllamaCompletionProvider { let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?; - let model = this.update(cx, |this, _| this.model.clone())?; + let (model, stop_tokens) = + this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?; let request = GenerateRequest { model, @@ -139,12 +301,7 @@ impl EditPredictionProvider for OllamaCompletionProvider { num_predict: Some(150), // Reasonable completion length temperature: Some(0.1), // Low temperature for more deterministic results top_p: Some(0.95), - stop: Some(vec![ - "\n\n".to_string(), - "```".to_string(), - "
".to_string(), - "".to_string(), - ]), + stop: Some(stop_tokens), }), keep_alive: None, context: None, @@ -156,8 +313,9 @@ impl EditPredictionProvider for OllamaCompletionProvider { this.update(cx, |this, cx| { this.pending_refresh = None; - if !response.response.trim().is_empty() { - this.current_completion = Some(response.response); + let cleaned_completion = this.clean_completion(&response.response); + if !cleaned_completion.is_empty() { + this.current_completion = Some(cleaned_completion); } else { this.current_completion = None; } @@ -210,12 +368,9 @@ impl EditPredictionProvider for OllamaCompletionProvider { let buffer_snapshot = buffer.read(cx); let position = cursor_position.bias_right(buffer_snapshot); - // Clean up the completion text - let completion_text = completion_text.trim_start().trim_end(); - Some(InlineCompletion { id: None, - edits: vec![(position..position, completion_text.to_string())], + edits: vec![(position..position, completion_text)], edit_preview: None, }) } @@ -229,7 +384,46 @@ mod tests { use std::sync::Arc; #[gpui::test] - async fn test_fim_prompt_patterns(_cx: &mut TestAppContext) { + async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "qwen2.5-coder:32b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + assert!(prompt.contains("<|fim_prefix|>")); + assert!(prompt.contains("<|fim_suffix|>")); + assert!(prompt.contains("<|fim_middle|>")); + assert!(prompt.contains(prefix)); + assert!(prompt.contains(suffix)); + } + + #[gpui::test] + async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "qwen2.5:32b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes + assert!(prompt.contains(prefix)); + assert!(prompt.contains(suffix)); + } + + #[gpui::test] + async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) { let provider = OllamaCompletionProvider::new( Arc::new(FakeHttpClient::with_404_response()), "http://localhost:11434".to_string(), @@ -281,6 +475,168 @@ mod tests { assert!(prompt.contains("")); } + #[gpui::test] + async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "replit-code:3b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + // Replit should use stable code pattern (no pipes) + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes + } + + #[gpui::test] + async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "codestral:22b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + // Codestral uses suffix-first order + assert!(prompt.contains("[SUFFIX]")); + assert!(prompt.contains("[PREFIX]")); + assert!(prompt.starts_with("[SUFFIX]")); + assert!(prompt.contains(prefix)); + assert!(prompt.contains(suffix)); + } + + #[gpui::test] + async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "codegemma:7b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + assert!(prompt.contains("<|fim_prefix|>")); + assert!(prompt.contains("<|fim_suffix|>")); + assert!(prompt.contains("<|fim_middle|>")); + } + + #[gpui::test] + async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "wizardcoder:13b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + // WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes + } + + #[gpui::test] + async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "santacoder:1b".to_string(), + ); + + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("")); + } + + #[gpui::test] + async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "qwen2.5-coder:32b".to_string(), + ); + + let completion_with_tokens = "console.log('hello');<|fim_middle|>"; + let cleaned = provider.clean_completion(completion_with_tokens); + assert_eq!(cleaned, "console.log('hello');"); + + let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>"; + let cleaned = provider.clean_completion(completion_with_multiple_tokens); + assert_eq!(cleaned, "console.log('hello');"); + + let completion_with_starcoder_tokens = "console.log('hello');"; + let cleaned = provider.clean_completion(completion_with_starcoder_tokens); + assert_eq!(cleaned, "console.log('hello');"); + + let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]"; + let cleaned = provider.clean_completion(completion_with_codestral_tokens); + assert_eq!(cleaned, "console.log('hello');"); + } + + #[gpui::test] + async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "qwen2.5-coder:32b".to_string(), + ); + + let stop_tokens = provider.get_stop_tokens(); + assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string())); + assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string())); + assert!(stop_tokens.contains(&"<|fim_middle|>".to_string())); + assert!(stop_tokens.contains(&"<|endoftext|>".to_string())); + assert!(stop_tokens.contains(&"<|fim_pad|>".to_string())); + assert!(stop_tokens.contains(&"<|repo_name|>".to_string())); + } + + #[gpui::test] + async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "codellama:7b".to_string(), + ); + + let stop_tokens = provider.get_stop_tokens(); + assert!(stop_tokens.contains(&"
".to_string()));
+        assert!(stop_tokens.contains(&"".to_string()));
+        assert!(stop_tokens.contains(&"".to_string()));
+        assert!(stop_tokens.contains(&"
".to_string())); + } + + #[gpui::test] + async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "codestral:7b".to_string(), + ); + + let stop_tokens = provider.get_stop_tokens(); + assert!(stop_tokens.contains(&"[PREFIX]".to_string())); + assert!(stop_tokens.contains(&"[SUFFIX]".to_string())); + } + #[gpui::test] async fn test_extract_context(cx: &mut TestAppContext) { let provider = OllamaCompletionProvider::new( @@ -378,19 +734,31 @@ mod tests { // Verify initial model assert_eq!(provider.model, "codellama:7b"); - // Test updating model - provider.update_model("deepseek-coder:6.7b".to_string()); - assert_eq!(provider.model, "deepseek-coder:6.7b"); + // Test updating model to Qwen Coder + provider.update_model("qwen2.5-coder:32b".to_string()); + assert_eq!(provider.model, "qwen2.5-coder:32b"); // Test FIM prompt changes with different model let prefix = "def hello():"; let suffix = " pass"; let prompt = provider.build_fim_prompt(prefix, suffix); - // Should now use deepseek pattern - assert!(prompt.contains("<|fim▁begin|>")); - assert!(prompt.contains("<|fim▁hole|>")); - assert!(prompt.contains("<|fim▁end|>")); + // Should now use qwen coder pattern (with pipes) + assert!(prompt.contains("<|fim_prefix|>")); + assert!(prompt.contains("<|fim_suffix|>")); + assert!(prompt.contains("<|fim_middle|>")); + + // Update to regular Qwen model (non-coder) + provider.update_model("qwen2.5:32b".to_string()); + assert_eq!(provider.model, "qwen2.5:32b"); + + let prompt = provider.build_fim_prompt(prefix, suffix); + + // Should now use stable code pattern (no pipes) + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes // Update to starcoder model provider.update_model("starcoder:7b".to_string()); @@ -398,9 +766,21 @@ mod tests { let prompt = provider.build_fim_prompt(prefix, suffix); - // Should now use starcoder pattern + // Should also use stable code pattern (no pipes) assert!(prompt.contains("")); assert!(prompt.contains("")); assert!(prompt.contains("")); + assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes + + // Update to codestral model + provider.update_model("codestral:22b".to_string()); + assert_eq!(provider.model, "codestral:22b"); + + let prompt = provider.build_fim_prompt(prefix, suffix); + + // Should use codestral pattern (suffix-first) + assert!(prompt.contains("[SUFFIX]")); + assert!(prompt.contains("[PREFIX]")); + assert!(prompt.starts_with("[SUFFIX]")); } }