diff --git a/Cargo.lock b/Cargo.lock index b1df6c1e3a..f16b67f49a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10945,6 +10945,7 @@ name = "ollama" version = "0.1.0" dependencies = [ "anyhow", + "client", "editor", "futures 0.3.31", "gpui", diff --git a/crates/ollama/Cargo.toml b/crates/ollama/Cargo.toml index b0756573a1..d682917ef5 100644 --- a/crates/ollama/Cargo.toml +++ b/crates/ollama/Cargo.toml @@ -27,6 +27,7 @@ gpui.workspace = true http_client.workspace = true inline_completion.workspace = true language.workspace = true + log.workspace = true project.workspace = true @@ -38,6 +39,7 @@ text.workspace = true workspace-hack.workspace = true [dev-dependencies] +client = { workspace = true, features = ["test-support"] } editor = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } http_client = { workspace = true, features = ["test-support"] } diff --git a/crates/ollama/src/ollama_completion_provider.rs b/crates/ollama/src/ollama_completion_provider.rs index c515c715ec..5cb6705d96 100644 --- a/crates/ollama/src/ollama_completion_provider.rs +++ b/crates/ollama/src/ollama_completion_provider.rs @@ -13,6 +13,30 @@ use project::Project; pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); +// Structure for passing settings model data without circular dependencies +#[derive(Clone, Debug)] +pub struct SettingsModel { + pub name: String, + pub display_name: Option, + pub max_tokens: u64, + pub supports_tools: Option, + pub supports_images: Option, + pub supports_thinking: Option, +} + +impl SettingsModel { + pub fn to_model(&self) -> Model { + Model::new( + &self.name, + self.display_name.as_deref(), + Some(self.max_tokens), + self.supports_tools, + self.supports_images, + self.supports_thinking, + ) + } +} + // Global Ollama service for managing models across all providers pub struct OllamaService { http_client: Arc, @@ -61,6 +85,19 @@ impl OllamaService { self.restart_fetch_models_task(cx); } + pub fn set_settings_models( + &mut self, + settings_models: Vec, + cx: &mut Context, + ) { + // Convert settings models to our Model type + self.available_models = settings_models + .into_iter() + .map(|settings_model| settings_model.to_model()) + .collect(); + self.restart_fetch_models_task(cx); + } + fn restart_fetch_models_task(&mut self, cx: &mut Context) { self.fetch_models_task = Some(self.fetch_models(cx)); } @@ -70,15 +107,27 @@ impl OllamaService { let api_url = self.api_url.clone(); cx.spawn(async move |this, cx| { - let models = match crate::get_models(http_client.as_ref(), &api_url, None).await { + // Get the current settings models to merge with API models + let settings_models = this.update(cx, |this, _cx| { + // Get just the names of models from settings to avoid duplicates + this.available_models + .iter() + .map(|m| m.name.clone()) + .collect::>() + })?; + + // Fetch models from API + let api_models = match crate::get_models(http_client.as_ref(), &api_url, None).await { Ok(models) => models, - Err(_) => return Ok(()), // Silently fail and use empty list + Err(_) => return Ok(()), // Silently fail if API is unavailable }; - let tasks = models + let tasks = api_models .into_iter() // Filter out embedding models .filter(|model| !model.name.contains("-embed")) + // Filter out models that are already defined in settings + .filter(|model| !settings_models.contains(&model.name)) .map(|model| { let http_client = Arc::clone(&http_client); let api_url = api_url.clone(); @@ -98,8 +147,8 @@ impl OllamaService { } }); - // Rate-limit capability fetches - let mut ollama_models: Vec<_> = futures::stream::iter(tasks) + // Rate-limit capability fetches for API-discovered models + let api_discovered_models: Vec<_> = futures::stream::iter(tasks) .buffer_unordered(5) .collect::>>() .await @@ -107,10 +156,11 @@ impl OllamaService { .collect::>>() .unwrap_or_default(); - ollama_models.sort_by(|a, b| a.name.cmp(&b.name)); - this.update(cx, |this, cx| { - this.available_models = ollama_models; + // Append API-discovered models to existing settings models + this.available_models.extend(api_discovered_models); + // Sort all models by name + this.available_models.sort_by(|a, b| a.name.cmp(&b.name)); cx.notify(); })?; @@ -397,6 +447,7 @@ mod tests { use gpui::{AppContext, TestAppContext}; + use client; use language::Buffer; use project::Project; use settings::SettingsStore; @@ -406,6 +457,7 @@ mod tests { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); theme::init(theme::LoadThemes::JustBase, cx); + client::init_settings(cx); language::init(cx); editor::init_settings(cx); Project::init_settings(cx); @@ -930,4 +982,125 @@ mod tests { assert_eq!(editor.text(cx), ""); }); } + + #[gpui::test] + async fn test_settings_model_merging(cx: &mut TestAppContext) { + init_test(cx); + + // Create fake HTTP client that returns some API models + let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new()); + + // Mock /api/tags response (list models) + let models_response = serde_json::json!({ + "models": [ + { + "name": "api-model-1", + "modified_at": "2024-01-01T00:00:00Z", + "size": 1000000, + "digest": "abc123", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "7B", + "quantization_level": "Q4_0" + } + }, + { + "name": "shared-model", + "modified_at": "2024-01-01T00:00:00Z", + "size": 2000000, + "digest": "def456", + "details": { + "format": "gguf", + "family": "llama", + "families": ["llama"], + "parameter_size": "13B", + "quantization_level": "Q4_0" + } + } + ] + }); + + fake_http_client.set_response("/api/tags", models_response.to_string()); + + // Mock /api/show responses for each model + let show_response = serde_json::json!({ + "capabilities": ["tools", "vision"] + }); + fake_http_client.set_response("/api/show", show_response.to_string()); + + // Create service + let service = cx.update(|cx| { + OllamaService::new( + fake_http_client.clone(), + "http://localhost:11434".to_string(), + cx, + ) + }); + + // Add settings models (including one that overlaps with API) + let settings_models = vec![ + SettingsModel { + name: "custom-model-1".to_string(), + display_name: Some("Custom Model 1".to_string()), + max_tokens: 4096, + supports_tools: Some(true), + supports_images: Some(false), + supports_thinking: Some(false), + }, + SettingsModel { + name: "shared-model".to_string(), // This should take precedence over API + display_name: Some("Custom Shared Model".to_string()), + max_tokens: 8192, + supports_tools: Some(true), + supports_images: Some(true), + supports_thinking: Some(true), + }, + ]; + + cx.update(|cx| { + service.update(cx, |service, cx| { + service.set_settings_models(settings_models, cx); + }); + }); + + // Wait for models to be fetched and merged + cx.run_until_parked(); + + // Verify merged models + let models = cx.update(|cx| service.read(cx).available_models().to_vec()); + + assert_eq!(models.len(), 3); // 2 settings models + 1 unique API model + + // Models should be sorted alphabetically, so check by name + let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect(); + assert_eq!( + model_names, + vec!["api-model-1", "custom-model-1", "shared-model"] + ); + + // Check custom model from settings + let custom_model = models.iter().find(|m| m.name == "custom-model-1").unwrap(); + assert_eq!( + custom_model.display_name, + Some("Custom Model 1".to_string()) + ); + assert_eq!(custom_model.max_tokens, 4096); + + // Settings model should override API model for shared-model + let shared_model = models.iter().find(|m| m.name == "shared-model").unwrap(); + assert_eq!( + shared_model.display_name, + Some("Custom Shared Model".to_string()) + ); + assert_eq!(shared_model.max_tokens, 8192); + assert_eq!(shared_model.supports_tools, Some(true)); + assert_eq!(shared_model.supports_vision, Some(true)); + assert_eq!(shared_model.supports_thinking, Some(true)); + + // API-only model should be included + let api_model = models.iter().find(|m| m.name == "api-model-1").unwrap(); + assert!(api_model.display_name.is_none()); // API models don't have custom display names + } } diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index 4bf1650d77..5f1c08659e 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -6,7 +6,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use language::language_settings::{EditPredictionProvider, all_language_settings}; use language_models::AllLanguageModelSettings; -use ollama::{OllamaCompletionProvider, OllamaService}; +use ollama::{OllamaCompletionProvider, OllamaService, SettingsModel}; use settings::{Settings as _, SettingsStore}; use smol::stream::StreamExt; use std::{cell::RefCell, rc::Rc, sync::Arc}; @@ -19,8 +19,30 @@ use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { // Initialize global Ollama service - let settings = &AllLanguageModelSettings::get_global(cx).ollama; - let ollama_service = OllamaService::new(client.http_client(), settings.api_url.clone(), cx); + let (api_url, settings_models) = { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + let api_url = settings.api_url.clone(); + let settings_models: Vec = settings + .available_models + .iter() + .map(|model| SettingsModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: model.supports_tools, + supports_images: model.supports_images, + supports_thinking: model.supports_thinking, + }) + .collect(); + (api_url, settings_models) + }; + + let ollama_service = OllamaService::new(client.http_client(), api_url, cx); + + ollama_service.update(cx, |service, cx| { + service.set_settings_models(settings_models, cx); + }); + OllamaService::set_global(ollama_service, cx); let editors: Rc, AnyWindowHandle>>> = Rc::default(); @@ -144,10 +166,23 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { } } else if provider == EditPredictionProvider::Ollama { // Update global Ollama service when settings change - let _settings = &AllLanguageModelSettings::get_global(cx).ollama; + let settings = &AllLanguageModelSettings::get_global(cx).ollama; if let Some(service) = OllamaService::global(cx) { + let settings_models: Vec = settings + .available_models + .iter() + .map(|model| SettingsModel { + name: model.name.clone(), + display_name: model.display_name.clone(), + max_tokens: model.max_tokens, + supports_tools: model.supports_tools, + supports_images: model.supports_images, + supports_thinking: model.supports_thinking, + }) + .collect(); + service.update(cx, |service, cx| { - service.refresh_models(cx); + service.set_settings_models(settings_models, cx); }); } }