diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index ac8251738e..6194242f2e 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -103,6 +103,7 @@ impl Model { pub struct GenerateRequest { pub model: String, pub prompt: String, + pub suffix: Option, pub stream: bool, pub options: Option, pub keep_alive: Option, @@ -425,6 +426,33 @@ pub async fn generate( mod tests { use super::*; + #[test] + fn test_generate_request_with_suffix_serialization() { + let request = GenerateRequest { + model: "qwen2.5-coder:32b".to_string(), + prompt: "def fibonacci(n):".to_string(), + suffix: Some(" return result".to_string()), + stream: false, + options: Some(GenerateOptions { + num_predict: Some(150), + temperature: Some(0.1), + top_p: Some(0.95), + stop: Some(vec!["<|endoftext|>".to_string()]), + }), + keep_alive: None, + context: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + let parsed: GenerateRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.model, "qwen2.5-coder:32b"); + assert_eq!(parsed.prompt, "def fibonacci(n):"); + assert_eq!(parsed.suffix, Some(" return result".to_string())); + assert!(!parsed.stream); + assert!(parsed.options.is_some()); + } + #[test] fn parse_completion() { let response = serde_json::json!({ diff --git a/crates/ollama/src/ollama_completion_provider.rs b/crates/ollama/src/ollama_completion_provider.rs index 095967aa1f..953bdecbb9 100644 --- a/crates/ollama/src/ollama_completion_provider.rs +++ b/crates/ollama/src/ollama_completion_provider.rs @@ -37,167 +37,21 @@ impl OllamaCompletionProvider { self.model = new_model; } - fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String { - // Use model-specific FIM patterns - let model_lower = self.model.to_lowercase(); - - if model_lower.contains("qwen") && model_lower.contains("coder") { - // QwenCoder models use pipes - format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>") - } else if model_lower.contains("codellama") { - format!("
 {prefix} {suffix} ")
-        } else if model_lower.contains("deepseek") {
-            format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
-        } else if model_lower.contains("codestral") {
-            // Codestral uses suffix-first order
-            format!("[SUFFIX]{suffix}[PREFIX]{prefix}")
-        } else if model_lower.contains("codegemma") {
-            format!("<|fim_prefix|>{prefix}<|fim_suffix|>{suffix}<|fim_middle|>")
-        } else if model_lower.contains("wizardcoder") {
-            // WizardCoder models inherit patterns from their base model
-            if model_lower.contains("deepseek") {
-                format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
-            } else {
-                // Most WizardCoder models use stable code pattern
-                format!("{prefix}{suffix}")
-            }
-        } else if model_lower.contains("starcoder")
-            || model_lower.contains("santacoder")
-            || model_lower.contains("stable")
-            || model_lower.contains("qwen")
-            || model_lower.contains("replit")
-        {
-            // Stable code pattern (no pipes) - used by StarCoder, SantaCoder, StableCode,
-            // non-coder Qwen models, and Replit models
-            format!("{prefix}{suffix}")
-        } else {
-            // Default to stable code pattern for unknown models
-            format!("{prefix}{suffix}")
-        }
-    }
-
     fn get_stop_tokens(&self) -> Vec {
-        let model_lower = self.model.to_lowercase();
-
-        let mut stop_tokens = vec!["\n\n".to_string(), "```".to_string()];
-
-        if model_lower.contains("qwen") && model_lower.contains("coder") {
-            stop_tokens.extend(vec![
-                "<|endoftext|>".to_string(),
-                "<|fim_prefix|>".to_string(),
-                "<|fim_middle|>".to_string(),
-                "<|fim_suffix|>".to_string(),
-                "<|fim_pad|>".to_string(),
-                "<|repo_name|>".to_string(),
-                "<|file_sep|>".to_string(),
-                "<|im_start|>".to_string(),
-                "<|im_end|>".to_string(),
-            ]);
-        } else if model_lower.contains("codellama") {
-            stop_tokens.extend(vec![
-                "
".to_string(),
-                "".to_string(),
-                "".to_string(),
-                "
".to_string(), - ]); - } else if model_lower.contains("deepseek") { - stop_tokens.extend(vec![ - "<|fim▁begin|>".to_string(), - "<|fim▁hole|>".to_string(), - "<|fim▁end|>".to_string(), - "//".to_string(), - "<|end▁of▁sentence|>".to_string(), - ]); - } else if model_lower.contains("codestral") { - stop_tokens.extend(vec!["[PREFIX]".to_string(), "[SUFFIX]".to_string()]); - } else if model_lower.contains("codegemma") { - stop_tokens.extend(vec![ - "<|fim_prefix|>".to_string(), - "<|fim_suffix|>".to_string(), - "<|fim_middle|>".to_string(), - "<|file_separator|>".to_string(), - "<|endoftext|>".to_string(), - ]); - } else if model_lower.contains("wizardcoder") { - // WizardCoder models inherit patterns from their base model - if model_lower.contains("deepseek") { - stop_tokens.extend(vec![ - "<|fim▁begin|>".to_string(), - "<|fim▁hole|>".to_string(), - "<|fim▁end|>".to_string(), - ]); - } else { - stop_tokens.extend(vec![ - "".to_string(), - "".to_string(), - "".to_string(), - "<|endoftext|>".to_string(), - ]); - } - } else if model_lower.contains("starcoder") - || model_lower.contains("santacoder") - || model_lower.contains("stable") - || model_lower.contains("qwen") - || model_lower.contains("replit") - { - // Stable code pattern stop tokens - stop_tokens.extend(vec![ - "".to_string(), - "".to_string(), - "".to_string(), - "<|endoftext|>".to_string(), - ]); - } else { - // Generic stop tokens for unknown models - cover both patterns - stop_tokens.extend(vec![ - "<|fim_prefix|>".to_string(), - "<|fim_suffix|>".to_string(), - "<|fim_middle|>".to_string(), - "".to_string(), - "".to_string(), - "".to_string(), - "<|endoftext|>".to_string(), - ]); - } - - stop_tokens + // Basic stop tokens for code completion + // Ollama handles FIM tokens internally, so we only need general completion stops + vec![ + "\n\n".to_string(), // Double newline often indicates end of completion + "```".to_string(), // Code block delimiter + "<|endoftext|>".to_string(), // Common model end token + ] } fn clean_completion(&self, completion: &str) -> String { let mut cleaned = completion.to_string(); - // Remove common FIM tokens that might appear in responses - let fim_tokens = [ - "<|fim_prefix|>", - "<|fim_suffix|>", - "<|fim_middle|>", - "<|fim_pad|>", - "<|repo_name|>", - "<|file_sep|>", - "<|im_start|>", - "<|im_end|>", - "", - "", - "", - "
",
-            "",
-            "",
-            "
", - "<|fim▁begin|>", - "<|fim▁hole|>", - "<|fim▁end|>", - "<|end▁of▁sentence|>", - "[PREFIX]", - "[SUFFIX]", - "<|file_separator|>", - "<|endoftext|>", - ]; - - for token in &fim_tokens { - cleaned = cleaned.replace(token, ""); - } - - // Remove leading/trailing whitespace and common prefixes + // Basic cleaning - Ollama should handle FIM tokens internally + // Remove leading/trailing whitespace cleaned = cleaned.trim().to_string(); // Remove common unwanted prefixes that models sometimes generate @@ -288,14 +142,13 @@ impl EditPredictionProvider for OllamaCompletionProvider { this.extract_context(buffer_snapshot, cursor_position) })?; - let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?; - let (model, stop_tokens) = this.update(cx, |this, _| (this.model.clone(), this.get_stop_tokens()))?; let request = GenerateRequest { model, - prompt, + prompt: prefix, + suffix: Some(suffix), stream: false, options: Some(GenerateOptions { num_predict: Some(150), // Reasonable completion length @@ -384,216 +237,7 @@ mod tests { use std::sync::Arc; #[gpui::test] - async fn test_fim_prompt_qwen_coder_pattern(_cx: &mut TestAppContext) { - let provider = OllamaCompletionProvider::new( - Arc::new(FakeHttpClient::with_404_response()), - "http://localhost:11434".to_string(), - "qwen2.5-coder:32b".to_string(), - ); - - let prefix = "def hello():"; - let suffix = " pass"; - let prompt = provider.build_fim_prompt(prefix, suffix); - - assert!(prompt.contains("<|fim_prefix|>")); - assert!(prompt.contains("<|fim_suffix|>")); - assert!(prompt.contains("<|fim_middle|>")); - assert!(prompt.contains(prefix)); - assert!(prompt.contains(suffix)); - } - - #[gpui::test] - async fn test_fim_prompt_qwen_non_coder_pattern(_cx: &mut TestAppContext) { - let provider = OllamaCompletionProvider::new( - Arc::new(FakeHttpClient::with_404_response()), - "http://localhost:11434".to_string(), - "qwen2.5:32b".to_string(), - ); - - let prefix = "def hello():"; - let suffix = " pass"; - let prompt = provider.build_fim_prompt(prefix, suffix); - - assert!(prompt.contains("")); - assert!(prompt.contains("")); - assert!(prompt.contains("")); - assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes - assert!(prompt.contains(prefix)); - assert!(prompt.contains(suffix)); - } - - #[gpui::test] - async fn test_fim_prompt_codellama_pattern(_cx: &mut TestAppContext) { - let provider = OllamaCompletionProvider::new( - Arc::new(FakeHttpClient::with_404_response()), - "http://localhost:11434".to_string(), - "codellama:7b".to_string(), - ); - - let prefix = "function hello() {"; - let suffix = "}"; - let prompt = provider.build_fim_prompt(prefix, suffix); - - assert!(prompt.contains("
"));
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(prefix));
-        assert!(prompt.contains(suffix));
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "deepseek-coder:6.7b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        assert!(prompt.contains("<|fim▁begin|>"));
-        assert!(prompt.contains("<|fim▁hole|>"));
-        assert!(prompt.contains("<|fim▁end|>"));
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "starcoder:7b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_replit_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "replit-code:3b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        // Replit should use stable code pattern (no pipes)
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_codestral_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "codestral:22b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        // Codestral uses suffix-first order
-        assert!(prompt.contains("[SUFFIX]"));
-        assert!(prompt.contains("[PREFIX]"));
-        assert!(prompt.starts_with("[SUFFIX]"));
-        assert!(prompt.contains(prefix));
-        assert!(prompt.contains(suffix));
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_codegemma_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "codegemma:7b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        assert!(prompt.contains("<|fim_prefix|>"));
-        assert!(prompt.contains("<|fim_suffix|>"));
-        assert!(prompt.contains("<|fim_middle|>"));
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_wizardcoder_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "wizardcoder:13b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        // WizardCoder should use stable code pattern (no pipes) unless it's deepseek-based
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes
-    }
-
-    #[gpui::test]
-    async fn test_fim_prompt_santacoder_pattern(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "santacoder:1b".to_string(),
-        );
-
-        let prefix = "def hello():";
-        let suffix = "    pass";
-        let prompt = provider.build_fim_prompt(prefix, suffix);
-
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-        assert!(prompt.contains(""));
-    }
-
-    #[gpui::test]
-    async fn test_clean_completion_removes_fim_tokens(_cx: &mut TestAppContext) {
-        let provider = OllamaCompletionProvider::new(
-            Arc::new(FakeHttpClient::with_404_response()),
-            "http://localhost:11434".to_string(),
-            "qwen2.5-coder:32b".to_string(),
-        );
-
-        let completion_with_tokens = "console.log('hello');<|fim_middle|>";
-        let cleaned = provider.clean_completion(completion_with_tokens);
-        assert_eq!(cleaned, "console.log('hello');");
-
-        let completion_with_multiple_tokens = "<|fim_prefix|>console.log('hello');<|fim_suffix|>";
-        let cleaned = provider.clean_completion(completion_with_multiple_tokens);
-        assert_eq!(cleaned, "console.log('hello');");
-
-        let completion_with_starcoder_tokens = "console.log('hello');";
-        let cleaned = provider.clean_completion(completion_with_starcoder_tokens);
-        assert_eq!(cleaned, "console.log('hello');");
-
-        let completion_with_codestral_tokens = "console.log('hello');[SUFFIX]";
-        let cleaned = provider.clean_completion(completion_with_codestral_tokens);
-        assert_eq!(cleaned, "console.log('hello');");
-    }
-
-    #[gpui::test]
-    async fn test_get_stop_tokens_qwen_coder(_cx: &mut TestAppContext) {
+    async fn test_get_stop_tokens(_cx: &mut TestAppContext) {
         let provider = OllamaCompletionProvider::new(
             Arc::new(FakeHttpClient::with_404_response()),
             "http://localhost:11434".to_string(),
@@ -601,40 +245,27 @@ mod tests {
         );
 
         let stop_tokens = provider.get_stop_tokens();
-        assert!(stop_tokens.contains(&"<|fim_prefix|>".to_string()));
-        assert!(stop_tokens.contains(&"<|fim_suffix|>".to_string()));
-        assert!(stop_tokens.contains(&"<|fim_middle|>".to_string()));
+        assert!(stop_tokens.contains(&"\n\n".to_string()));
+        assert!(stop_tokens.contains(&"```".to_string()));
         assert!(stop_tokens.contains(&"<|endoftext|>".to_string()));
-        assert!(stop_tokens.contains(&"<|fim_pad|>".to_string()));
-        assert!(stop_tokens.contains(&"<|repo_name|>".to_string()));
+        assert_eq!(stop_tokens.len(), 3);
     }
 
     #[gpui::test]
-    async fn test_get_stop_tokens_codellama(_cx: &mut TestAppContext) {
+    async fn test_clean_completion_basic(_cx: &mut TestAppContext) {
         let provider = OllamaCompletionProvider::new(
             Arc::new(FakeHttpClient::with_404_response()),
             "http://localhost:11434".to_string(),
-            "codellama:7b".to_string(),
+            "qwen2.5-coder:32b".to_string(),
         );
 
-        let stop_tokens = provider.get_stop_tokens();
-        assert!(stop_tokens.contains(&"
".to_string()));
-        assert!(stop_tokens.contains(&"".to_string()));
-        assert!(stop_tokens.contains(&"".to_string()));
-        assert!(stop_tokens.contains(&"
".to_string())); - } + let completion = " console.log('hello'); "; + let cleaned = provider.clean_completion(completion); + assert_eq!(cleaned, "console.log('hello');"); - #[gpui::test] - async fn test_get_stop_tokens_codestral(_cx: &mut TestAppContext) { - let provider = OllamaCompletionProvider::new( - Arc::new(FakeHttpClient::with_404_response()), - "http://localhost:11434".to_string(), - "codestral:7b".to_string(), - ); - - let stop_tokens = provider.get_stop_tokens(); - assert!(stop_tokens.contains(&"[PREFIX]".to_string())); - assert!(stop_tokens.contains(&"[SUFFIX]".to_string())); + let completion_with_prefix = "// COMPLETION HERE\nconsole.log('hello');"; + let cleaned = provider.clean_completion(completion_with_prefix); + assert_eq!(cleaned, "console.log('hello');"); } #[gpui::test] @@ -738,49 +369,59 @@ mod tests { provider.update_model("qwen2.5-coder:32b".to_string()); assert_eq!(provider.model, "qwen2.5-coder:32b"); - // Test FIM prompt changes with different model - let prefix = "def hello():"; - let suffix = " pass"; - let prompt = provider.build_fim_prompt(prefix, suffix); - - // Should now use qwen coder pattern (with pipes) - assert!(prompt.contains("<|fim_prefix|>")); - assert!(prompt.contains("<|fim_suffix|>")); - assert!(prompt.contains("<|fim_middle|>")); - - // Update to regular Qwen model (non-coder) + // Test updating to different models provider.update_model("qwen2.5:32b".to_string()); assert_eq!(provider.model, "qwen2.5:32b"); - let prompt = provider.build_fim_prompt(prefix, suffix); - - // Should now use stable code pattern (no pipes) - assert!(prompt.contains("")); - assert!(prompt.contains("")); - assert!(prompt.contains("")); - assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes - - // Update to starcoder model provider.update_model("starcoder:7b".to_string()); assert_eq!(provider.model, "starcoder:7b"); - let prompt = provider.build_fim_prompt(prefix, suffix); - - // Should also use stable code pattern (no pipes) - assert!(prompt.contains("")); - assert!(prompt.contains("")); - assert!(prompt.contains("")); - assert!(!prompt.contains("<|fim_prefix|>")); // Should NOT contain pipes - - // Update to codestral model provider.update_model("codestral:22b".to_string()); assert_eq!(provider.model, "codestral:22b"); - let prompt = provider.build_fim_prompt(prefix, suffix); + // FIM patterns are now handled by Ollama natively, so we just test model updates + provider.update_model("deepseek-coder:6.7b".to_string()); + assert_eq!(provider.model, "deepseek-coder:6.7b"); + } - // Should use codestral pattern (suffix-first) - assert!(prompt.contains("[SUFFIX]")); - assert!(prompt.contains("[PREFIX]")); - assert!(prompt.starts_with("[SUFFIX]")); + #[gpui::test] + async fn test_native_fim_request_structure(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "qwen2.5-coder:32b".to_string(), + ); + + let prefix = "def fibonacci(n):"; + let suffix = " return result"; + + // Test that we create the correct request structure for native FIM + let request = GenerateRequest { + model: provider.model.clone(), + prompt: prefix.to_string(), + suffix: Some(suffix.to_string()), + stream: false, + options: Some(GenerateOptions { + num_predict: Some(150), + temperature: Some(0.1), + top_p: Some(0.95), + stop: Some(provider.get_stop_tokens()), + }), + keep_alive: None, + context: None, + }; + + // Verify the request structure uses native FIM approach + assert_eq!(request.model, "qwen2.5-coder:32b"); + assert_eq!(request.prompt, "def fibonacci(n):"); + assert_eq!(request.suffix, Some(" return result".to_string())); + assert!(!request.stream); + + // Verify stop tokens are simplified (no FIM-specific tokens) + let stop_tokens = request.options.as_ref().unwrap().stop.as_ref().unwrap(); + assert!(stop_tokens.contains(&"\n\n".to_string())); + assert!(stop_tokens.contains(&"```".to_string())); + assert!(stop_tokens.contains(&"<|endoftext|>".to_string())); + assert_eq!(stop_tokens.len(), 3); } }