Ollama model switcher working

This commit is contained in:
Oliver Azevedo Barnes 2025-07-02 13:15:11 -03:00
parent 5debd57040
commit 902a07606b
No known key found for this signature in database
2 changed files with 90 additions and 2 deletions

View file

@ -32,6 +32,11 @@ impl OllamaCompletionProvider {
} }
} }
/// Updates the model used by this provider
pub fn update_model(&mut self, new_model: String) {
self.model = new_model;
}
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String { fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
// Use model-specific FIM patterns // Use model-specific FIM patterns
match self.model.as_str() { match self.model.as_str() {
@ -100,7 +105,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
) { ) {
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_url = self.api_url.clone(); let api_url = self.api_url.clone();
let model = self.model.clone();
self.pending_refresh = Some(cx.spawn(async move |this, cx| { self.pending_refresh = Some(cx.spawn(async move |this, cx| {
if debounce { if debounce {
@ -125,8 +129,10 @@ impl EditPredictionProvider for OllamaCompletionProvider {
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?; let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
let model = this.update(cx, |this, _| this.model.clone())?;
let request = GenerateRequest { let request = GenerateRequest {
model: model.clone(), model,
prompt, prompt,
stream: false, stream: false,
options: Some(GenerateOptions { options: Some(GenerateOptions {
@ -360,4 +366,41 @@ mod tests {
assert!(completion.is_none()); assert!(completion.is_none());
} }
#[gpui::test]
async fn test_update_model(_cx: &mut TestAppContext) {
let mut provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
);
// Verify initial model
assert_eq!(provider.model, "codellama:7b");
// Test updating model
provider.update_model("deepseek-coder:6.7b".to_string());
assert_eq!(provider.model, "deepseek-coder:6.7b");
// Test FIM prompt changes with different model
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use deepseek pattern
assert!(prompt.contains("<fim▁begin>"));
assert!(prompt.contains("<fim▁hole>"));
assert!(prompt.contains("<fim▁end>"));
// Update to starcoder model
provider.update_model("starcoder:7b".to_string());
assert_eq!(provider.model, "starcoder:7b");
let prompt = provider.build_fim_prompt(prefix, suffix);
// Should now use starcoder pattern
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
}
} }

View file

@ -3,6 +3,7 @@ use collections::HashMap;
use copilot::{Copilot, CopilotCompletionProvider}; use copilot::{Copilot, CopilotCompletionProvider};
use editor::Editor; use editor::Editor;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity}; use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings}; use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::AllLanguageModelSettings; use language_models::AllLanguageModelSettings;
use ollama::OllamaCompletionProvider; use ollama::OllamaCompletionProvider;
@ -13,6 +14,7 @@ use supermaven::{Supermaven, SupermavenCompletionProvider};
use ui::Window; use ui::Window;
use util::ResultExt; use util::ResultExt;
use workspace::Workspace; use workspace::Workspace;
use zed_actions;
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider}; use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) { pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
@ -135,6 +137,9 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
| EditPredictionProvider::Ollama => {} | EditPredictionProvider::Ollama => {}
} }
} }
} else if provider == EditPredictionProvider::Ollama {
// Update Ollama providers when settings change but provider stays the same
update_ollama_providers(&editors, &client, user_store.clone(), cx);
} }
} }
}) })
@ -147,6 +152,46 @@ fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
} }
} }
fn update_ollama_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
client: &Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut App,
) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let _current_model = settings
.available_models
.first()
.map(|m| m.name.clone())
.unwrap_or_else(|| "codellama:7b".to_string());
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, window, cx| {
_ = editor.update(cx, |editor, cx| {
if let Some(provider) = editor.edit_prediction_provider() {
// Check if this is an Ollama provider by comparing names
if provider.name() == "ollama" {
// Recreate the provider with the new model
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let _api_url = settings.api_url.clone();
// Get client from the registry context (need to pass it)
// For now, we'll trigger a full reassignment
assign_edit_prediction_provider(
editor,
EditPredictionProvider::Ollama,
&client,
user_store.clone(),
window,
cx,
);
}
}
})
});
}
}
fn assign_edit_prediction_providers( fn assign_edit_prediction_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>, editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
provider: EditPredictionProvider, provider: EditPredictionProvider,