Ollama model switcher working
This commit is contained in:
parent
5debd57040
commit
902a07606b
2 changed files with 90 additions and 2 deletions
|
@ -32,6 +32,11 @@ impl OllamaCompletionProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Updates the model used by this provider
|
||||||
|
pub fn update_model(&mut self, new_model: String) {
|
||||||
|
self.model = new_model;
|
||||||
|
}
|
||||||
|
|
||||||
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
||||||
// Use model-specific FIM patterns
|
// Use model-specific FIM patterns
|
||||||
match self.model.as_str() {
|
match self.model.as_str() {
|
||||||
|
@ -100,7 +105,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
) {
|
) {
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_url = self.api_url.clone();
|
let api_url = self.api_url.clone();
|
||||||
let model = self.model.clone();
|
|
||||||
|
|
||||||
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
||||||
if debounce {
|
if debounce {
|
||||||
|
@ -125,8 +129,10 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
|
||||||
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
||||||
|
|
||||||
|
let model = this.update(cx, |this, _| this.model.clone())?;
|
||||||
|
|
||||||
let request = GenerateRequest {
|
let request = GenerateRequest {
|
||||||
model: model.clone(),
|
model,
|
||||||
prompt,
|
prompt,
|
||||||
stream: false,
|
stream: false,
|
||||||
options: Some(GenerateOptions {
|
options: Some(GenerateOptions {
|
||||||
|
@ -360,4 +366,41 @@ mod tests {
|
||||||
|
|
||||||
assert!(completion.is_none());
|
assert!(completion.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_update_model(_cx: &mut TestAppContext) {
|
||||||
|
let mut provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify initial model
|
||||||
|
assert_eq!(provider.model, "codellama:7b");
|
||||||
|
|
||||||
|
// Test updating model
|
||||||
|
provider.update_model("deepseek-coder:6.7b".to_string());
|
||||||
|
assert_eq!(provider.model, "deepseek-coder:6.7b");
|
||||||
|
|
||||||
|
// Test FIM prompt changes with different model
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// Should now use deepseek pattern
|
||||||
|
assert!(prompt.contains("<|fim▁begin|>"));
|
||||||
|
assert!(prompt.contains("<|fim▁hole|>"));
|
||||||
|
assert!(prompt.contains("<|fim▁end|>"));
|
||||||
|
|
||||||
|
// Update to starcoder model
|
||||||
|
provider.update_model("starcoder:7b".to_string());
|
||||||
|
assert_eq!(provider.model, "starcoder:7b");
|
||||||
|
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
// Should now use starcoder pattern
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ use collections::HashMap;
|
||||||
use copilot::{Copilot, CopilotCompletionProvider};
|
use copilot::{Copilot, CopilotCompletionProvider};
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
||||||
|
|
||||||
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
||||||
use language_models::AllLanguageModelSettings;
|
use language_models::AllLanguageModelSettings;
|
||||||
use ollama::OllamaCompletionProvider;
|
use ollama::OllamaCompletionProvider;
|
||||||
|
@ -13,6 +14,7 @@ use supermaven::{Supermaven, SupermavenCompletionProvider};
|
||||||
use ui::Window;
|
use ui::Window;
|
||||||
use util::ResultExt;
|
use util::ResultExt;
|
||||||
use workspace::Workspace;
|
use workspace::Workspace;
|
||||||
|
use zed_actions;
|
||||||
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
|
@ -135,6 +137,9 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
| EditPredictionProvider::Ollama => {}
|
| EditPredictionProvider::Ollama => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if provider == EditPredictionProvider::Ollama {
|
||||||
|
// Update Ollama providers when settings change but provider stays the same
|
||||||
|
update_ollama_providers(&editors, &client, user_store.clone(), cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -147,6 +152,46 @@ fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn update_ollama_providers(
|
||||||
|
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
||||||
|
client: &Arc<Client>,
|
||||||
|
user_store: Entity<UserStore>,
|
||||||
|
cx: &mut App,
|
||||||
|
) {
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
let _current_model = settings
|
||||||
|
.available_models
|
||||||
|
.first()
|
||||||
|
.map(|m| m.name.clone())
|
||||||
|
.unwrap_or_else(|| "codellama:7b".to_string());
|
||||||
|
|
||||||
|
for (editor, window) in editors.borrow().iter() {
|
||||||
|
_ = window.update(cx, |_window, window, cx| {
|
||||||
|
_ = editor.update(cx, |editor, cx| {
|
||||||
|
if let Some(provider) = editor.edit_prediction_provider() {
|
||||||
|
// Check if this is an Ollama provider by comparing names
|
||||||
|
if provider.name() == "ollama" {
|
||||||
|
// Recreate the provider with the new model
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
let _api_url = settings.api_url.clone();
|
||||||
|
|
||||||
|
// Get client from the registry context (need to pass it)
|
||||||
|
// For now, we'll trigger a full reassignment
|
||||||
|
assign_edit_prediction_provider(
|
||||||
|
editor,
|
||||||
|
EditPredictionProvider::Ollama,
|
||||||
|
&client,
|
||||||
|
user_store.clone(),
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn assign_edit_prediction_providers(
|
fn assign_edit_prediction_providers(
|
||||||
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
||||||
provider: EditPredictionProvider,
|
provider: EditPredictionProvider,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue