diff --git a/Cargo.lock b/Cargo.lock index 42649b137f..e1f06fdd85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5001,6 +5001,7 @@ version = "0.1.0" dependencies = [ "anyhow", "client", + "clock", "cloud_llm_client", "copilot", "edit_prediction", @@ -5009,9 +5010,14 @@ dependencies = [ "fs", "futures 0.3.31", "gpui", + "gpui_tokio", + "http_client", "indoc", "language", + "language_model", + "language_models", "lsp", + "ollama", "paths", "project", "regex", @@ -11065,11 +11071,22 @@ name = "ollama" version = "0.1.0" dependencies = [ "anyhow", + "client", + "edit_prediction", + "editor", "futures 0.3.31", + "gpui", "http_client", + "language", + "log", + "project", "schemars", "serde", "serde_json", + "settings", + "text", + "theme", + "workspace", "workspace-hack", ] @@ -20479,6 +20496,7 @@ dependencies = [ "nix 0.29.0", "node_runtime", "notifications", + "ollama", "onboarding", "outline", "outline_panel", diff --git a/crates/edit_prediction_button/Cargo.toml b/crates/edit_prediction_button/Cargo.toml index 07447280fa..65465c44ef 100644 --- a/crates/edit_prediction_button/Cargo.toml +++ b/crates/edit_prediction_button/Cargo.toml @@ -21,9 +21,13 @@ editor.workspace = true feature_flags.workspace = true fs.workspace = true gpui.workspace = true +http_client.workspace = true indoc.workspace = true edit_prediction.workspace = true language.workspace = true +language_models.workspace = true +ollama.workspace = true + paths.workspace = true project.workspace = true regex.workspace = true @@ -37,11 +41,18 @@ zed_actions.workspace = true zeta.workspace = true [dev-dependencies] +clock.workspace = true +client = { workspace = true, features = ["test-support"] } copilot = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } futures.workspace = true +http_client = { workspace = true, features = ["test-support"] } indoc.workspace = true +language_model = { workspace = true, features = ["test-support"] } lsp = { workspace = true, features = ["test-support"] } +ollama = { workspace = true, features = ["test-support"] } project = { workspace = true, features = ["test-support"] } serde_json.workspace = true +settings = { workspace = true, features = ["test-support"] } theme = { workspace = true, features = ["test-support"] } +gpui_tokio.workspace = true diff --git a/crates/edit_prediction_button/src/edit_prediction_button.rs b/crates/edit_prediction_button/src/edit_prediction_button.rs index 0e3fe8cb1a..e4efb35384 100644 --- a/crates/edit_prediction_button/src/edit_prediction_button.rs +++ b/crates/edit_prediction_button/src/edit_prediction_button.rs @@ -10,11 +10,17 @@ use gpui::{ Focusable, IntoElement, ParentElement, Render, Subscription, WeakEntity, actions, div, pulsating_between, }; + use indoc::indoc; use language::{ EditPredictionsMode, File, Language, language_settings::{self, AllLanguageSettings, EditPredictionProvider, all_language_settings}, }; +use language_models::AllLanguageModelSettings; + +use ollama; + +use paths; use project::DisableAiSettings; use regex::Regex; use settings::{Settings, SettingsStore, update_settings_file}; @@ -357,6 +363,41 @@ impl Render for EditPredictionButton { div().child(popover_menu.into_any_element()) } + + EditPredictionProvider::Ollama => { + let enabled = self.editor_enabled.unwrap_or(false); + let icon = if enabled { + IconName::AiOllama + } else { + IconName::AiOllama // Could add disabled variant + }; + + let this = cx.entity().clone(); + + div().child( + PopoverMenu::new("ollama") + .menu(move |window, cx| { + Some( + this.update(cx, |this, cx| { + this.build_ollama_context_menu(window, cx) + }), + ) + }) + .trigger( + IconButton::new("ollama-completion", icon) + .icon_size(IconSize::Small) + .tooltip(|window, cx| { + Tooltip::for_action( + "Ollama Completion", + &ToggleMenu, + window, + cx, + ) + }), + ) + .with_handle(self.popover_menu_handle.clone()), + ) + } } } } @@ -375,10 +416,14 @@ impl EditPredictionButton { cx.observe_global::(move |_, cx| cx.notify()) .detach(); + if let Some(service) = ollama::State::global(cx) { + cx.observe(&service, |_, _, cx| cx.notify()).detach(); + } + Self { editor_subscription: None, editor_enabled: None, - editor_show_predictions: true, + editor_show_predictions: false, editor_focus_handle: None, language: None, file: None, @@ -492,6 +537,7 @@ impl EditPredictionButton { EditPredictionProvider::Zed | EditPredictionProvider::Copilot | EditPredictionProvider::Supermaven + | EditPredictionProvider::Ollama ) { menu = menu .separator() @@ -813,6 +859,238 @@ impl EditPredictionButton { }) } + /// Builds a simplified context menu for Ollama with essential features: + /// - API URL configuration that opens settings at the correct location + /// - Model selection from available models + /// - Common language settings (buffer/language/global toggles, privacy settings) + /// + /// The menu focuses on core functionality without connection status or external links. + fn build_ollama_context_menu( + &self, + window: &mut Window, + cx: &mut Context, + ) -> Entity { + let fs = self.fs.clone(); + ContextMenu::build(window, cx, |menu, window, cx| { + let settings = AllLanguageModelSettings::get_global(cx); + let ollama_settings = &settings.ollama; + + // Get models from both settings and global service discovery + let mut available_models = ollama_settings.available_models.clone(); + + // Add discovered models from the global Ollama service + if let Some(service) = ollama::State::global(cx) { + let discovered_models = service.read(cx).available_models(); + for model in discovered_models { + // Convert from ollama::Model to language_models AvailableModel + let available_model = language_models::provider::ollama::AvailableModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + keep_alive: model.keep_alive.clone(), + supports_tools: model.supports_tools, + supports_images: model.supports_vision, + supports_thinking: model.supports_thinking, + }; + + // Add if not already in settings (settings take precedence) + if !available_models.iter().any(|m| m.name == model.name) { + available_models.push(available_model); + } + } + } + + // Check if ollama settings exist before building menu + let has_ollama_settings = Self::ollama_settings_exist_in_content( + &std::fs::read_to_string(paths::settings_file()).unwrap_or_default(), + ); + + // API URL configuration - only show if Ollama settings exist in the user's config + let menu = if has_ollama_settings { + menu.entry("Configure API URL", None, { + let fs = fs.clone(); + move |window, cx| { + Self::open_ollama_settings(fs.clone(), window, cx); + } + }) + } else { + menu + }; + + // Model selection section + let menu = if !available_models.is_empty() { + let menu = menu.separator().header("Available Models"); + + // Add each available model as a menu entry + let menu = available_models.iter().fold(menu, |menu, model| { + let model_name = model.display_name.as_ref().unwrap_or(&model.name); + let is_current = ollama_settings + .available_models + .first() + .map(|current_model| current_model.name == model.name) + .unwrap_or(false); + + menu.toggleable_entry( + model_name.clone(), + is_current, + IconPosition::Start, + None, + { + let model_name = model.name.clone(); + let fs = fs.clone(); + move |_window, cx| { + Self::switch_ollama_model(fs.clone(), model_name.clone(), cx); + } + }, + ) + }); + + // Add refresh models option + menu.separator().entry("Refresh Models", None, { + move |_window, cx| { + Self::refresh_ollama_models(cx); + } + }) + } else { + menu.separator() + .header("No Models Configured") + .entry("Configure Models", None, { + let fs = fs.clone(); + move |window, cx| { + Self::open_ollama_settings(fs.clone(), window, cx); + } + }) + .entry("Refresh Models", None, { + move |_window, cx| { + Self::refresh_ollama_models(cx); + } + }) + }; + + // Use the common language settings menu + self.build_language_settings_menu(menu, window, cx) + }) + } + + /// Opens Zed settings and navigates directly to the Ollama models configuration. + /// Uses improved regex patterns to locate the exact setting in the JSON structure. + fn open_ollama_settings(_fs: Arc, window: &mut Window, cx: &mut App) { + if let Some(workspace) = window.root::().flatten() { + let workspace = workspace.downgrade(); + window + .spawn(cx, async move |cx| { + let settings_editor = workspace + .update_in(cx, |_, window, cx| { + create_and_open_local_file(paths::settings_file(), window, cx, || { + settings::initial_user_settings_content().as_ref().into() + }) + })? + .await? + .downcast::() + .unwrap(); + + let _ = settings_editor + .downgrade() + .update_in(cx, |item, window, cx| { + let text = item.buffer().read(cx).snapshot(cx).text(); + + // Look for language_models.ollama section with precise pattern + // This matches the full nested structure to avoid false matches + let ollama_pattern = r#""language_models"\s*:\s*\{[\s\S]*?"ollama"\s*:\s*\{[\s\S]*?"available_models"\s*:\s*\[\s*\]"#; + let regex = regex::Regex::new(ollama_pattern).unwrap(); + + if let Some(captures) = regex.captures(&text) { + let full_match = captures.get(0).unwrap(); + + // Position cursor after the opening bracket of available_models array + let bracket_pos = full_match.as_str().rfind('[').unwrap(); + let cursor_pos = full_match.start() + bracket_pos + 1; + + // Place cursor inside the available_models array + item.change_selections( + SelectionEffects::scroll(Autoscroll::newest()), + window, + cx, + |selections| { + selections.select_ranges(vec![cursor_pos..cursor_pos]); + }, + ); + return Ok::<(), anyhow::Error>(()); + } + + Ok::<(), anyhow::Error>(()) + })?; + + Ok::<(), anyhow::Error>(()) + }) + .detach_and_log_err(cx); + } + } + + fn ollama_settings_exist_in_content(content: &str) -> bool { + let api_url_pattern = r#""language_models"\s*:\s*\{[\s\S]*?"ollama"\s*:\s*\{[\s\S]*?"api_url"\s*:\s*"([^"]*)"#; + let regex = regex::Regex::new(api_url_pattern).unwrap(); + regex.is_match(content) + } + + fn switch_ollama_model(fs: Arc, model_name: String, cx: &mut App) { + update_settings_file::(fs, cx, move |settings, cx| { + // Ensure ollama settings exist + if settings.ollama.is_none() { + settings.ollama = Some(language_models::OllamaSettingsContent { + api_url: None, + available_models: Some(Vec::new()), + }); + } + + let ollama_settings = settings.ollama.as_mut().unwrap(); + + // Ensure available_models exists + if ollama_settings.available_models.is_none() { + ollama_settings.available_models = Some(Vec::new()); + } + + let models = ollama_settings.available_models.as_mut().unwrap(); + + // Check if model is already in settings + if let Some(index) = models.iter().position(|m| m.name == model_name) { + // Move existing model to the front + let selected_model = models.remove(index); + models.insert(0, selected_model); + } else { + // Model not in settings - check if it's a discovered model and add it + if let Some(service) = ollama::State::global(cx) { + let discovered_models = service.read(cx).available_models(); + if let Some(discovered_model) = + discovered_models.iter().find(|m| m.name == model_name) + { + // Convert from ollama::Model to language_models AvailableModel + let available_model = language_models::provider::ollama::AvailableModel { + name: discovered_model.name.clone(), + display_name: discovered_model.display_name.clone(), + max_tokens: discovered_model.max_tokens, + keep_alive: discovered_model.keep_alive.clone(), + supports_tools: discovered_model.supports_tools, + supports_images: discovered_model.supports_vision, + supports_thinking: discovered_model.supports_thinking, + }; + + // Add the discovered model to the front of the list + models.insert(0, available_model); + } + } + } + }); + } + + fn refresh_ollama_models(cx: &mut App) { + if let Some(service) = ollama::State::global(cx) { + service.update(cx, |service, cx| { + service.refresh_models(cx); + }); + } + } + pub fn update_enabled(&mut self, editor: Entity, cx: &mut Context) { let editor = editor.read(cx); let snapshot = editor.buffer().read(cx).snapshot(cx); @@ -997,3 +1275,691 @@ fn toggle_edit_prediction_mode(fs: Arc, mode: EditPredictionsMode, cx: & }); } } + +#[cfg(test)] +mod tests { + use super::*; + use clock::FakeSystemClock; + use gpui::TestAppContext; + use http_client; + use ollama::{State, fake::FakeHttpClient}; + use settings::SettingsStore; + use std::sync::Arc; + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + gpui_tokio::init(cx); + theme::init(theme::LoadThemes::JustBase, cx); + language::init(cx); + language_settings::init(cx); + + // Initialize language_models settings for tests that need them + // Create client and user store for language_models::init + client::init_settings(cx); + let clock = Arc::new(FakeSystemClock::new()); + let http = http_client::FakeHttpClient::with_404_response(); + let client = client::Client::new(clock, http, cx); + let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx)); + + client::init(&client, cx); + language_model::init(client.clone(), cx); + language_models::init(user_store, client, cx); + }); + } + + #[gpui::test] + async fn test_ollama_menu_shows_discovered_models(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client with mock models response + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Mock /api/tags response + let models_response = serde_json::json!({ + "models": [ + { + "name": "qwen2.5-coder:3b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "qwen2", + "families": ["qwen2"], + "parameter_size": "3B", + "quantization_level": "Q4_0" + } + }, + { + "name": "codellama:7b-code", + "modified_at": "2024-01-01T00:00:00Z", + "size": 2000000, + "digest": "def456", + "details": { + "format": "gguf", + "family": "codellama", + "families": ["codellama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + + // Mock /api/show response + let capabilities = serde_json::json!({ + "capabilities": ["tools"] + }); + fake_http_client.set_response("/api/show", capabilities.to_string()); + + // Create and set global Ollama service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Wait for model discovery + cx.background_executor.run_until_parked(); + + // Verify models are accessible through the service + cx.update(|cx| { + if let Some(service) = State::global(cx) { + let discovered_models = service.read(cx).available_models(); + assert_eq!(discovered_models.len(), 2); + + let model_names: Vec<&str> = + discovered_models.iter().map(|m| m.name.as_str()).collect(); + assert!(model_names.contains(&"qwen2.5-coder:3b")); + assert!(model_names.contains(&"codellama:7b-code")); + } else { + panic!("Global service should be available"); + } + }); + + // Verify the global service has the expected models + service.read_with(cx, |service, _| { + let models = service.available_models(); + assert_eq!(models.len(), 2); + + let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + assert!(model_names.contains(&"qwen2.5-coder:3b")); + assert!(model_names.contains(&"codellama:7b-code")); + }); + } + + #[gpui::test] + async fn test_ollama_menu_shows_service_models(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client with models + let fake_http_client = Arc::new(FakeHttpClient::new()); + + let models_response = serde_json::json!({ + "models": [ + { + "name": "qwen2.5-coder:7b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "qwen2", + "families": ["qwen2"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + fake_http_client.set_response( + "/api/show", + serde_json::json!({"capabilities": []}).to_string(), + ); + + // Create and set global service + let service = cx.update(|cx| { + State::new( + fake_http_client, + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + cx.background_executor.run_until_parked(); + + // Test that discovered models are accessible + cx.update(|cx| { + if let Some(service) = State::global(cx) { + let discovered_models = service.read(cx).available_models(); + assert_eq!(discovered_models.len(), 1); + assert_eq!(discovered_models[0].name, "qwen2.5-coder:7b"); + } else { + panic!("Global service should be available"); + } + }); + } + + #[gpui::test] + async fn test_ollama_menu_refreshes_on_service_update(cx: &mut TestAppContext) { + init_test(cx); + + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Initially empty models + fake_http_client.set_response("/api/tags", serde_json::json!({"models": []}).to_string()); + + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify the service subscription mechanism works by creating a button + let _button = cx.update(|cx| { + let fs = fs::FakeFs::new(cx.background_executor().clone()); + let clock = Arc::new(FakeSystemClock::new()); + let http = http_client::FakeHttpClient::with_404_response(); + let client = client::Client::new(clock, http, cx); + let user_store = cx.new(|cx| client::UserStore::new(client, cx)); + let popover_handle = PopoverMenuHandle::default(); + + cx.new(|cx| EditPredictionButton::new(fs, user_store, popover_handle, cx)) + }); + + // Verify initially no models + service.read_with(cx, |service, _| { + assert_eq!(service.available_models().len(), 0); + }); + + // Update mock to return models + let models_response = serde_json::json!({ + "models": [ + { + "name": "phi3:mini", + "modified_at": "2024-01-01T00:00:00Z", + "size": 500000, + "digest": "xyz789", + "details": { + "format": "gguf", + "family": "phi3", + "families": ["phi3"], + "parameter_size": "3.8B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + fake_http_client.set_response( + "/api/show", + serde_json::json!({"capabilities": []}).to_string(), + ); + + // Trigger refresh + service.update(cx, |service, cx| { + service.refresh_models(cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify models were refreshed + service.read_with(cx, |service, _| { + let models = service.available_models(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "phi3:mini"); + }); + + // The button should have been notified and will rebuild its menu with new models + // when next requested (this tests the subscription mechanism) + } + + #[gpui::test] + async fn test_refresh_models_button_functionality(cx: &mut TestAppContext) { + init_test(cx); + + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Start with one model + let initial_response = serde_json::json!({ + "models": [ + { + "name": "mistral:7b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "initial123", + "details": { + "format": "gguf", + "family": "mistral", + "families": ["mistral"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", initial_response.to_string()); + fake_http_client.set_response( + "/api/show", + serde_json::json!({"capabilities": []}).to_string(), + ); + + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify initial model + service.read_with(cx, |service, _| { + assert_eq!(service.available_models().len(), 1); + assert_eq!(service.available_models()[0].name, "mistral:7b"); + }); + + // Update mock to simulate new model available + let updated_response = serde_json::json!({ + "models": [ + { + "name": "mistral:7b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "initial123", + "details": { + "format": "gguf", + "family": "mistral", + "families": ["mistral"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + }, + { + "name": "gemma2:9b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 2000000, + "digest": "new456", + "details": { + "format": "gguf", + "family": "gemma2", + "families": ["gemma2"], + "parameter_size": "9B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", updated_response.to_string()); + + // Simulate clicking "Refresh Models" button + cx.update(|cx| { + EditPredictionButton::refresh_ollama_models(cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify models were refreshed + service.read_with(cx, |service, _| { + let models = service.available_models(); + assert_eq!(models.len(), 2); + + let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + assert!(model_names.contains(&"mistral:7b")); + assert!(model_names.contains(&"gemma2:9b")); + }); + } + + #[gpui::test] + async fn test_ollama_menu_shows_discovered_models_for_selection(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client with mock models response + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Mock /api/tags response with a model not in settings + let models_response = serde_json::json!({ + "models": [ + { + "name": "discovered-model:latest", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + fake_http_client.set_response( + "/api/show", + serde_json::json!({"capabilities": []}).to_string(), + ); + + // Create and set global service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify model is discovered by the service + let discovered_model_exists = cx.update(|cx| { + if let Some(service) = State::global(cx) { + let discovered_models = service.read(cx).available_models(); + discovered_models + .iter() + .any(|m| m.name == "discovered-model:latest") + } else { + false + } + }); + assert!( + discovered_model_exists, + "Model should be discovered by service" + ); + + // Verify initial settings are empty + let settings_empty = cx.update(|cx| { + let settings = AllLanguageModelSettings::get_global(cx); + settings.ollama.available_models.is_empty() + }); + assert!(settings_empty, "Settings should initially be empty"); + + // Test the core logic: when a discovered model is selected, it should be available + // In the UI context, the menu should show discovered models even if not in settings + let menu_shows_discovered_model = cx.update(|cx| { + let settings = AllLanguageModelSettings::get_global(cx); + let ollama_settings = &settings.ollama; + + // Get models from both settings and global service discovery (like the UI does) + let mut available_models = ollama_settings.available_models.clone(); + + // Add discovered models from the global Ollama service + if let Some(service) = ollama::State::global(cx) { + let discovered_models = service.read(cx).available_models(); + for model in discovered_models { + // Convert from ollama::Model to language_models AvailableModel + let available_model = language_models::provider::ollama::AvailableModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + keep_alive: model.keep_alive.clone(), + supports_tools: model.supports_tools, + supports_images: model.supports_vision, + supports_thinking: model.supports_thinking, + }; + + // Add if not already in settings (settings take precedence) + if !available_models.iter().any(|m| m.name == model.name) { + available_models.push(available_model); + } + } + } + + available_models + .iter() + .any(|m| m.name == "discovered-model:latest") + }); + + assert!( + menu_shows_discovered_model, + "Menu should show discovered models even when not in settings" + ); + } + + #[gpui::test] + async fn test_ollama_discovered_model_menu_integration(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client with mock models response + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Mock /api/tags response with a model not in settings + let models_response = serde_json::json!({ + "models": [ + { + "name": "discovered-model:latest", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + fake_http_client.set_response( + "/api/show", + serde_json::json!({"capabilities": []}).to_string(), + ); + + // Create and set global service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + cx.background_executor.run_until_parked(); + + // Test the core functionality: discovered models should be available for the UI + // This simulates what the build_ollama_context_menu function does + cx.update(|cx| { + let settings = AllLanguageModelSettings::get_global(cx); + let ollama_settings = &settings.ollama; + + // Get models from both settings and global service discovery (like the UI does) + let mut available_models = ollama_settings.available_models.clone(); + + // Add discovered models from the global Ollama service + if let Some(service) = ollama::State::global(cx) { + let discovered_models = service.read(cx).available_models(); + for model in discovered_models { + // Convert from ollama::Model to language_models AvailableModel + let available_model = language_models::provider::ollama::AvailableModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + keep_alive: model.keep_alive.clone(), + supports_tools: model.supports_tools, + supports_images: model.supports_vision, + supports_thinking: model.supports_thinking, + }; + + // Add if not already in settings (settings take precedence) + if !available_models.iter().any(|m| m.name == model.name) { + available_models.push(available_model); + } + } + } + + // The key test: discovered models should now be available for selection + assert_eq!(available_models.len(), 1); + assert_eq!(available_models[0].name, "discovered-model:latest"); + + // Verify that the switch_ollama_model function can find the discovered model + // by checking it exists in the service + if let Some(service) = ollama::State::global(cx) { + let discovered_models = service.read(cx).available_models(); + let found_model = discovered_models + .iter() + .find(|m| m.name == "discovered-model:latest"); + assert!( + found_model.is_some(), + "Model should be discoverable by the service for selection" + ); + } + }); + } + + #[gpui::test] + async fn test_switch_ollama_model_with_discovered_model(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client with mock models response + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Mock /api/tags response with a model not in settings + let models_response = serde_json::json!({ + "models": [ + { + "name": "test-model:latest", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + fake_http_client.set_response( + "/api/show", + serde_json::json!({"capabilities": []}).to_string(), + ); + + // Create and set global service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify model is discovered by service + let discovered = cx.update(|cx| { + if let Some(service) = ollama::State::global(cx) { + let models = service.read(cx).available_models(); + models.iter().any(|m| m.name == "test-model:latest") + } else { + false + } + }); + assert!(discovered, "Model should be discovered by service"); + + // Test that switch_ollama_model function can handle discovered models + // This test focuses on the function's ability to find and convert discovered models + // rather than testing file system persistence + let fs = fs::FakeFs::new(cx.background_executor.clone()) as Arc; + + // The key test: the function should be able to process a discovered model + // We test this by verifying the function doesn't panic and can access the service + cx.update(|cx| { + // Verify the service is accessible within the function context + if let Some(service) = ollama::State::global(cx) { + let discovered_models = service.read(cx).available_models(); + let target_model = discovered_models + .iter() + .find(|m| m.name == "test-model:latest"); + + assert!( + target_model.is_some(), + "Target model should be discoverable" + ); + + // Test the conversion logic that switch_ollama_model uses + if let Some(discovered_model) = target_model { + let available_model = language_models::provider::ollama::AvailableModel { + name: discovered_model.name.clone(), + display_name: discovered_model.display_name.clone(), + max_tokens: discovered_model.max_tokens, + keep_alive: discovered_model.keep_alive.clone(), + supports_tools: discovered_model.supports_tools, + supports_images: discovered_model.supports_vision, + supports_thinking: discovered_model.supports_thinking, + }; + + // Verify the conversion worked correctly + assert_eq!(available_model.name, "test-model:latest"); + } + } + + // Call the actual function to ensure it doesn't panic with discovered models + // Note: In a test environment, the file system changes may not persist to + // the global settings, but the function should execute without errors + EditPredictionButton::switch_ollama_model(fs, "test-model:latest".to_string(), cx); + }); + + // Allow any async operations to complete + cx.background_executor.run_until_parked(); + } +} diff --git a/crates/editor/src/edit_prediction_tests.rs b/crates/editor/src/edit_prediction_tests.rs index bba632e81f..f6d7cddd6d 100644 --- a/crates/editor/src/edit_prediction_tests.rs +++ b/crates/editor/src/edit_prediction_tests.rs @@ -562,3 +562,36 @@ impl EditPredictionProvider for FakeNonZedEditPredictionProvider { self.completion.clone() } } + +#[gpui::test] +async fn test_partial_accept_edit_prediction(cx: &mut gpui::TestAppContext) { + init_test(cx, |_| {}); + + let mut cx = EditorTestContext::new(cx).await; + let provider = cx.new(|_| FakeEditPredictionProvider::default()); + assign_editor_completion_provider(provider.clone(), &mut cx); + + cx.set_state("let x = ˇ;"); + + // Propose a completion with multiple words + propose_edits( + &provider, + vec![(Point::new(0, 8)..Point::new(0, 8), "hello world")], + &mut cx, + ); + + cx.update_editor(|editor, window, cx| editor.update_visible_edit_prediction(window, cx)); + + // Verify the completion is shown + cx.assert_editor_state("let x = ˇ;"); + cx.editor(|editor, _, _| { + assert!(editor.has_active_edit_prediction()); + }); + + // Accept partial completion - should accept first word + cx.update_editor(|editor, window, cx| { + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + }); + + cx.assert_editor_state("let x = helloˇ;"); +} diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 80680ae9c0..7bde8e178b 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -9157,6 +9157,7 @@ impl Editor { ) -> IconName { match provider { Some(provider) => match provider.provider.name() { + "ollama" => IconName::AiOllama, "copilot" => IconName::Copilot, "supermaven" => IconName::Supermaven, _ => IconName::ZedPredict, @@ -9206,6 +9207,7 @@ impl Editor { use text::ToPoint as _; if target.text_anchor.to_point(snapshot).row > cursor_point.row { + // For move predictions, still use directional icons Icon::new(IconName::ZedPredictDown) } else { Icon::new(IconName::ZedPredictUp) diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index 0f82d3997f..207663a182 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -215,6 +215,7 @@ pub enum EditPredictionProvider { Copilot, Supermaven, Zed, + Ollama, } impl EditPredictionProvider { @@ -223,7 +224,8 @@ impl EditPredictionProvider { EditPredictionProvider::Zed => true, EditPredictionProvider::None | EditPredictionProvider::Copilot - | EditPredictionProvider::Supermaven => false, + | EditPredictionProvider::Supermaven + | EditPredictionProvider::Ollama => false, } } } diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 3f2d47fba3..8f137d0514 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -1,7 +1,7 @@ use anyhow::{Result, anyhow}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{Stream, TryFutureExt, stream}; -use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; +use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, Subscription, Task}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, @@ -81,7 +81,7 @@ impl State { // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(async move |this, cx| { - let models = get_models(http_client.as_ref(), &api_url, None).await?; + let models = get_models(http_client.as_ref(), &api_url, None, None).await?; let tasks = models .into_iter() @@ -94,7 +94,8 @@ impl State { let api_url = api_url.clone(); async move { let name = model.name.as_str(); - let capabilities = show_model(http_client.as_ref(), &api_url, name).await?; + let capabilities = + show_model(http_client.as_ref(), &api_url, None, name).await?; let ollama_model = ollama::Model::new( name, None, @@ -141,6 +142,29 @@ impl State { } impl OllamaLanguageModelProvider { + pub fn global(cx: &App) -> Option> { + cx.try_global::() + .map(|provider| provider.0.clone()) + } + + pub fn set_global(provider: Entity, cx: &mut App) { + cx.set_global(GlobalOllamaLanguageModelProvider(provider)); + } + + pub fn available_models_for_completion(&self, cx: &App) -> Vec { + self.state.read(cx).available_models.clone() + } + + pub fn http_client(&self) -> Arc { + self.http_client.clone() + } + + pub fn refresh_models(&self, cx: &mut App) { + self.state.update(cx, |state, cx| { + state.restart_fetch_models_task(cx); + }); + } + pub fn new(http_client: Arc, cx: &mut App) -> Self { let this = Self { http_client: http_client.clone(), @@ -676,6 +700,10 @@ impl Render for ConfigurationView { } } +struct GlobalOllamaLanguageModelProvider(Entity); + +impl Global for GlobalOllamaLanguageModelProvider {} + fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool { ollama::OllamaTool::Function { function: OllamaFunctionTool { diff --git a/crates/ollama/Cargo.toml b/crates/ollama/Cargo.toml index 2765f23400..d07ba3d859 100644 --- a/crates/ollama/Cargo.toml +++ b/crates/ollama/Cargo.toml @@ -9,17 +9,42 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/ollama.rs" +path = "src/lib.rs" [features] default = [] schemars = ["dep:schemars"] +test-support = [ + "gpui/test-support", + "http_client/test-support", + "language/test-support", +] [dependencies] anyhow.workspace = true futures.workspace = true +gpui.workspace = true http_client.workspace = true +edit_prediction.workspace = true +language.workspace = true + +log.workspace = true + +project.workspace = true +settings.workspace = true schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true +text.workspace = true workspace-hack.workspace = true + +[dev-dependencies] +client = { workspace = true, features = ["test-support"] } +editor = { workspace = true, features = ["test-support"] } +gpui = { workspace = true, features = ["test-support"] } +http_client = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/ollama/src/lib.rs b/crates/ollama/src/lib.rs new file mode 100644 index 0000000000..4fcc61be95 --- /dev/null +++ b/crates/ollama/src/lib.rs @@ -0,0 +1,8 @@ +mod ollama; +mod ollama_completion_provider; + +pub use ollama::*; +pub use ollama_completion_provider::*; + +#[cfg(any(test, feature = "test-support"))] +pub use ollama::fake; diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 64cd1cc0cb..244446f963 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -6,6 +6,7 @@ use serde_json::Value; use std::time::Duration; pub const OLLAMA_API_URL: &str = "http://localhost:11434"; +pub const OLLAMA_API_KEY_VAR: &str = "OLLAMA_API_KEY"; #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -99,6 +100,39 @@ impl Model { } } +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateRequest { + pub model: String, + pub prompt: String, + pub suffix: Option, + pub stream: bool, + pub options: Option, + pub keep_alive: Option, + pub context: Option>, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateOptions { + pub num_predict: Option, + pub temperature: Option, + pub top_p: Option, + pub stop: Option>, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateResponse { + pub response: String, + pub done: bool, + pub context: Option>, + pub total_duration: Option, + pub load_duration: Option, + pub prompt_eval_count: Option, + pub eval_count: Option, +} + #[derive(Serialize, Deserialize, Debug)] #[serde(tag = "role", rename_all = "lowercase")] pub enum ChatMessage { @@ -309,14 +343,19 @@ pub async fn stream_chat_completion( pub async fn get_models( client: &dyn HttpClient, api_url: &str, + api_key: Option, _: Option, ) -> Result> { let uri = format!("{api_url}/api/tags"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) .header("Accept", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) + } + let request = request_builder.body(AsyncBody::default())?; let mut response = client.send(request).await?; @@ -336,15 +375,25 @@ pub async fn get_models( } /// Fetch details of a model, used to determine model capabilities -pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result { +pub async fn show_model( + client: &dyn HttpClient, + api_url: &str, + api_key: Option, + model: &str, +) -> Result { let uri = format!("{api_url}/api/show"); - let request = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) - .header("Content-Type", "application/json") - .body(AsyncBody::from( - serde_json::json!({ "model": model }).to_string(), - ))?; + .header("Content-Type", "application/json"); + + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")); + } + + let request = request_builder.body(AsyncBody::from( + serde_json::json!({ "model": model }).to_string(), + ))?; let mut response = client.send(request).await?; let mut body = String::new(); @@ -360,10 +409,198 @@ pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Ok(details) } +pub async fn generate( + client: &dyn HttpClient, + api_url: &str, + api_key: Option, + request: GenerateRequest, +) -> Result { + let uri = format!("{api_url}/api/generate"); + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Content-Type", "application/json"); + + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) + } + + let serialized_request = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(serialized_request))?; + + let mut response = match client.send(request).await { + Ok(response) => response, + Err(err) => { + log::error!("Ollama server unavailable at {}: {}", api_url, err); + return Err(err); + } + }; + + let mut body = String::new(); + response.body_mut().read_to_string(&mut body).await?; + + anyhow::ensure!( + response.status().is_success(), + "Failed to connect to Ollama API: {} {}", + response.status(), + body, + ); + + let response: GenerateResponse = + serde_json::from_str(&body).context("Unable to parse Ollama generate response")?; + Ok(response) +} + +#[cfg(any(test, feature = "test-support"))] +pub mod fake { + use super::*; + use crate::ollama_completion_provider::OllamaCompletionProvider; + use gpui::AppContext; + use http_client::{AsyncBody, Response, Url}; + use std::collections::HashMap; + use std::sync::{Arc, Mutex}; + + pub struct FakeHttpClient { + responses: Arc>>, + requests: Arc>>, // (path, body) + } + + impl FakeHttpClient { + pub fn new() -> Self { + Self { + responses: Arc::new(Mutex::new(HashMap::new())), + requests: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn set_response(&self, path: &str, response: String) { + self.responses + .lock() + .unwrap() + .insert(path.to_string(), response); + } + + pub fn set_generate_response(&self, completion_text: &str) { + let response = serde_json::json!({ + "response": completion_text, + "done": true, + "context": [], + "total_duration": 1000000_u64, + "load_duration": 1000000_u64, + "prompt_eval_count": 10, + "prompt_eval_duration": 1000000_u64, + "eval_count": 20, + "eval_duration": 1000000_u64 + }); + self.set_response("/api/generate", response.to_string()); + } + + pub fn set_error(&self, path: &str) { + // Remove any existing response to force an error + self.responses.lock().unwrap().remove(path); + } + + pub fn get_requests(&self) -> Vec<(String, String)> { + self.requests.lock().unwrap().clone() + } + + pub fn clear_requests(&self) { + self.requests.lock().unwrap().clear(); + } + } + + impl HttpClient for FakeHttpClient { + fn type_name(&self) -> &'static str { + "FakeHttpClient" + } + + fn user_agent(&self) -> Option<&http::HeaderValue> { + None + } + + fn proxy(&self) -> Option<&Url> { + None + } + + fn send( + &self, + req: http_client::Request, + ) -> futures::future::BoxFuture<'static, Result, anyhow::Error>> + { + let path = req.uri().path().to_string(); + let responses = Arc::clone(&self.responses); + let requests = Arc::clone(&self.requests); + + Box::pin(async move { + // Store the request + requests.lock().unwrap().push((path.clone(), String::new())); + + let responses = responses.lock().unwrap(); + + if let Some(response_body) = responses.get(&path).cloned() { + let response = Response::builder() + .status(200) + .header("content-type", "application/json") + .body(AsyncBody::from(response_body)) + .unwrap(); + Ok(response) + } else { + Err(anyhow::anyhow!("No mock response set for {}", path)) + } + }) + } + } + + pub struct Ollama; + + impl Ollama { + pub fn fake( + cx: &mut gpui::TestAppContext, + ) -> ( + gpui::Entity, + std::sync::Arc, + ) { + let fake_client = std::sync::Arc::new(FakeHttpClient::new()); + + let provider = + cx.new(|cx| OllamaCompletionProvider::new("qwencoder".to_string(), None, cx)); + + (provider, fake_client) + } + } +} + #[cfg(test)] mod tests { use super::*; + #[test] + fn test_generate_request_with_suffix_serialization() { + let request = GenerateRequest { + model: "qwen2.5-coder:32b".to_string(), + prompt: "def fibonacci(n):".to_string(), + suffix: Some(" return result".to_string()), + stream: false, + options: Some(GenerateOptions { + num_predict: Some(150), + temperature: Some(0.1), + top_p: Some(0.95), + stop: None, + }), + keep_alive: None, + context: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + let parsed: GenerateRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.model, "qwen2.5-coder:32b"); + assert_eq!(parsed.prompt, "def fibonacci(n):"); + assert_eq!(parsed.suffix, Some(" return result".to_string())); + assert!(!parsed.stream); + assert!(parsed.options.is_some()); + } + #[test] fn parse_completion() { let response = serde_json::json!({ @@ -585,4 +822,64 @@ mod tests { assert_eq!(message_images.len(), 1); assert_eq!(message_images[0].as_str().unwrap(), base64_image); } + + #[test] + fn test_generate_request_with_api_key_serialization() { + let request = GenerateRequest { + model: "qwen2.5-coder:32b".to_string(), + prompt: "def fibonacci(n):".to_string(), + suffix: Some(" return result".to_string()), + stream: false, + options: Some(GenerateOptions { + num_predict: Some(150), + temperature: Some(0.1), + top_p: Some(0.95), + stop: None, + }), + keep_alive: None, + context: None, + }; + + // Test with API key + let json = serde_json::to_string(&request).unwrap(); + let parsed: GenerateRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.model, "qwen2.5-coder:32b"); + assert_eq!(parsed.prompt, "def fibonacci(n):"); + assert_eq!(parsed.suffix, Some(" return result".to_string())); + assert!(!parsed.stream); + assert!(parsed.options.is_some()); + + // Note: The API key parameter is passed to the generate function itself, + // not included in the GenerateRequest struct that gets serialized to JSON + } + + #[test] + fn test_generate_request_with_stop_tokens() { + let request = GenerateRequest { + model: "codellama:7b-code".to_string(), + prompt: "def fibonacci(n):".to_string(), + suffix: Some(" return result".to_string()), + stream: false, + options: Some(GenerateOptions { + num_predict: Some(150), + temperature: Some(0.1), + top_p: Some(0.95), + stop: Some(vec!["".to_string()]), + }), + keep_alive: None, + context: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + let parsed: GenerateRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.model, "codellama:7b-code"); + assert_eq!(parsed.prompt, "def fibonacci(n):"); + assert_eq!(parsed.suffix, Some(" return result".to_string())); + assert!(!parsed.stream); + assert!(parsed.options.is_some()); + let options = parsed.options.unwrap(); + assert_eq!(options.stop, Some(vec!["".to_string()])); + } } diff --git a/crates/ollama/src/ollama_completion_provider.rs b/crates/ollama/src/ollama_completion_provider.rs new file mode 100644 index 0000000000..8ae8054fcf --- /dev/null +++ b/crates/ollama/src/ollama_completion_provider.rs @@ -0,0 +1,1260 @@ +use crate::{GenerateOptions, GenerateRequest, Model, generate}; +use anyhow::{Context as AnyhowContext, Result}; +use futures::StreamExt; +use std::{path::Path, sync::Arc, time::Duration}; + +use edit_prediction::{Direction, EditPrediction, EditPredictionProvider}; +use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task}; +use http_client::HttpClient; +use language::{Anchor, Buffer, ToOffset}; +use settings::SettingsStore; + +use project::Project; + +pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); + +// Structure for passing settings model data without circular dependencies +#[derive(Clone, Debug)] +pub struct SettingsModel { + pub name: String, + pub display_name: Option, + pub max_tokens: u64, + pub supports_tools: Option, + pub supports_images: Option, + pub supports_thinking: Option, +} + +impl SettingsModel { + pub fn to_model(&self) -> Model { + Model::new( + &self.name, + self.display_name.as_deref(), + Some(self.max_tokens), + self.supports_tools, + self.supports_images, + self.supports_thinking, + ) + } +} + +// Global Ollama service for managing models across all providers +pub struct State { + http_client: Arc, + api_url: String, + api_key: Option, + available_models: Vec, + fetch_models_task: Option>>, + _settings_subscription: Subscription, +} + +impl State { + pub fn new( + http_client: Arc, + api_url: String, + api_key: Option, + cx: &mut App, + ) -> Entity { + cx.new(|cx| { + let subscription = cx.observe_global::({ + move |this: &mut State, cx| { + this.restart_fetch_models_task(cx); + } + }); + + let mut service = Self { + http_client, + api_url, + api_key, + available_models: Vec::new(), + fetch_models_task: None, + _settings_subscription: subscription, + }; + + // TODO: why a secod refresh here? + service.restart_fetch_models_task(cx); + service + }) + } + + pub fn global(cx: &App) -> Option> { + cx.try_global::() + .map(|service| service.0.clone()) + } + + pub fn set_global(service: Entity, cx: &mut App) { + cx.set_global(GlobalOllamaState(service)); + } + + pub fn available_models(&self) -> &[Model] { + &self.available_models + } + + pub fn refresh_models(&mut self, cx: &mut Context) { + self.restart_fetch_models_task(cx); + } + + pub fn set_settings_models( + &mut self, + settings_models: Vec, + cx: &mut Context, + ) { + // Convert settings models to our Model type + self.available_models = settings_models + .into_iter() + .map(|settings_model| settings_model.to_model()) + .collect(); + self.restart_fetch_models_task(cx); + } + + pub fn set_api_key(&mut self, api_key: Option, cx: &mut Context) { + if self.api_key != api_key { + self.api_key = api_key; + self.restart_fetch_models_task(cx); + } + } + + fn restart_fetch_models_task(&mut self, cx: &mut Context) { + self.fetch_models_task = Some(self.fetch_models(cx)); + } + + fn fetch_models(&mut self, cx: &mut Context) -> Task> { + let http_client = Arc::clone(&self.http_client); + let api_url = self.api_url.clone(); + let api_key = self.api_key.clone(); + + cx.spawn(async move |this, cx| { + // Get the current settings models to merge with API models + let settings_models = this.update(cx, |this, _cx| { + // Get just the names of models from settings to avoid duplicates + this.available_models + .iter() + .map(|m| m.name.clone()) + .collect::>() + })?; + + // Fetch models from API + let api_models = match crate::get_models( + http_client.as_ref(), + &api_url, + api_key.clone(), + None, + ) + .await + { + Ok(models) => models, + Err(_) => return Ok(()), // Silently fail if API is unavailable + }; + + let tasks = api_models + .into_iter() + // Filter out embedding models + .filter(|model| !model.name.contains("-embed")) + // Filter out models that are already defined in settings + .filter(|model| !settings_models.contains(&model.name)) + .map(|model| { + let http_client = Arc::clone(&http_client); + let api_url = api_url.clone(); + let api_key = api_key.clone(); + async move { + let name = model.name.as_str(); + let capabilities = + crate::show_model(http_client.as_ref(), &api_url, api_key, name) + .await?; + let ollama_model = Model::new( + name, + None, + None, + Some(capabilities.supports_tools()), + Some(capabilities.supports_vision()), + Some(capabilities.supports_thinking()), + ); + Ok(ollama_model) + } + }); + + // Rate-limit capability fetches for API-discovered models + let api_discovered_models: Vec<_> = futures::stream::iter(tasks) + .buffer_unordered(5) + .collect::>>() + .await + .into_iter() + .collect::>>() + .unwrap_or_default(); + + this.update(cx, |this, cx| { + // Append API-discovered models to existing settings models + this.available_models.extend(api_discovered_models); + // Sort all models by name + this.available_models.sort_by(|a, b| a.name.cmp(&b.name)); + cx.notify(); + })?; + + Ok(()) + }) + } +} + +struct GlobalOllamaState(Entity); + +impl Global for GlobalOllamaState {} + +// TODO refactor to OllamaEditPredictionProvider +pub struct OllamaCompletionProvider { + model: String, + buffer_id: Option, + file_extension: Option, + current_completion: Option, + pending_refresh: Option>>, + api_key: Option, + _service_subscription: Option, +} + +impl OllamaCompletionProvider { + pub fn new(model: String, api_key: Option, cx: &mut Context) -> Self { + // Update the global service with the API key if one is provided + if let Some(service) = State::global(cx) { + service.update(cx, |service, cx| { + service.set_api_key(api_key.clone(), cx); + }); + } + + let subscription = if let Some(service) = State::global(cx) { + Some(cx.observe(&service, |_this, _service, cx| { + cx.notify(); + })) + } else { + None + }; + + Self { + model, + buffer_id: None, + file_extension: None, + current_completion: None, + pending_refresh: None, + api_key, + _service_subscription: subscription, + } + } + + pub fn available_models(&self, cx: &App) -> Vec { + if let Some(service) = State::global(cx) { + service.read(cx).available_models().to_vec() + } else { + Vec::new() + } + } + + pub fn refresh_models(&self, cx: &mut App) { + if let Some(service) = State::global(cx) { + service.update(cx, |service, cx| { + service.refresh_models(cx); + }); + } + } + + /// Updates the model used by this provider + pub fn update_model(&mut self, model: String) { + self.model = model; + } + + /// Updates the file extension used by this provider + pub fn update_file_extension(&mut self, new_file_extension: String) { + self.file_extension = Some(new_file_extension); + } + + fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) { + let cursor_offset = cursor_position.to_offset(buffer); + let text = buffer.text(); + + // Get reasonable context around cursor + let context_size = 2000; // 2KB before and after cursor + + let start = cursor_offset.saturating_sub(context_size); + let end = (cursor_offset + context_size).min(text.len()); + + let prefix = text[start..cursor_offset].to_string(); + let suffix = text[cursor_offset..end].to_string(); + + (prefix, suffix) + } + + /// Get stop tokens for the current model + /// For now we only handle the case for codellama:7b-code model + /// that we found was including the stop token in the completion suggestion. + /// We wanted to avoid going down this route and let Ollama abstract all template tokens away. + /// But apparently, and surprisingly for a llama model, Ollama misses this case. + fn get_stop_tokens(&self) -> Option> { + if self.model.contains("codellama") && self.model.contains("code") { + Some(vec!["".to_string()]) + } else { + None + } + } +} + +impl EditPredictionProvider for OllamaCompletionProvider { + fn name() -> &'static str { + "ollama" + } + + fn display_name() -> &'static str { + "Ollama" + } + + fn show_completions_in_menu() -> bool { + true + } + + fn is_enabled(&self, _buffer: &Entity, _cursor_position: Anchor, _cx: &App) -> bool { + true + } + + fn is_refreshing(&self) -> bool { + self.pending_refresh.is_some() + } + + fn refresh( + &mut self, + project: Option>, + buffer: Entity, + cursor_position: Anchor, + debounce: bool, + cx: &mut Context, + ) { + // Get API settings from the global Ollama service or fallback + let (http_client, api_url) = if let Some(service) = State::global(cx) { + let service_ref = service.read(cx); + (service_ref.http_client.clone(), service_ref.api_url.clone()) + } else { + // Fallback if global service isn't available + ( + project + .as_ref() + .map(|p| p.read(cx).client().http_client() as Arc) + .unwrap_or_else(|| { + Arc::new(http_client::BlockedHttpClient::new()) as Arc + }), + crate::OLLAMA_API_URL.to_string(), + ) + }; + + self.pending_refresh = Some(cx.spawn(async move |this, cx| { + if debounce { + cx.background_executor() + .timer(OLLAMA_DEBOUNCE_TIMEOUT) + .await; + } + + let (prefix, suffix) = this.update(cx, |this, cx| { + let buffer_snapshot = buffer.read(cx); + this.buffer_id = Some(buffer.entity_id()); + this.file_extension = buffer_snapshot.file().and_then(|file| { + Some( + Path::new(file.file_name(cx)) + .extension()? + .to_str()? + .to_string(), + ) + }); + this.extract_context(buffer_snapshot, cursor_position) + })?; + + let (model, api_key) = + this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?; + + let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?; + + let request = GenerateRequest { + model, + prompt: prefix, + suffix: Some(suffix), + stream: false, + options: Some(GenerateOptions { + num_predict: Some(150), // Reasonable completion length + temperature: Some(0.1), // Low temperature for more deterministic results + top_p: Some(0.95), + stop: stop_tokens, + }), + keep_alive: None, + context: None, + }; + + let response = generate(http_client.as_ref(), &api_url, api_key, request) + .await + .context("Failed to get completion from Ollama"); + + this.update(cx, |this, cx| { + this.pending_refresh = None; + match response { + Ok(response) if !response.response.trim().is_empty() => { + this.current_completion = Some(response.response); + } + _ => { + this.current_completion = None; + } + } + cx.notify(); + })?; + + Ok(()) + })); + } + + fn cycle( + &mut self, + _buffer: Entity, + _cursor_position: Anchor, + _direction: Direction, + _cx: &mut Context, + ) { + // Ollama doesn't provide multiple completions in a single request + // Could be implemented by making multiple requests with different temperatures + // or by using different models + } + + fn accept(&mut self, _cx: &mut Context) { + self.current_completion = None; + } + + fn discard(&mut self, _cx: &mut Context) { + self.current_completion = None; + } + + fn suggest( + &mut self, + buffer: &Entity, + cursor_position: Anchor, + cx: &mut Context, + ) -> Option { + let buffer_id = buffer.entity_id(); + if Some(buffer_id) != self.buffer_id { + return None; + } + + let completion_text = self.current_completion.as_ref()?.clone(); + + if completion_text.trim().is_empty() { + return None; + } + + let buffer_snapshot = buffer.read(cx); + let cursor_offset = cursor_position.to_offset(buffer_snapshot); + + // Get text before cursor to check what's already been typed + let text_before_cursor = buffer_snapshot + .text_for_range(0..cursor_offset) + .collect::(); + + // Find how much of the completion has already been typed by checking + // if the text before the cursor ends with a prefix of our completion + let mut prefix_len = 0; + for i in 1..=completion_text.len().min(text_before_cursor.len()) { + if text_before_cursor.ends_with(&completion_text[..i]) { + prefix_len = i; + } + } + + // Only suggest the remaining part of the completion + let remaining_completion = &completion_text[prefix_len..]; + + if remaining_completion.trim().is_empty() { + return None; + } + + let position = cursor_position.bias_right(buffer_snapshot); + + Some(EditPrediction { + id: None, + edits: vec![(position..position, remaining_completion.to_string())], + edit_preview: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::fake::FakeHttpClient; + + use gpui::{AppContext, TestAppContext}; + + use client; + use language::Buffer; + use project::Project; + use settings::SettingsStore; + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + theme::init(theme::LoadThemes::JustBase, cx); + client::init_settings(cx); + language::init(cx); + editor::init_settings(cx); + Project::init_settings(cx); + workspace::init_settings(cx); + }); + } + + /// Test the complete Ollama completion flow from refresh to suggestion + #[gpui::test] + fn test_get_stop_tokens(cx: &mut TestAppContext) { + init_test(cx); + + // Test CodeLlama code model gets stop tokens + let codellama_provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx)) + }); + + codellama_provider.read_with(cx, |provider, _| { + assert_eq!(provider.get_stop_tokens(), Some(vec!["".to_string()])); + }); + + // Test non-CodeLlama model doesn't get stop tokens + let qwen_provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + qwen_provider.read_with(cx, |provider, _| { + assert_eq!(provider.get_stop_tokens(), None); + }); + } + + #[gpui::test] + async fn test_model_discovery(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + + // Mock /api/tags response (list models) + let models_response = serde_json::json!({ + "models": [ + { + "name": "qwen2.5-coder:3b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "qwen2", + "families": ["qwen2"], + "parameter_size": "3B", + "quantization_level": "Q4_0" + } + }, + { + "name": "codellama:7b-code", + "modified_at": "2024-01-01T00:00:00Z", + "size": 2000000, + "digest": "def456", + "details": { + "format": "gguf", + "family": "codellama", + "families": ["codellama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + }, + { + "name": "nomic-embed-text", + "modified_at": "2024-01-01T00:00:00Z", + "size": 500000, + "digest": "ghi789", + "details": { + "format": "gguf", + "family": "nomic-embed", + "families": ["nomic-embed"], + "parameter_size": "137M", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + + // Mock /api/show responses for model capabilities + let qwen_capabilities = serde_json::json!({ + "capabilities": ["tools", "thinking"] + }); + + let _codellama_capabilities = serde_json::json!({ + "capabilities": [] + }); + + fake_http_client.set_response("/api/show", qwen_capabilities.to_string()); + + // Create global Ollama service for testing + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + // Set it as global + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Create completion provider + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + // Wait for model discovery to complete + cx.background_executor.run_until_parked(); + + // Verify models were discovered through the global provider + provider.read_with(cx, |provider, cx| { + let models = provider.available_models(cx); + assert_eq!(models.len(), 2); // Should exclude nomic-embed-text + + let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + assert!(model_names.contains(&"codellama:7b-code")); + assert!(model_names.contains(&"qwen2.5-coder:3b")); + assert!(!model_names.contains(&"nomic-embed-text")); + }); + } + + #[gpui::test] + async fn test_model_discovery_api_failure(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client that returns errors + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + fake_http_client.set_error("Connection refused"); + + // Create global Ollama service that will fail + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Create completion provider + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + // Wait for model discovery to complete (with failure) + cx.background_executor.run_until_parked(); + + // Verify graceful handling - should have empty model list + provider.read_with(cx, |provider, cx| { + let models = provider.available_models(cx); + assert_eq!(models.len(), 0); + }); + } + + #[gpui::test] + async fn test_refresh_models(cx: &mut TestAppContext) { + init_test(cx); + + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + + // Initially return empty model list + let empty_response = serde_json::json!({"models": []}); + fake_http_client.set_response("/api/tags", empty_response.to_string()); + + // Create global Ollama service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx)) + }); + + cx.background_executor.run_until_parked(); + + // Verify initially empty + provider.read_with(cx, |provider, cx| { + assert_eq!(provider.available_models(cx).len(), 0); + }); + + // Update mock to return models + let models_response = serde_json::json!({ + "models": [ + { + "name": "qwen2.5-coder:7b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "qwen2", + "families": ["qwen2"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + + let capabilities = serde_json::json!({ + "capabilities": ["tools", "thinking"] + }); + + fake_http_client.set_response("/api/show", capabilities.to_string()); + + // Trigger refresh + provider.update(cx, |provider, cx| { + provider.refresh_models(cx); + }); + + cx.background_executor.run_until_parked(); + + // Verify models were refreshed + provider.read_with(cx, |provider, cx| { + let models = provider.available_models(cx); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "qwen2.5-coder:7b"); + }); + } + + #[gpui::test] + async fn test_full_completion_flow(cx: &mut TestAppContext) { + init_test(cx); + + // Create a buffer with realistic code content + let buffer = cx.update(|cx| cx.new(|cx| Buffer::local("fn test() {\n \n}", cx))); + let cursor_position = buffer.read_with(cx, |buffer, _| { + buffer.anchor_before(11) // Position in the middle of the function + }); + + // Create fake HTTP client and set up global service + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + fake_http_client.set_generate_response("println!(\"Hello\");"); + + // Create global Ollama service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Create provider + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + // Trigger completion refresh (no debounce for test speed) + provider.update(cx, |provider, cx| { + provider.refresh(None, buffer.clone(), cursor_position, false, cx); + }); + + // Wait for completion task to complete + cx.background_executor.run_until_parked(); + + // Verify completion was processed and stored + provider.read_with(cx, |provider, _cx| { + assert!(provider.current_completion.is_some()); + assert_eq!( + provider.current_completion.as_ref().unwrap(), + "println!(\"Hello\");" + ); + assert!(!provider.is_refreshing()); + }); + + // Test suggestion logic returns the completion + let suggestion = cx.update(|cx| { + provider.update(cx, |provider, cx| { + provider.suggest(&buffer, cursor_position, cx) + }) + }); + + assert!(suggestion.is_some()); + let suggestion = suggestion.unwrap(); + assert_eq!(suggestion.edits.len(), 1); + assert_eq!(suggestion.edits[0].1, "println!(\"Hello\");"); + + // Verify acceptance clears the completion + provider.update(cx, |provider, cx| { + provider.accept(cx); + }); + + provider.read_with(cx, |provider, _cx| { + assert!(provider.current_completion.is_none()); + }); + } + + /// Test that partial typing is handled correctly - only suggests untyped portion + #[gpui::test] + async fn test_partial_typing_handling(cx: &mut TestAppContext) { + init_test(cx); + + // Create buffer where user has partially typed "vec" + let buffer = cx.update(|cx| cx.new(|cx| Buffer::local("let result = vec", cx))); + let cursor_position = buffer.read_with(cx, |buffer, _| { + buffer.anchor_after(16) // After "vec" + }); + + // Create fake HTTP client and set up global service + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + + // Create global Ollama service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Create provider + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + // Configure response that starts with what user already typed + fake_http_client.set_generate_response("vec![1, 2, 3]"); + + provider.update(cx, |provider, cx| { + provider.refresh(None, buffer.clone(), cursor_position, false, cx); + }); + + cx.background_executor.run_until_parked(); + + // Should suggest only the remaining part after "vec" + let suggestion = cx.update(|cx| { + provider.update(cx, |provider, cx| { + provider.suggest(&buffer, cursor_position, cx) + }) + }); + + // Verify we get a reasonable suggestion + if let Some(suggestion) = suggestion { + assert_eq!(suggestion.edits.len(), 1); + assert!(suggestion.edits[0].1.contains("1, 2, 3")); + } + } + + #[gpui::test] + async fn test_accept_partial_ollama_suggestion(cx: &mut TestAppContext) { + init_test(cx); + + let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await; + + // Create fake HTTP client and set up global service + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + fake_http_client.set_generate_response("vec![hello, world]"); + + // Create global Ollama service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Create provider + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + // Set up the editor with the Ollama provider + editor_cx.update_editor(|editor, window, cx| { + editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); + }); + + // Set initial state + editor_cx.set_state("let items = ˇ"); + + // Trigger the completion through the provider + let buffer = + editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone()); + let cursor_position = editor_cx.buffer_snapshot().anchor_after(12); + + provider.update(cx, |provider, cx| { + provider.refresh(None, buffer, cursor_position, false, cx); + }); + + cx.background_executor.run_until_parked(); + + editor_cx.update_editor(|editor, window, cx| { + editor.refresh_edit_prediction(false, true, window, cx); + }); + + cx.background_executor.run_until_parked(); + + editor_cx.update_editor(|editor, window, cx| { + // Verify we have an active completion + assert!(editor.has_active_edit_prediction()); + + // The display text should show the full completion + assert_eq!(editor.display_text(cx), "let items = vec![hello, world]"); + // But the actual text should only show what's been typed + assert_eq!(editor.text(cx), "let items = "); + + // Accept first partial - should accept "vec" (alphabetic characters) + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + + // Assert the buffer now contains the first partially accepted text + assert_eq!(editor.text(cx), "let items = vec"); + // Completion should still be active for remaining text + assert!(editor.has_active_edit_prediction()); + + // Accept second partial - should accept "![" (non-alphabetic characters) + editor.accept_partial_edit_prediction(&Default::default(), window, cx); + + // Assert the buffer now contains both partial acceptances + assert_eq!(editor.text(cx), "let items = vec!["); + // Completion should still be active for remaining text + assert!(editor.has_active_edit_prediction()); + }); + } + + #[gpui::test] + async fn test_completion_invalidation(cx: &mut TestAppContext) { + init_test(cx); + + let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await; + + // Create fake HTTP client and set up global service + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + fake_http_client.set_generate_response("bar"); + + // Create global Ollama service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Create provider + let provider = cx.update(|cx| { + cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx)) + }); + + // Set up the editor with the Ollama provider + editor_cx.update_editor(|editor, window, cx| { + editor.set_edit_prediction_provider(Some(provider.clone()), window, cx); + }); + + editor_cx.set_state("fooˇ"); + + // Trigger the completion through the provider + let buffer = + editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone()); + let cursor_position = editor_cx.buffer_snapshot().anchor_after(3); // After "foo" + + provider.update(cx, |provider, cx| { + provider.refresh(None, buffer, cursor_position, false, cx); + }); + + cx.background_executor.run_until_parked(); + + editor_cx.update_editor(|editor, window, cx| { + editor.refresh_edit_prediction(false, true, window, cx); + }); + + cx.background_executor.run_until_parked(); + + editor_cx.update_editor(|editor, window, cx| { + assert!(editor.has_active_edit_prediction()); + assert_eq!(editor.display_text(cx), "foobar"); + assert_eq!(editor.text(cx), "foo"); + + // Backspace within the original text - completion should remain + editor.backspace(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); + assert_eq!(editor.display_text(cx), "fobar"); + assert_eq!(editor.text(cx), "fo"); + + editor.backspace(&Default::default(), window, cx); + assert!(editor.has_active_edit_prediction()); + assert_eq!(editor.display_text(cx), "fbar"); + assert_eq!(editor.text(cx), "f"); + + // This backspace removes all original text - should invalidate completion + editor.backspace(&Default::default(), window, cx); + assert!(!editor.has_active_edit_prediction()); + assert_eq!(editor.display_text(cx), ""); + assert_eq!(editor.text(cx), ""); + }); + } + + #[gpui::test] + async fn test_settings_model_merging(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client that returns some API models + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + + // Mock /api/tags response (list models) + let models_response = serde_json::json!({ + "models": [ + { + "name": "api-model-1", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + }, + { + "name": "shared-model", + "modified_at": "2024-01-01T00:00:00Z", + "size": 2000000, + "digest": "def456", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "13B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + + // Mock /api/show responses for each model + let show_response = serde_json::json!({ + "capabilities": ["tools", "vision"] + }); + fake_http_client.set_response("/api/show", show_response.to_string()); + + // Create service + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + // Add settings models (including one that overlaps with API) + let settings_models = vec![ + SettingsModel { + name: "custom-model-1".to_string(), + display_name: Some("Custom Model 1".to_string()), + max_tokens: 4096, + supports_tools: Some(true), + supports_images: Some(false), + supports_thinking: Some(false), + }, + SettingsModel { + name: "shared-model".to_string(), // This should take precedence over API + display_name: Some("Custom Shared Model".to_string()), + max_tokens: 8192, + supports_tools: Some(true), + supports_images: Some(true), + supports_thinking: Some(true), + }, + ]; + + cx.update(|cx| { + service.update(cx, |service, cx| { + service.set_settings_models(settings_models, cx); + }); + }); + + // Wait for models to be fetched and merged + cx.run_until_parked(); + + // Verify merged models + let models = cx.update(|cx| service.read(cx).available_models().to_vec()); + + assert_eq!(models.len(), 3); // 2 settings models + 1 unique API model + + // Models should be sorted alphabetically, so check by name + let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + assert_eq!( + model_names, + vec!["api-model-1", "custom-model-1", "shared-model"] + ); + + // Check custom model from settings + let custom_model = models.iter().find(|m| m.name == "custom-model-1").unwrap(); + assert_eq!( + custom_model.display_name, + Some("Custom Model 1".to_string()) + ); + assert_eq!(custom_model.max_tokens, 4096); + + // Settings model should override API model for shared-model + let shared_model = models.iter().find(|m| m.name == "shared-model").unwrap(); + assert_eq!( + shared_model.display_name, + Some("Custom Shared Model".to_string()) + ); + assert_eq!(shared_model.max_tokens, 8192); + assert_eq!(shared_model.supports_tools, Some(true)); + assert_eq!(shared_model.supports_vision, Some(true)); + assert_eq!(shared_model.supports_thinking, Some(true)); + + // API-only model should be included + let api_model = models.iter().find(|m| m.name == "api-model-1").unwrap(); + assert!(api_model.display_name.is_none()); // API models don't have custom display names + } + + #[gpui::test] + async fn test_api_key_passed_to_requests(cx: &mut TestAppContext) { + init_test(cx); + + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Set up responses for model discovery with API key + fake_http_client.set_response( + "/api/tags", + serde_json::json!({ + "models": [ + { + "name": "qwen2.5-coder:3b", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "qwen2.5", + "families": ["qwen2.5"], + "parameter_size": "3B", + "quantization_level": "Q4_0" + } + } + ] + }) + .to_string(), + ); + + // Set up show model response + fake_http_client.set_response( + "/api/show", + serde_json::json!({ + "capabilities": { + "tools": true, + "vision": false, + "thinking": false + } + }) + .to_string(), + ); + + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + Some("test-api-key".to_string()), + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Wait for model fetching to complete + cx.background_executor.run_until_parked(); + + // Verify that requests were made + let requests = fake_http_client.get_requests(); + assert!(!requests.is_empty(), "Expected HTTP requests to be made"); + + // Note: We can't easily test the Authorization header with the current FakeHttpClient + // implementation, but the important thing is that the API key gets passed through + // to the HTTP requests without panicking. + } + + #[gpui::test] + async fn test_api_key_update_triggers_refresh(cx: &mut TestAppContext) { + init_test(cx); + + let fake_http_client = Arc::new(FakeHttpClient::new()); + + // Set up initial response + fake_http_client.set_response( + "/api/tags", + serde_json::json!({ + "models": [] + }) + .to_string(), + ); + + let service = cx.update(|cx| { + State::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + None, + cx, + ) + }); + + cx.update(|cx| { + State::set_global(service.clone(), cx); + }); + + // Clear initial requests + fake_http_client.clear_requests(); + + // Update API key + service.update(cx, |service, cx| { + service.set_api_key(Some("new-api-key".to_string()), cx); + }); + + // Wait for refresh to complete + cx.background_executor.run_until_parked(); + + // Verify new requests were made + let requests = fake_http_client.get_requests(); + assert!( + !requests.is_empty(), + "Expected new requests after API key update" + ); + } +} diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 6f4ead9ebb..d2595623df 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -82,6 +82,7 @@ image_viewer.workspace = true indoc.workspace = true edit_prediction_button.workspace = true inspector_ui.workspace = true +ollama.workspace = true install_cli.workspace = true jj_ui.workspace = true journal.workspace = true diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index bc2d757fd1..d994962e7e 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -3,8 +3,11 @@ use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; + use language::language_settings::{EditPredictionProvider, all_language_settings}; -use settings::SettingsStore; +use language_models::AllLanguageModelSettings; +use ollama::{OLLAMA_API_KEY_VAR, OllamaCompletionProvider, SettingsModel, State}; +use settings::{Settings as _, SettingsStore}; use std::{cell::RefCell, rc::Rc, sync::Arc}; use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; @@ -12,6 +15,33 @@ use workspace::Workspace; use zeta::{ProviderDataCollection, ZetaEditPredictionProvider}; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { + // Initialize global Ollama service + let (api_url, settings_models) = { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + let api_url = settings.api_url.clone(); + let settings_models: Vec = settings + .available_models + .iter() + .map(|model| SettingsModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: model.supports_tools, + supports_images: model.supports_images, + supports_thinking: model.supports_thinking, + }) + .collect(); + (api_url, settings_models) + }; + + let ollama_service = State::new(client.http_client(), api_url, None, cx); + + ollama_service.update(cx, |service, cx| { + service.set_settings_models(settings_models, cx); + }); + + State::set_global(ollama_service, cx); + let editors: Rc, AnyWindowHandle>>> = Rc::default(); cx.observe_new({ let editors = editors.clone(); @@ -89,6 +119,27 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { user_store.clone(), cx, ); + } else if provider == EditPredictionProvider::Ollama { + // Update global Ollama service when settings change + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + if let Some(service) = State::global(cx) { + let settings_models: Vec = settings + .available_models + .iter() + .map(|model| SettingsModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: model.supports_tools, + supports_images: model.supports_images, + supports_thinking: model.supports_thinking, + }) + .collect(); + + service.update(cx, |service, cx| { + service.set_settings_models(settings_models, cx); + }); + } } } }) @@ -229,5 +280,81 @@ fn assign_edit_prediction_provider( editor.set_edit_prediction_provider(Some(provider), window, cx); } } + EditPredictionProvider::Ollama => { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + let api_key = std::env::var(OLLAMA_API_KEY_VAR).ok(); + + // Get model from settings or use discovered models + let model = if let Some(first_model) = settings.available_models.first() { + Some(first_model.name.clone()) + } else if let Some(service) = State::global(cx) { + // Use first discovered model + service + .read(cx) + .available_models() + .first() + .map(|m| m.name.clone()) + } else { + None + }; + + if let Some(model) = model { + let provider = cx.new(|cx| OllamaCompletionProvider::new(model, api_key, cx)); + editor.set_edit_prediction_provider(Some(provider), window, cx); + } else { + log::error!( + "No Ollama models available. Please configure models in settings or pull models using 'ollama pull '" + ); + editor.set_edit_prediction_provider::(None, window, cx); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::zed::tests::init_test; + use editor::{Editor, MultiBuffer}; + use gpui::TestAppContext; + use language::Buffer; + use language_models::{AllLanguageModelSettings, provider::ollama::OllamaSettings}; + + #[gpui::test] + async fn test_assign_edit_prediction_provider_with_no_ollama_models(cx: &mut TestAppContext) { + let app_state = init_test(cx); + + let buffer = cx.new(|cx| Buffer::local("test content", cx)); + let multibuffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx)); + let (editor, cx) = + cx.add_window_view(|window, cx| Editor::for_multibuffer(multibuffer, None, window, cx)); + + // Override settings to have empty available_models + cx.update(|_, cx| { + let new_settings = AllLanguageModelSettings { + ollama: OllamaSettings { + api_url: "http://localhost:11434".to_string(), + available_models: vec![], // Empty models list + }, + ..Default::default() + }; + AllLanguageModelSettings::override_global(new_settings, cx); + }); + + // Call assign_edit_prediction_provider with Ollama provider + // This should complete without panicking even when no models are available + let result = editor.update_in(cx, |editor, window, cx| { + assign_edit_prediction_provider( + editor, + language::language_settings::EditPredictionProvider::Ollama, + &app_state.client, + app_state.user_store.clone(), + window, + cx, + ) + }); + + // Assert that assign_edit_prediction_provider returns () + assert_eq!(result, ()); } } diff --git a/docs/src/ai/edit-prediction.md b/docs/src/ai/edit-prediction.md index 7843b08ff7..11599402b1 100644 --- a/docs/src/ai/edit-prediction.md +++ b/docs/src/ai/edit-prediction.md @@ -44,7 +44,7 @@ On Linux, `alt-tab` is often used by the window manager for switching windows, s {#action editor::AcceptPartialEditPrediction} ({#kb editor::AcceptPartialEditPrediction}) can be used to accept the current edit prediction up to the next word boundary. -See the [Configuring GitHub Copilot](#github-copilot) and [Configuring Supermaven](#supermaven) sections below for configuration of other providers. Only text insertions at the current cursor are supported for these providers, whereas the Zeta model provides multiple predictions including deletions. +See the [Configuring GitHub Copilot](#github-copilot), [Configuring Supermaven](#supermaven), and [Configuring Ollama](#ollama) sections below for configuration of other providers. Only text insertions at the current cursor are supported for these providers, whereas the Zeta model provides multiple predictions including deletions. ## Configuring Edit Prediction Keybindings {#edit-predictions-keybinding} @@ -304,6 +304,74 @@ To use Supermaven as your provider, set this within `settings.json`: You should be able to sign-in to Supermaven by clicking on the Supermaven icon in the status bar and following the setup instructions. +## Configuring Ollama {#ollama} + +To use Ollama as your edit prediction provider, set this within `settings.json`: + +```json +{ + "features": { + "edit_prediction_provider": "ollama" + } +} +``` + +### Setup + +1. Download and install Ollama from [ollama.com/download](https://ollama.com/download) +2. Pull completion-capable models, for example: + + ```sh + ollama pull qwen2.5-coder:3b + ollama pull codellama:7b + ``` + +3. Ensure Ollama is running: + + ```sh + ollama serve + ``` + +4. Configure the model in your language model settings + +The Edit Prediction menu will automatically detect available models. When one is newly selected in the menu, it will be added to your `settings.json`, and put at the top of the list. You can then manually configure it in the settings file if you need more control. + + + +```json +{ + "language_models": { + "ollama": { + "api_url": "http://localhost:11434", + "available_models": [ + { + "name": "qwen2.5-coder:3b", + "display_name": "Qwen 2.5 Coder 3B", + "max_tokens": 8192 + }, + { + "name": "codellama:7b", + "display_name": "CodeLlama 7B", + "max_tokens": 8192 + } + ] + } + } +} +``` + +You can also switch between them in the menu, and the order of the models in the settings file will be updated behind the scenes. + +The settings allows for configuring Ollama's API url too, so one can use Ollama either locally or hosted. The Edit Prediction menu includes a shortcut for it that will open the settings file where the url is set. + +### Authentication + +Ollama itself doesn't require an API key, but when running it remotely it's a good idea and common practice to setup a proxy server in front of it that does. When sending edit prediction requests to it, Zed will forward the API key as an authentication header so the proxy can authenticate against it: + +```bash +export OLLAMA_API_KEY=your_api_key_here +``` + ## See also You may also use the [Agent Panel](./agent-panel.md) or the [Inline Assistant](./inline-assistant.md) to interact with language models, see the [AI documentation](./overview.md) for more information on the other AI features in Zed. diff --git a/docs/src/completions.md b/docs/src/completions.md index d14cf61d82..5e7c43ac2e 100644 --- a/docs/src/completions.md +++ b/docs/src/completions.md @@ -3,7 +3,7 @@ Zed supports two sources for completions: 1. "Code Completions" provided by Language Servers (LSPs) automatically installed by Zed or via [Zed Language Extensions](languages.md). -2. "Edit Predictions" provided by Zed's own Zeta model or by external providers like [GitHub Copilot](#github-copilot) or [Supermaven](#supermaven). +2. "Edit Predictions" provided by Zed's own Zeta model or by external providers like [GitHub Copilot](#github-copilot), [Supermaven](#supermaven), or [Ollama](#ollama). ## Language Server Code Completions {#code-completions}