diff --git a/Cargo.lock b/Cargo.lock index ef2f698d0a..89ccf94c19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10807,10 +10807,17 @@ version = "0.1.0" dependencies = [ "anyhow", "futures 0.3.31", + "gpui", "http_client", + "indoc", + "inline_completion", + "language", + "multi_buffer", + "project", "schemars", "serde", "serde_json", + "text", "workspace-hack", ] @@ -20005,6 +20012,7 @@ dependencies = [ "nix 0.29.0", "node_runtime", "notifications", + "ollama", "outline", "outline_panel", "parking_lot", diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs index 4e9c887124..bc40bf95b2 100644 --- a/crates/inline_completion_button/src/inline_completion_button.rs +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -358,6 +358,41 @@ impl Render for InlineCompletionButton { div().child(popover_menu.into_any_element()) } + + EditPredictionProvider::Ollama => { + let enabled = self.editor_enabled.unwrap_or(false); + let icon = if enabled { + IconName::AiOllama + } else { + IconName::AiOllama // Could add disabled variant + }; + + let this = cx.entity().clone(); + + div().child( + PopoverMenu::new("ollama") + .menu(move |window, cx| { + Some( + this.update(cx, |this, cx| { + this.build_ollama_context_menu(window, cx) + }), + ) + }) + .trigger( + IconButton::new("ollama-completion", icon) + .icon_size(IconSize::Small) + .tooltip(|window, cx| { + Tooltip::for_action( + "Ollama Completion", + &ToggleMenu, + window, + cx, + ) + }), + ) + .with_handle(self.popover_menu_handle.clone()), + ) + } } } } @@ -805,6 +840,26 @@ impl InlineCompletionButton { }) } + fn build_ollama_context_menu( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + let fs = self.fs.clone(); + ContextMenu::build(window, cx, |menu, _window, _cx| { + menu.entry("Toggle Ollama Completions", None, { + let fs = fs.clone(); + move |_window, cx| { + toggle_inline_completions_globally(fs.clone(), cx); + } + }) + .entry("Ollama Settings...", None, |_window, cx| { + // TODO: Open Ollama-specific settings + cx.open_url("http://localhost:11434"); + }) + }) + } + pub fn update_enabled(&mut self, editor: Entity, cx: &mut Context) { let editor = editor.read(cx); let snapshot = editor.buffer().read(cx).snapshot(cx); diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 9dda60b6a6..31187f4f15 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -216,6 +216,7 @@ pub enum EditPredictionProvider { Copilot, Supermaven, Zed, + Ollama, } impl EditPredictionProvider { @@ -224,7 +225,8 @@ impl EditPredictionProvider { EditPredictionProvider::Zed => true, EditPredictionProvider::None | EditPredictionProvider::Copilot - | EditPredictionProvider::Supermaven => false, + | EditPredictionProvider::Supermaven + | EditPredictionProvider::Ollama => false, } } } diff --git a/crates/ollama/Cargo.toml b/crates/ollama/Cargo.toml index 2765f23400..3839d142e8 100644 --- a/crates/ollama/Cargo.toml +++ b/crates/ollama/Cargo.toml @@ -9,17 +9,34 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/ollama.rs" +path = "src/lib.rs" [features] default = [] schemars = ["dep:schemars"] +test-support = [ + "gpui/test-support", + "http_client/test-support", + "language/test-support", +] [dependencies] anyhow.workspace = true futures.workspace = true +gpui.workspace = true http_client.workspace = true +inline_completion.workspace = true +language.workspace = true +multi_buffer.workspace = true +project.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +text.workspace = true workspace-hack.workspace = true + +[dev-dependencies] +gpui = { workspace = true, features = ["test-support"] } +http_client = { workspace = true, features = ["test-support"] } +indoc.workspace = true +language = { workspace = true, features = ["test-support"] } diff --git a/crates/ollama/src/lib.rs b/crates/ollama/src/lib.rs new file mode 100644 index 0000000000..80b07985c5 --- /dev/null +++ b/crates/ollama/src/lib.rs @@ -0,0 +1,5 @@ +mod ollama; +mod ollama_completion_provider; + +pub use ollama::*; +pub use ollama_completion_provider::*; diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 109fea7353..ac8251738e 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -98,6 +98,38 @@ impl Model { } } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateRequest { + pub model: String, + pub prompt: String, + pub stream: bool, + pub options: Option, + pub keep_alive: Option, + pub context: Option>, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateOptions { + pub num_predict: Option, + pub temperature: Option, + pub top_p: Option, + pub stop: Option>, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateResponse { + pub response: String, + pub done: bool, + pub context: Option>, + pub total_duration: Option, + pub load_duration: Option, + pub prompt_eval_count: Option, + pub eval_count: Option, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { @@ -359,6 +391,36 @@ pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Ok(details) } +pub async fn generate( + client: &dyn HttpClient, + api_url: &str, + request: GenerateRequest, +) -> Result { + let uri = format!("{api_url}/api/generate"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + let serialized_request = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(serialized_request))?; + + let mut response = client.send(request).await?; + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::ensure!( + response.status().is_success(), + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + ); + + let response: GenerateResponse = + serde_json::from_str(&body).context("Unable to parse Ollama generate response")?; + Ok(response) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/ollama/src/ollama_completion_provider.rs b/crates/ollama/src/ollama_completion_provider.rs new file mode 100644 index 0000000000..02abe6c935 --- /dev/null +++ b/crates/ollama/src/ollama_completion_provider.rs @@ -0,0 +1,363 @@ +use crate::{GenerateOptions, GenerateRequest, generate}; +use anyhow::{Context as AnyhowContext, Result}; +use gpui::{App, Context, Entity, EntityId, Task}; +use http_client::HttpClient; +use inline_completion::{Direction, EditPredictionProvider, InlineCompletion}; +use language::{Anchor, Buffer, ToOffset}; +use project::Project; +use std::{path::Path, sync::Arc, time::Duration}; + +pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); + +pub struct OllamaCompletionProvider { + http_client: Arc, + api_url: String, + model: String, + buffer_id: Option, + file_extension: Option, + current_completion: Option, + pending_refresh: Option>>, +} + +impl OllamaCompletionProvider { + pub fn new(http_client: Arc, api_url: String, model: String) -> Self { + Self { + http_client, + api_url, + model, + buffer_id: None, + file_extension: None, + current_completion: None, + pending_refresh: None, + } + } + + fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String { + // Use model-specific FIM patterns + match self.model.as_str() { + m if m.contains("codellama") => { + format!("
 {prefix} {suffix} ")
+            }
+            m if m.contains("deepseek") => {
+                format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
+            }
+            m if m.contains("starcoder") => {
+                format!("{prefix}{suffix}")
+            }
+            _ => {
+                // Generic FIM pattern - fallback for models without specific support
+                format!("// Complete the following code:\n{prefix}\n// COMPLETION HERE\n{suffix}")
+            }
+        }
+    }
+
+    fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
+        let cursor_offset = cursor_position.to_offset(buffer);
+        let text = buffer.text();
+
+        // Get reasonable context around cursor
+        let context_size = 2000; // 2KB before and after cursor
+
+        let start = cursor_offset.saturating_sub(context_size);
+        let end = (cursor_offset + context_size).min(text.len());
+
+        let prefix = text[start..cursor_offset].to_string();
+        let suffix = text[cursor_offset..end].to_string();
+
+        (prefix, suffix)
+    }
+}
+
+impl EditPredictionProvider for OllamaCompletionProvider {
+    fn name() -> &'static str {
+        "ollama"
+    }
+
+    fn display_name() -> &'static str {
+        "Ollama"
+    }
+
+    fn show_completions_in_menu() -> bool {
+        false
+    }
+
+    fn is_enabled(&self, _buffer: &Entity, _cursor_position: Anchor, _cx: &App) -> bool {
+        // TODO: Could ping Ollama API to check if it's running
+        true
+    }
+
+    fn is_refreshing(&self) -> bool {
+        self.pending_refresh.is_some()
+    }
+
+    fn refresh(
+        &mut self,
+        _project: Option>,
+        buffer: Entity,
+        cursor_position: Anchor,
+        debounce: bool,
+        cx: &mut Context,
+    ) {
+        let http_client = self.http_client.clone();
+        let api_url = self.api_url.clone();
+        let model = self.model.clone();
+
+        self.pending_refresh = Some(cx.spawn(async move |this, cx| {
+            if debounce {
+                cx.background_executor()
+                    .timer(OLLAMA_DEBOUNCE_TIMEOUT)
+                    .await;
+            }
+
+            let (prefix, suffix) = this.update(cx, |this, cx| {
+                let buffer_snapshot = buffer.read(cx);
+                this.buffer_id = Some(buffer.entity_id());
+                this.file_extension = buffer_snapshot.file().and_then(|file| {
+                    Some(
+                        Path::new(file.file_name(cx))
+                            .extension()?
+                            .to_str()?
+                            .to_string(),
+                    )
+                });
+                this.extract_context(buffer_snapshot, cursor_position)
+            })?;
+
+            let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
+
+            let request = GenerateRequest {
+                model: model.clone(),
+                prompt,
+                stream: false,
+                options: Some(GenerateOptions {
+                    num_predict: Some(150), // Reasonable completion length
+                    temperature: Some(0.1), // Low temperature for more deterministic results
+                    top_p: Some(0.95),
+                    stop: Some(vec![
+                        "\n\n".to_string(),
+                        "```".to_string(),
+                        "
".to_string(), + "".to_string(), + ]), + }), + keep_alive: None, + context: None, + }; + + let response = generate(http_client.as_ref(), &api_url, request) + .await + .context("Failed to get completion from Ollama")?; + + this.update(cx, |this, cx| { + this.pending_refresh = None; + if !response.response.trim().is_empty() { + this.current_completion = Some(response.response); + } else { + this.current_completion = None; + } + cx.notify(); + })?; + + Ok(()) + })); + } + + fn cycle( + &mut self, + _buffer: Entity, + _cursor_position: Anchor, + _direction: Direction, + _cx: &mut Context, + ) { + // Ollama doesn't provide multiple completions in a single request + // Could be implemented by making multiple requests with different temperatures + // or by using different models + } + + fn accept(&mut self, _cx: &mut Context) { + self.current_completion = None; + // TODO: Could send accept telemetry to Ollama if supported + } + + fn discard(&mut self, _cx: &mut Context) { + self.current_completion = None; + // TODO: Could send discard telemetry to Ollama if supported + } + + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: Anchor, + cx: &mut Context, + ) -> Option { + let buffer_id = buffer.entity_id(); + if Some(buffer_id) != self.buffer_id { + return None; + } + + let completion_text = self.current_completion.as_ref()?.clone(); + + if completion_text.trim().is_empty() { + return None; + } + + let buffer_snapshot = buffer.read(cx); + let position = cursor_position.bias_right(buffer_snapshot); + + // Clean up the completion text + let completion_text = completion_text.trim_start().trim_end(); + + Some(InlineCompletion { + id: None, + edits: vec![(position..position, completion_text.to_string())], + edit_preview: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use gpui::{AppContext, TestAppContext}; + use http_client::FakeHttpClient; + use std::sync::Arc; + + #[gpui::test] + async fn test_fim_prompt_patterns(_cx: &mut TestAppContext) { + let provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "codellama:7b".to_string(), + ); + + let prefix = "function hello() {"; + let suffix = "}"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + assert!(prompt.contains("
"));
+        assert!(prompt.contains(""));
+        assert!(prompt.contains(""));
+        assert!(prompt.contains(prefix));
+        assert!(prompt.contains(suffix));
+    }
+
+    #[gpui::test]
+    async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
+        let provider = OllamaCompletionProvider::new(
+            Arc::new(FakeHttpClient::with_404_response()),
+            "http://localhost:11434".to_string(),
+            "deepseek-coder:6.7b".to_string(),
+        );
+
+        let prefix = "def hello():";
+        let suffix = "    pass";
+        let prompt = provider.build_fim_prompt(prefix, suffix);
+
+        assert!(prompt.contains("<|fim▁begin|>"));
+        assert!(prompt.contains("<|fim▁hole|>"));
+        assert!(prompt.contains("<|fim▁end|>"));
+    }
+
+    #[gpui::test]
+    async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
+        let provider = OllamaCompletionProvider::new(
+            Arc::new(FakeHttpClient::with_404_response()),
+            "http://localhost:11434".to_string(),
+            "starcoder:7b".to_string(),
+        );
+
+        let prefix = "def hello():";
+        let suffix = "    pass";
+        let prompt = provider.build_fim_prompt(prefix, suffix);
+
+        assert!(prompt.contains(""));
+        assert!(prompt.contains(""));
+        assert!(prompt.contains(""));
+    }
+
+    #[gpui::test]
+    async fn test_extract_context(cx: &mut TestAppContext) {
+        let provider = OllamaCompletionProvider::new(
+            Arc::new(FakeHttpClient::with_404_response()),
+            "http://localhost:11434".to_string(),
+            "codellama:7b".to_string(),
+        );
+
+        // Create a simple buffer using test context
+        let buffer_text = "function example() {\n    let x = 1;\n    let y = 2;\n    // cursor here\n    return x + y;\n}";
+        let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
+
+        // Position cursor at the end of the "// cursor here" line
+        let (prefix, suffix, _cursor_position) = cx.read(|cx| {
+            let buffer_snapshot = buffer.read(cx);
+            let cursor_position = buffer_snapshot.anchor_after(text::Point::new(3, 15)); // End of "// cursor here"
+            let (prefix, suffix) = provider.extract_context(&buffer_snapshot, cursor_position);
+            (prefix, suffix, cursor_position)
+        });
+
+        assert!(prefix.contains("function example()"));
+        assert!(prefix.contains("// cursor h"));
+        assert!(suffix.contains("ere"));
+        assert!(suffix.contains("return x + y"));
+        assert!(suffix.contains("}"));
+    }
+
+    #[gpui::test]
+    async fn test_suggest_with_completion(cx: &mut TestAppContext) {
+        let provider = cx.new(|_| {
+            OllamaCompletionProvider::new(
+                Arc::new(FakeHttpClient::with_404_response()),
+                "http://localhost:11434".to_string(),
+                "codellama:7b".to_string(),
+            )
+        });
+
+        let buffer_text = "// test";
+        let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
+
+        // Set up a mock completion
+        provider.update(cx, |provider, _| {
+            provider.current_completion = Some("console.log('hello');".to_string());
+            provider.buffer_id = Some(buffer.entity_id());
+        });
+
+        let cursor_position = cx.read(|cx| buffer.read(cx).anchor_after(text::Point::new(0, 7)));
+
+        let completion = provider.update(cx, |provider, cx| {
+            provider.suggest(&buffer, cursor_position, cx)
+        });
+
+        assert!(completion.is_some());
+        let completion = completion.unwrap();
+        assert_eq!(completion.edits.len(), 1);
+        assert_eq!(completion.edits[0].1, "console.log('hello');");
+    }
+
+    #[gpui::test]
+    async fn test_suggest_empty_completion(cx: &mut TestAppContext) {
+        let provider = cx.new(|_| {
+            OllamaCompletionProvider::new(
+                Arc::new(FakeHttpClient::with_404_response()),
+                "http://localhost:11434".to_string(),
+                "codellama:7b".to_string(),
+            )
+        });
+
+        let buffer_text = "// test";
+        let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
+
+        // Set up an empty completion
+        provider.update(cx, |provider, _| {
+            provider.current_completion = Some("   ".to_string()); // Only whitespace
+            provider.buffer_id = Some(buffer.entity_id());
+        });
+
+        let cursor_position = cx.read(|cx| buffer.read(cx).anchor_after(text::Point::new(0, 7)));
+
+        let completion = provider.update(cx, |provider, cx| {
+            provider.suggest(&buffer, cursor_position, cx)
+        });
+
+        assert!(completion.is_none());
+    }
+}
diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml
index 4e426c3837..f37aa371ff 100644
--- a/crates/zed/Cargo.toml
+++ b/crates/zed/Cargo.toml
@@ -71,6 +71,7 @@ image_viewer.workspace = true
 indoc.workspace = true
 inline_completion_button.workspace = true
 inspector_ui.workspace = true
+ollama.workspace = true
 install_cli.workspace = true
 jj_ui.workspace = true
 journal.workspace = true
diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs
index f2e9d21b96..bf6022be8e 100644
--- a/crates/zed/src/zed/inline_completion_registry.rs
+++ b/crates/zed/src/zed/inline_completion_registry.rs
@@ -4,7 +4,9 @@ use copilot::{Copilot, CopilotCompletionProvider};
 use editor::Editor;
 use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
 use language::language_settings::{EditPredictionProvider, all_language_settings};
-use settings::SettingsStore;
+use language_models::AllLanguageModelSettings;
+use ollama::OllamaCompletionProvider;
+use settings::{Settings, SettingsStore};
 use smol::stream::StreamExt;
 use std::{cell::RefCell, rc::Rc, sync::Arc};
 use supermaven::{Supermaven, SupermavenCompletionProvider};
@@ -129,7 +131,8 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) {
                         }
                         EditPredictionProvider::None
                         | EditPredictionProvider::Copilot
-                        | EditPredictionProvider::Supermaven => {}
+                        | EditPredictionProvider::Supermaven
+                        | EditPredictionProvider::Ollama => {}
                     }
                 }
             }
@@ -283,5 +286,20 @@ fn assign_edit_prediction_provider(
                 editor.set_edit_prediction_provider(Some(provider), window, cx);
             }
         }
+        EditPredictionProvider::Ollama => {
+            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+            let api_url = settings.api_url.clone();
+
+            // Use first available model or default
+            let model = settings
+                .available_models
+                .first()
+                .map(|m| m.name.clone())
+                .unwrap_or_else(|| "codellama:7b".to_string());
+
+            let provider =
+                cx.new(|_| OllamaCompletionProvider::new(client.http_client(), api_url, model));
+            editor.set_edit_prediction_provider(Some(provider), window, cx);
+        }
     }
 }