diff --git a/crates/ollama/src/ollama_completion_provider.rs b/crates/ollama/src/ollama_completion_provider.rs index 02abe6c935..3112ad831a 100644 --- a/crates/ollama/src/ollama_completion_provider.rs +++ b/crates/ollama/src/ollama_completion_provider.rs @@ -32,6 +32,11 @@ impl OllamaCompletionProvider { } } + /// Updates the model used by this provider + pub fn update_model(&mut self, new_model: String) { + self.model = new_model; + } + fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String { // Use model-specific FIM patterns match self.model.as_str() { @@ -100,7 +105,6 @@ impl EditPredictionProvider for OllamaCompletionProvider { ) { let http_client = self.http_client.clone(); let api_url = self.api_url.clone(); - let model = self.model.clone(); self.pending_refresh = Some(cx.spawn(async move |this, cx| { if debounce { @@ -125,8 +129,10 @@ impl EditPredictionProvider for OllamaCompletionProvider { let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?; + let model = this.update(cx, |this, _| this.model.clone())?; + let request = GenerateRequest { - model: model.clone(), + model, prompt, stream: false, options: Some(GenerateOptions { @@ -360,4 +366,41 @@ mod tests { assert!(completion.is_none()); } + + #[gpui::test] + async fn test_update_model(_cx: &mut TestAppContext) { + let mut provider = OllamaCompletionProvider::new( + Arc::new(FakeHttpClient::with_404_response()), + "http://localhost:11434".to_string(), + "codellama:7b".to_string(), + ); + + // Verify initial model + assert_eq!(provider.model, "codellama:7b"); + + // Test updating model + provider.update_model("deepseek-coder:6.7b".to_string()); + assert_eq!(provider.model, "deepseek-coder:6.7b"); + + // Test FIM prompt changes with different model + let prefix = "def hello():"; + let suffix = " pass"; + let prompt = provider.build_fim_prompt(prefix, suffix); + + // Should now use deepseek pattern + assert!(prompt.contains("<|fim▁begin|>")); + assert!(prompt.contains("<|fim▁hole|>")); + assert!(prompt.contains("<|fim▁end|>")); + + // Update to starcoder model + provider.update_model("starcoder:7b".to_string()); + assert_eq!(provider.model, "starcoder:7b"); + + let prompt = provider.build_fim_prompt(prefix, suffix); + + // Should now use starcoder pattern + assert!(prompt.contains("")); + assert!(prompt.contains("")); + assert!(prompt.contains("")); + } } diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs index bf6022be8e..df4c7a2919 100644 --- a/crates/zed/src/zed/inline_completion_registry.rs +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -3,6 +3,7 @@ use collections::HashMap; use copilot::{Copilot, CopilotCompletionProvider}; use editor::Editor; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; + use language::language_settings::{EditPredictionProvider, all_language_settings}; use language_models::AllLanguageModelSettings; use ollama::OllamaCompletionProvider; @@ -13,6 +14,7 @@ use supermaven::{Supermaven, SupermavenCompletionProvider}; use ui::Window; use util::ResultExt; use workspace::Workspace; +use zed_actions; use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; pub fn init(client: Arc, user_store: Entity, cx: &mut App) { @@ -135,6 +137,9 @@ pub fn init(client: Arc, user_store: Entity, cx: &mut App) { | EditPredictionProvider::Ollama => {} } } + } else if provider == EditPredictionProvider::Ollama { + // Update Ollama providers when settings change but provider stays the same + update_ollama_providers(&editors, &client, user_store.clone(), cx); } } }) @@ -147,6 +152,46 @@ fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) { } } +fn update_ollama_providers( + editors: &Rc, AnyWindowHandle>>>, + client: &Arc, + user_store: Entity, + cx: &mut App, +) { + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + let _current_model = settings + .available_models + .first() + .map(|m| m.name.clone()) + .unwrap_or_else(|| "codellama:7b".to_string()); + + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, window, cx| { + _ = editor.update(cx, |editor, cx| { + if let Some(provider) = editor.edit_prediction_provider() { + // Check if this is an Ollama provider by comparing names + if provider.name() == "ollama" { + // Recreate the provider with the new model + let settings = &AllLanguageModelSettings::get_global(cx).ollama; + let _api_url = settings.api_url.clone(); + + // Get client from the registry context (need to pass it) + // For now, we'll trigger a full reassignment + assign_edit_prediction_provider( + editor, + EditPredictionProvider::Ollama, + &client, + user_store.clone(), + window, + cx, + ); + } + } + }) + }); + } +} + fn assign_edit_prediction_providers( editors: &Rc, AnyWindowHandle>>>, provider: EditPredictionProvider,