Auto detect models WIP
This commit is contained in:
parent
5a1506c3c2
commit
0bdb42e65d
8 changed files with 952 additions and 128 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -8363,6 +8363,7 @@ dependencies = [
|
||||||
"fs",
|
"fs",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
"gpui",
|
"gpui",
|
||||||
|
"gpui_tokio",
|
||||||
"http_client",
|
"http_client",
|
||||||
"indoc",
|
"indoc",
|
||||||
"inline_completion",
|
"inline_completion",
|
||||||
|
@ -8370,6 +8371,7 @@ dependencies = [
|
||||||
"language_model",
|
"language_model",
|
||||||
"language_models",
|
"language_models",
|
||||||
"lsp",
|
"lsp",
|
||||||
|
"ollama",
|
||||||
"paths",
|
"paths",
|
||||||
"project",
|
"project",
|
||||||
"regex",
|
"regex",
|
||||||
|
@ -20253,6 +20255,7 @@ dependencies = [
|
||||||
"nix 0.29.0",
|
"nix 0.29.0",
|
||||||
"node_runtime",
|
"node_runtime",
|
||||||
"notifications",
|
"notifications",
|
||||||
|
"ollama",
|
||||||
"onboarding",
|
"onboarding",
|
||||||
"outline",
|
"outline",
|
||||||
"outline_panel",
|
"outline_panel",
|
||||||
|
|
|
@ -25,6 +25,7 @@ indoc.workspace = true
|
||||||
inline_completion.workspace = true
|
inline_completion.workspace = true
|
||||||
language.workspace = true
|
language.workspace = true
|
||||||
language_models.workspace = true
|
language_models.workspace = true
|
||||||
|
ollama.workspace = true
|
||||||
|
|
||||||
paths.workspace = true
|
paths.workspace = true
|
||||||
regex.workspace = true
|
regex.workspace = true
|
||||||
|
@ -48,6 +49,9 @@ http_client = { workspace = true, features = ["test-support"] }
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
language_model = { workspace = true, features = ["test-support"] }
|
language_model = { workspace = true, features = ["test-support"] }
|
||||||
lsp = { workspace = true, features = ["test-support"] }
|
lsp = { workspace = true, features = ["test-support"] }
|
||||||
|
ollama = { workspace = true, features = ["test-support"] }
|
||||||
project = { workspace = true, features = ["test-support"] }
|
project = { workspace = true, features = ["test-support"] }
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
settings = { workspace = true, features = ["test-support"] }
|
||||||
theme = { workspace = true, features = ["test-support"] }
|
theme = { workspace = true, features = ["test-support"] }
|
||||||
|
gpui_tokio.workspace = true
|
||||||
|
|
|
@ -21,6 +21,8 @@ use language::{
|
||||||
};
|
};
|
||||||
use language_models::AllLanguageModelSettings;
|
use language_models::AllLanguageModelSettings;
|
||||||
|
|
||||||
|
use ollama;
|
||||||
|
|
||||||
use paths;
|
use paths;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use settings::{Settings, SettingsStore, update_settings_file};
|
use settings::{Settings, SettingsStore, update_settings_file};
|
||||||
|
@ -413,6 +415,10 @@ impl InlineCompletionButton {
|
||||||
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
|
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
|
||||||
.detach();
|
.detach();
|
||||||
|
|
||||||
|
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||||
|
cx.observe(&service, |_, _, cx| cx.notify()).detach();
|
||||||
|
}
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
editor_subscription: None,
|
editor_subscription: None,
|
||||||
editor_enabled: None,
|
editor_enabled: None,
|
||||||
|
@ -858,8 +864,30 @@ impl InlineCompletionButton {
|
||||||
let settings = AllLanguageModelSettings::get_global(cx);
|
let settings = AllLanguageModelSettings::get_global(cx);
|
||||||
let ollama_settings = &settings.ollama;
|
let ollama_settings = &settings.ollama;
|
||||||
|
|
||||||
// Clone needed values to avoid borrowing issues
|
// Get models from both settings and global service discovery
|
||||||
let available_models = ollama_settings.available_models.clone();
|
let mut available_models = ollama_settings.available_models.clone();
|
||||||
|
|
||||||
|
// Add discovered models from the global Ollama service
|
||||||
|
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||||
|
let discovered_models = service.read(cx).available_models();
|
||||||
|
for model in discovered_models {
|
||||||
|
// Convert from ollama::Model to language_models AvailableModel
|
||||||
|
let available_model = language_models::provider::ollama::AvailableModel {
|
||||||
|
name: model.name.clone(),
|
||||||
|
display_name: model.display_name.clone(),
|
||||||
|
max_tokens: model.max_tokens,
|
||||||
|
keep_alive: model.keep_alive.clone(),
|
||||||
|
supports_tools: model.supports_tools,
|
||||||
|
supports_images: model.supports_vision,
|
||||||
|
supports_thinking: model.supports_thinking,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add if not already in settings (settings take precedence)
|
||||||
|
if !available_models.iter().any(|m| m.name == model.name) {
|
||||||
|
available_models.push(available_model);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// API URL configuration - only show if Ollama settings exist in the user's config
|
// API URL configuration - only show if Ollama settings exist in the user's config
|
||||||
let menu = if Self::ollama_settings_exist(cx) {
|
let menu = if Self::ollama_settings_exist(cx) {
|
||||||
|
@ -878,7 +906,7 @@ impl InlineCompletionButton {
|
||||||
let menu = menu.separator().header("Available Models");
|
let menu = menu.separator().header("Available Models");
|
||||||
|
|
||||||
// Add each available model as a menu entry
|
// Add each available model as a menu entry
|
||||||
available_models.iter().fold(menu, |menu, model| {
|
let menu = available_models.iter().fold(menu, |menu, model| {
|
||||||
let model_name = model.display_name.as_ref().unwrap_or(&model.name);
|
let model_name = model.display_name.as_ref().unwrap_or(&model.name);
|
||||||
let is_current = available_models
|
let is_current = available_models
|
||||||
.first()
|
.first()
|
||||||
|
@ -898,6 +926,13 @@ impl InlineCompletionButton {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
// Add refresh models option
|
||||||
|
menu.separator().entry("Refresh Models", None, {
|
||||||
|
move |_window, cx| {
|
||||||
|
Self::refresh_ollama_models(cx);
|
||||||
|
}
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
menu.separator()
|
menu.separator()
|
||||||
|
@ -908,6 +943,11 @@ impl InlineCompletionButton {
|
||||||
Self::open_ollama_settings(fs.clone(), window, cx);
|
Self::open_ollama_settings(fs.clone(), window, cx);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.entry("Refresh Models", None, {
|
||||||
|
move |_window, cx| {
|
||||||
|
Self::refresh_ollama_models(cx);
|
||||||
|
}
|
||||||
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use the common language settings menu
|
// Use the common language settings menu
|
||||||
|
@ -997,6 +1037,14 @@ impl InlineCompletionButton {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn refresh_ollama_models(cx: &mut App) {
|
||||||
|
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||||
|
service.update(cx, |service, cx| {
|
||||||
|
service.refresh_models(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
|
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
|
||||||
let editor = editor.read(cx);
|
let editor = editor.read(cx);
|
||||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||||
|
@ -1188,3 +1236,359 @@ fn toggle_edit_prediction_mode(fs: Arc<dyn Fs>, mode: EditPredictionsMode, cx: &
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use clock::FakeSystemClock;
|
||||||
|
use gpui::TestAppContext;
|
||||||
|
use http_client;
|
||||||
|
use language_models::provider::ollama::AvailableModel;
|
||||||
|
use ollama::{OllamaService, fake::FakeHttpClient};
|
||||||
|
use settings::SettingsStore;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn init_test(cx: &mut TestAppContext) {
|
||||||
|
cx.update(|cx| {
|
||||||
|
let settings_store = SettingsStore::test(cx);
|
||||||
|
cx.set_global(settings_store);
|
||||||
|
gpui_tokio::init(cx);
|
||||||
|
theme::init(theme::LoadThemes::JustBase, cx);
|
||||||
|
language::init(cx);
|
||||||
|
language_settings::init(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_ollama_menu_shows_discovered_models(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
// Create fake HTTP client with mock models response
|
||||||
|
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Mock /api/tags response
|
||||||
|
let models_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "qwen2.5-coder:3b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 1000000,
|
||||||
|
"digest": "abc123",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "qwen2",
|
||||||
|
"families": ["qwen2"],
|
||||||
|
"parameter_size": "3B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "codellama:7b-code",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 2000000,
|
||||||
|
"digest": "def456",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "codellama",
|
||||||
|
"families": ["codellama"],
|
||||||
|
"parameter_size": "7B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||||
|
|
||||||
|
// Mock /api/show response
|
||||||
|
let capabilities = serde_json::json!({
|
||||||
|
"capabilities": ["tools"]
|
||||||
|
});
|
||||||
|
fake_http_client.set_response("/api/show", capabilities.to_string());
|
||||||
|
|
||||||
|
// Create and set global Ollama service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for model discovery
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify models are accessible through the service
|
||||||
|
cx.update(|cx| {
|
||||||
|
if let Some(service) = OllamaService::global(cx) {
|
||||||
|
let discovered_models = service.read(cx).available_models();
|
||||||
|
assert_eq!(discovered_models.len(), 2);
|
||||||
|
|
||||||
|
let model_names: Vec<&str> =
|
||||||
|
discovered_models.iter().map(|m| m.name.as_str()).collect();
|
||||||
|
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||||
|
assert!(model_names.contains(&"codellama:7b-code"));
|
||||||
|
} else {
|
||||||
|
panic!("Global service should be available");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify the global service has the expected models
|
||||||
|
service.read_with(cx, |service, _| {
|
||||||
|
let models = service.available_models();
|
||||||
|
assert_eq!(models.len(), 2);
|
||||||
|
|
||||||
|
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||||
|
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||||
|
assert!(model_names.contains(&"codellama:7b-code"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_ollama_menu_shows_service_models(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
// Create fake HTTP client with models
|
||||||
|
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||||
|
|
||||||
|
let models_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "qwen2.5-coder:7b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 1000000,
|
||||||
|
"digest": "abc123",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "qwen2",
|
||||||
|
"families": ["qwen2"],
|
||||||
|
"parameter_size": "7B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||||
|
fake_http_client.set_response(
|
||||||
|
"/api/show",
|
||||||
|
serde_json::json!({"capabilities": []}).to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create and set global service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(fake_http_client, "http://localhost:11434".to_string(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Test that discovered models are accessible
|
||||||
|
cx.update(|cx| {
|
||||||
|
if let Some(service) = OllamaService::global(cx) {
|
||||||
|
let discovered_models = service.read(cx).available_models();
|
||||||
|
assert_eq!(discovered_models.len(), 1);
|
||||||
|
assert_eq!(discovered_models[0].name, "qwen2.5-coder:7b");
|
||||||
|
} else {
|
||||||
|
panic!("Global service should be available");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_ollama_menu_refreshes_on_service_update(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Initially empty models
|
||||||
|
fake_http_client.set_response("/api/tags", serde_json::json!({"models": []}).to_string());
|
||||||
|
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify the service subscription mechanism works by creating a button
|
||||||
|
let _button = cx.update(|cx| {
|
||||||
|
let fs = fs::FakeFs::new(cx.background_executor().clone());
|
||||||
|
let user_store = cx.new(|cx| {
|
||||||
|
client::UserStore::new(
|
||||||
|
Arc::new(http_client::FakeHttpClient::create(|_| {
|
||||||
|
Box::pin(async { Err(anyhow::anyhow!("not implemented")) })
|
||||||
|
})),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
let popover_handle = PopoverMenuHandle::default();
|
||||||
|
|
||||||
|
cx.new(|cx| InlineCompletionButton::new(fs, user_store, popover_handle, cx))
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify initially no models
|
||||||
|
service.read_with(cx, |service, _| {
|
||||||
|
assert_eq!(service.available_models().len(), 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update mock to return models
|
||||||
|
let models_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "phi3:mini",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 500000,
|
||||||
|
"digest": "xyz789",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "phi3",
|
||||||
|
"families": ["phi3"],
|
||||||
|
"parameter_size": "3.8B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||||
|
fake_http_client.set_response(
|
||||||
|
"/api/show",
|
||||||
|
serde_json::json!({"capabilities": []}).to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Trigger refresh
|
||||||
|
service.update(cx, |service, cx| {
|
||||||
|
service.refresh_models(cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify models were refreshed
|
||||||
|
service.read_with(cx, |service, _| {
|
||||||
|
let models = service.available_models();
|
||||||
|
assert_eq!(models.len(), 1);
|
||||||
|
assert_eq!(models[0].name, "phi3:mini");
|
||||||
|
});
|
||||||
|
|
||||||
|
// The button should have been notified and will rebuild its menu with new models
|
||||||
|
// when next requested (this tests the subscription mechanism)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_refresh_models_button_functionality(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Start with one model
|
||||||
|
let initial_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "mistral:7b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 1000000,
|
||||||
|
"digest": "initial123",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "mistral",
|
||||||
|
"families": ["mistral"],
|
||||||
|
"parameter_size": "7B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", initial_response.to_string());
|
||||||
|
fake_http_client.set_response(
|
||||||
|
"/api/show",
|
||||||
|
serde_json::json!({"capabilities": []}).to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify initial model
|
||||||
|
service.read_with(cx, |service, _| {
|
||||||
|
assert_eq!(service.available_models().len(), 1);
|
||||||
|
assert_eq!(service.available_models()[0].name, "mistral:7b");
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update mock to simulate new model available
|
||||||
|
let updated_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "mistral:7b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 1000000,
|
||||||
|
"digest": "initial123",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "mistral",
|
||||||
|
"families": ["mistral"],
|
||||||
|
"parameter_size": "7B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "gemma2:9b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 2000000,
|
||||||
|
"digest": "new456",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "gemma2",
|
||||||
|
"families": ["gemma2"],
|
||||||
|
"parameter_size": "9B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", updated_response.to_string());
|
||||||
|
|
||||||
|
// Simulate clicking "Refresh Models" button
|
||||||
|
cx.update(|cx| {
|
||||||
|
InlineCompletionButton::refresh_ollama_models(cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify models were refreshed
|
||||||
|
service.read_with(cx, |service, _| {
|
||||||
|
let models = service.available_models();
|
||||||
|
assert_eq!(models.len(), 2);
|
||||||
|
|
||||||
|
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||||
|
assert!(model_names.contains(&"mistral:7b"));
|
||||||
|
assert!(model_names.contains(&"gemma2:9b"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use futures::{Stream, TryFutureExt, stream};
|
use futures::{Stream, TryFutureExt, stream};
|
||||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, Subscription, Task};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||||
|
@ -141,6 +141,29 @@ impl State {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaLanguageModelProvider {
|
impl OllamaLanguageModelProvider {
|
||||||
|
pub fn global(cx: &App) -> Option<Entity<Self>> {
|
||||||
|
cx.try_global::<GlobalOllamaLanguageModelProvider>()
|
||||||
|
.map(|provider| provider.0.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_global(provider: Entity<Self>, cx: &mut App) {
|
||||||
|
cx.set_global(GlobalOllamaLanguageModelProvider(provider));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn available_models_for_completion(&self, cx: &App) -> Vec<ollama::Model> {
|
||||||
|
self.state.read(cx).available_models.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn http_client(&self) -> Arc<dyn HttpClient> {
|
||||||
|
self.http_client.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn refresh_models(&self, cx: &mut App) {
|
||||||
|
self.state.update(cx, |state, cx| {
|
||||||
|
state.restart_fetch_models_task(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||||
let this = Self {
|
let this = Self {
|
||||||
http_client: http_client.clone(),
|
http_client: http_client.clone(),
|
||||||
|
@ -667,6 +690,10 @@ impl Render for ConfigurationView {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct GlobalOllamaLanguageModelProvider(Entity<OllamaLanguageModelProvider>);
|
||||||
|
|
||||||
|
impl Global for GlobalOllamaLanguageModelProvider {}
|
||||||
|
|
||||||
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
|
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
|
||||||
ollama::OllamaTool::Function {
|
ollama::OllamaTool::Function {
|
||||||
function: OllamaFunctionTool {
|
function: OllamaFunctionTool {
|
||||||
|
|
|
@ -30,6 +30,7 @@ language.workspace = true
|
||||||
log.workspace = true
|
log.workspace = true
|
||||||
|
|
||||||
project.workspace = true
|
project.workspace = true
|
||||||
|
settings.workspace = true
|
||||||
schemars = { workspace = true, optional = true }
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
|
|
@ -541,14 +541,8 @@ pub mod fake {
|
||||||
) {
|
) {
|
||||||
let fake_client = std::sync::Arc::new(FakeHttpClient::new());
|
let fake_client = std::sync::Arc::new(FakeHttpClient::new());
|
||||||
|
|
||||||
let provider = cx.new(|_| {
|
let provider =
|
||||||
OllamaCompletionProvider::new(
|
cx.new(|cx| OllamaCompletionProvider::new("qwencoder".to_string(), None, cx));
|
||||||
fake_client.clone(),
|
|
||||||
"http://localhost:11434".to_string(),
|
|
||||||
"qwencoder".to_string(),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
|
|
||||||
(provider, fake_client)
|
(provider, fake_client)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,43 +1,172 @@
|
||||||
use crate::{GenerateOptions, GenerateRequest, generate};
|
use crate::{GenerateOptions, GenerateRequest, Model, generate};
|
||||||
use anyhow::{Context as AnyhowContext, Result};
|
use anyhow::{Context as AnyhowContext, Result};
|
||||||
|
use futures::StreamExt;
|
||||||
|
use std::{path::Path, sync::Arc, time::Duration};
|
||||||
|
|
||||||
use gpui::{App, Context, Entity, EntityId, Task};
|
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
||||||
use language::{Anchor, Buffer, ToOffset};
|
use language::{Anchor, Buffer, ToOffset};
|
||||||
|
use settings::SettingsStore;
|
||||||
|
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{path::Path, sync::Arc, time::Duration};
|
|
||||||
|
|
||||||
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
||||||
|
|
||||||
pub struct OllamaCompletionProvider {
|
// Global Ollama service for managing models across all providers
|
||||||
|
pub struct OllamaService {
|
||||||
http_client: Arc<dyn HttpClient>,
|
http_client: Arc<dyn HttpClient>,
|
||||||
api_url: String,
|
api_url: String,
|
||||||
|
available_models: Vec<Model>,
|
||||||
|
fetch_models_task: Option<Task<Result<()>>>,
|
||||||
|
_settings_subscription: Subscription,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaService {
|
||||||
|
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, cx: &mut App) -> Entity<Self> {
|
||||||
|
cx.new(|cx| {
|
||||||
|
let subscription = cx.observe_global::<SettingsStore>({
|
||||||
|
move |this: &mut OllamaService, cx| {
|
||||||
|
this.restart_fetch_models_task(cx);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut service = Self {
|
||||||
|
http_client,
|
||||||
|
api_url,
|
||||||
|
available_models: Vec::new(),
|
||||||
|
fetch_models_task: None,
|
||||||
|
_settings_subscription: subscription,
|
||||||
|
};
|
||||||
|
|
||||||
|
service.restart_fetch_models_task(cx);
|
||||||
|
service
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn global(cx: &App) -> Option<Entity<Self>> {
|
||||||
|
cx.try_global::<GlobalOllamaService>()
|
||||||
|
.map(|service| service.0.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_global(service: Entity<Self>, cx: &mut App) {
|
||||||
|
cx.set_global(GlobalOllamaService(service));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self) -> &[Model] {
|
||||||
|
&self.available_models
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn refresh_models(&mut self, cx: &mut Context<Self>) {
|
||||||
|
self.restart_fetch_models_task(cx);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
|
||||||
|
self.fetch_models_task = Some(self.fetch_models(cx));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||||
|
let http_client = Arc::clone(&self.http_client);
|
||||||
|
let api_url = self.api_url.clone();
|
||||||
|
|
||||||
|
cx.spawn(async move |this, cx| {
|
||||||
|
let models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
|
||||||
|
Ok(models) => models,
|
||||||
|
Err(_) => return Ok(()), // Silently fail and use empty list
|
||||||
|
};
|
||||||
|
|
||||||
|
let tasks = models
|
||||||
|
.into_iter()
|
||||||
|
// Filter out embedding models
|
||||||
|
.filter(|model| !model.name.contains("-embed"))
|
||||||
|
.map(|model| {
|
||||||
|
let http_client = Arc::clone(&http_client);
|
||||||
|
let api_url = api_url.clone();
|
||||||
|
async move {
|
||||||
|
let name = model.name.as_str();
|
||||||
|
let capabilities =
|
||||||
|
crate::show_model(http_client.as_ref(), &api_url, name).await?;
|
||||||
|
let ollama_model = Model::new(
|
||||||
|
name,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
Some(capabilities.supports_tools()),
|
||||||
|
Some(capabilities.supports_vision()),
|
||||||
|
Some(capabilities.supports_thinking()),
|
||||||
|
);
|
||||||
|
Ok(ollama_model)
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Rate-limit capability fetches
|
||||||
|
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
|
||||||
|
.buffer_unordered(5)
|
||||||
|
.collect::<Vec<Result<_>>>()
|
||||||
|
.await
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<Vec<_>>>()
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.available_models = ollama_models;
|
||||||
|
cx.notify();
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct GlobalOllamaService(Entity<OllamaService>);
|
||||||
|
|
||||||
|
impl Global for GlobalOllamaService {}
|
||||||
|
|
||||||
|
pub struct OllamaCompletionProvider {
|
||||||
model: String,
|
model: String,
|
||||||
buffer_id: Option<EntityId>,
|
buffer_id: Option<EntityId>,
|
||||||
file_extension: Option<String>,
|
file_extension: Option<String>,
|
||||||
current_completion: Option<String>,
|
current_completion: Option<String>,
|
||||||
pending_refresh: Option<Task<Result<()>>>,
|
pending_refresh: Option<Task<Result<()>>>,
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
|
_service_subscription: Option<Subscription>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaCompletionProvider {
|
impl OllamaCompletionProvider {
|
||||||
pub fn new(
|
pub fn new(model: String, api_key: Option<String>, cx: &mut Context<Self>) -> Self {
|
||||||
http_client: Arc<dyn HttpClient>,
|
let subscription = if let Some(service) = OllamaService::global(cx) {
|
||||||
api_url: String,
|
Some(cx.observe(&service, |_this, _service, cx| {
|
||||||
model: String,
|
cx.notify();
|
||||||
api_key: Option<String>,
|
}))
|
||||||
) -> Self {
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
http_client,
|
|
||||||
api_url,
|
|
||||||
model,
|
model,
|
||||||
buffer_id: None,
|
buffer_id: None,
|
||||||
file_extension: None,
|
file_extension: None,
|
||||||
current_completion: None,
|
current_completion: None,
|
||||||
pending_refresh: None,
|
pending_refresh: None,
|
||||||
api_key,
|
api_key,
|
||||||
|
_service_subscription: subscription,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self, cx: &App) -> Vec<Model> {
|
||||||
|
if let Some(service) = OllamaService::global(cx) {
|
||||||
|
service.read(cx).available_models().to_vec()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn refresh_models(&self, cx: &mut App) {
|
||||||
|
if let Some(service) = OllamaService::global(cx) {
|
||||||
|
service.update(cx, |service, cx| {
|
||||||
|
service.refresh_models(cx);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,14 +233,28 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
|
||||||
fn refresh(
|
fn refresh(
|
||||||
&mut self,
|
&mut self,
|
||||||
_project: Option<Entity<Project>>,
|
project: Option<Entity<Project>>,
|
||||||
buffer: Entity<Buffer>,
|
buffer: Entity<Buffer>,
|
||||||
cursor_position: Anchor,
|
cursor_position: Anchor,
|
||||||
debounce: bool,
|
debounce: bool,
|
||||||
cx: &mut Context<Self>,
|
cx: &mut Context<Self>,
|
||||||
) {
|
) {
|
||||||
let http_client = self.http_client.clone();
|
// Get API settings from the global Ollama service or fallback
|
||||||
let api_url = self.api_url.clone();
|
let (http_client, api_url) = if let Some(service) = OllamaService::global(cx) {
|
||||||
|
let service_ref = service.read(cx);
|
||||||
|
(service_ref.http_client.clone(), service_ref.api_url.clone())
|
||||||
|
} else {
|
||||||
|
// Fallback if global service isn't available
|
||||||
|
(
|
||||||
|
project
|
||||||
|
.as_ref()
|
||||||
|
.map(|p| p.read(cx).client().http_client() as Arc<dyn HttpClient>)
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
Arc::new(http_client::BlockedHttpClient::new()) as Arc<dyn HttpClient>
|
||||||
|
}),
|
||||||
|
crate::OLLAMA_API_URL.to_string(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
||||||
if debounce {
|
if debounce {
|
||||||
|
@ -156,15 +299,18 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
|
||||||
let response = generate(http_client.as_ref(), &api_url, api_key, request)
|
let response = generate(http_client.as_ref(), &api_url, api_key, request)
|
||||||
.await
|
.await
|
||||||
.context("Failed to get completion from Ollama")?;
|
.context("Failed to get completion from Ollama");
|
||||||
|
|
||||||
this.update(cx, |this, cx| {
|
this.update(cx, |this, cx| {
|
||||||
this.pending_refresh = None;
|
this.pending_refresh = None;
|
||||||
if !response.response.trim().is_empty() {
|
match response {
|
||||||
|
Ok(response) if !response.response.trim().is_empty() => {
|
||||||
this.current_completion = Some(response.response);
|
this.current_completion = Some(response.response);
|
||||||
} else {
|
}
|
||||||
|
_ => {
|
||||||
this.current_completion = None;
|
this.current_completion = None;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
cx.notify();
|
cx.notify();
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
@ -248,7 +394,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::fake::Ollama;
|
|
||||||
|
|
||||||
use gpui::{AppContext, TestAppContext};
|
use gpui::{AppContext, TestAppContext};
|
||||||
|
|
||||||
|
@ -269,31 +414,238 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Test the complete Ollama completion flow from refresh to suggestion
|
/// Test the complete Ollama completion flow from refresh to suggestion
|
||||||
#[test]
|
#[gpui::test]
|
||||||
fn test_get_stop_tokens() {
|
fn test_get_stop_tokens(cx: &mut TestAppContext) {
|
||||||
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
init_test(cx);
|
||||||
|
|
||||||
// Test CodeLlama code model gets stop tokens
|
// Test CodeLlama code model gets stop tokens
|
||||||
let codellama_provider = OllamaCompletionProvider::new(
|
let codellama_provider = cx.update(|cx| {
|
||||||
http_client.clone(),
|
cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx))
|
||||||
"http://localhost:11434".to_string(),
|
});
|
||||||
"codellama:7b-code".to_string(),
|
|
||||||
None,
|
|
||||||
);
|
|
||||||
|
|
||||||
assert_eq!(
|
codellama_provider.read_with(cx, |provider, _| {
|
||||||
codellama_provider.get_stop_tokens(),
|
assert_eq!(provider.get_stop_tokens(), Some(vec!["<EOT>".to_string()]));
|
||||||
Some(vec!["<EOT>".to_string()])
|
});
|
||||||
);
|
|
||||||
|
|
||||||
// Test non-CodeLlama model doesn't get stop tokens
|
// Test non-CodeLlama model doesn't get stop tokens
|
||||||
let qwen_provider = OllamaCompletionProvider::new(
|
let qwen_provider = cx.update(|cx| {
|
||||||
http_client.clone(),
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
|
qwen_provider.read_with(cx, |provider, _| {
|
||||||
|
assert_eq!(provider.get_stop_tokens(), None);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_model_discovery(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
// Create fake HTTP client
|
||||||
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Mock /api/tags response (list models)
|
||||||
|
let models_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "qwen2.5-coder:3b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 1000000,
|
||||||
|
"digest": "abc123",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "qwen2",
|
||||||
|
"families": ["qwen2"],
|
||||||
|
"parameter_size": "3B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "codellama:7b-code",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 2000000,
|
||||||
|
"digest": "def456",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "codellama",
|
||||||
|
"families": ["codellama"],
|
||||||
|
"parameter_size": "7B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "nomic-embed-text",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 500000,
|
||||||
|
"digest": "ghi789",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "nomic-embed",
|
||||||
|
"families": ["nomic-embed"],
|
||||||
|
"parameter_size": "137M",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||||
|
|
||||||
|
// Mock /api/show responses for model capabilities
|
||||||
|
let qwen_capabilities = serde_json::json!({
|
||||||
|
"capabilities": ["tools", "thinking"]
|
||||||
|
});
|
||||||
|
|
||||||
|
let _codellama_capabilities = serde_json::json!({
|
||||||
|
"capabilities": []
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/show", qwen_capabilities.to_string());
|
||||||
|
|
||||||
|
// Create global Ollama service for testing
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"qwen2.5-coder:3b".to_string(),
|
cx,
|
||||||
None,
|
)
|
||||||
);
|
});
|
||||||
assert_eq!(qwen_provider.get_stop_tokens(), None);
|
|
||||||
|
// Set it as global
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create completion provider
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for model discovery to complete
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify models were discovered through the global provider
|
||||||
|
provider.read_with(cx, |provider, cx| {
|
||||||
|
let models = provider.available_models(cx);
|
||||||
|
assert_eq!(models.len(), 2); // Should exclude nomic-embed-text
|
||||||
|
|
||||||
|
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||||
|
assert!(model_names.contains(&"codellama:7b-code"));
|
||||||
|
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||||
|
assert!(!model_names.contains(&"nomic-embed-text"));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_model_discovery_api_failure(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
// Create fake HTTP client that returns errors
|
||||||
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
fake_http_client.set_error("Connection refused");
|
||||||
|
|
||||||
|
// Create global Ollama service that will fail
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create completion provider
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for model discovery to complete (with failure)
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify graceful handling - should have empty model list
|
||||||
|
provider.read_with(cx, |provider, cx| {
|
||||||
|
let models = provider.available_models(cx);
|
||||||
|
assert_eq!(models.len(), 0);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_refresh_models(cx: &mut TestAppContext) {
|
||||||
|
init_test(cx);
|
||||||
|
|
||||||
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Initially return empty model list
|
||||||
|
let empty_response = serde_json::json!({"models": []});
|
||||||
|
fake_http_client.set_response("/api/tags", empty_response.to_string());
|
||||||
|
|
||||||
|
// Create global Ollama service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify initially empty
|
||||||
|
provider.read_with(cx, |provider, cx| {
|
||||||
|
assert_eq!(provider.available_models(cx).len(), 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Update mock to return models
|
||||||
|
let models_response = serde_json::json!({
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "qwen2.5-coder:7b",
|
||||||
|
"modified_at": "2024-01-01T00:00:00Z",
|
||||||
|
"size": 1000000,
|
||||||
|
"digest": "abc123",
|
||||||
|
"details": {
|
||||||
|
"format": "gguf",
|
||||||
|
"family": "qwen2",
|
||||||
|
"families": ["qwen2"],
|
||||||
|
"parameter_size": "7B",
|
||||||
|
"quantization_level": "Q4_0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||||
|
|
||||||
|
let capabilities = serde_json::json!({
|
||||||
|
"capabilities": ["tools", "thinking"]
|
||||||
|
});
|
||||||
|
|
||||||
|
fake_http_client.set_response("/api/show", capabilities.to_string());
|
||||||
|
|
||||||
|
// Trigger refresh
|
||||||
|
provider.update(cx, |provider, cx| {
|
||||||
|
provider.refresh_models(cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
|
// Verify models were refreshed
|
||||||
|
provider.read_with(cx, |provider, cx| {
|
||||||
|
let models = provider.available_models(cx);
|
||||||
|
assert_eq!(models.len(), 1);
|
||||||
|
assert_eq!(models[0].name, "qwen2.5-coder:7b");
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
|
@ -306,12 +658,28 @@ mod tests {
|
||||||
buffer.anchor_before(11) // Position in the middle of the function
|
buffer.anchor_before(11) // Position in the middle of the function
|
||||||
});
|
});
|
||||||
|
|
||||||
// Create Ollama provider with fake HTTP client
|
// Create fake HTTP client and set up global service
|
||||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
|
||||||
// Configure mock HTTP response
|
|
||||||
fake_http_client.set_generate_response("println!(\"Hello\");");
|
fake_http_client.set_generate_response("println!(\"Hello\");");
|
||||||
|
|
||||||
|
// Create global Ollama service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
// Trigger completion refresh (no debounce for test speed)
|
// Trigger completion refresh (no debounce for test speed)
|
||||||
provider.update(cx, |provider, cx| {
|
provider.update(cx, |provider, cx| {
|
||||||
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
|
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
|
||||||
|
@ -363,7 +731,26 @@ mod tests {
|
||||||
buffer.anchor_after(16) // After "vec"
|
buffer.anchor_after(16) // After "vec"
|
||||||
});
|
});
|
||||||
|
|
||||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
// Create fake HTTP client and set up global service
|
||||||
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
|
||||||
|
// Create global Ollama service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
// Configure response that starts with what user already typed
|
// Configure response that starts with what user already typed
|
||||||
fake_http_client.set_generate_response("vec![1, 2, 3]");
|
fake_http_client.set_generate_response("vec![1, 2, 3]");
|
||||||
|
@ -393,7 +780,28 @@ mod tests {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
|
|
||||||
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
||||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
|
||||||
|
// Create fake HTTP client and set up global service
|
||||||
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
fake_http_client.set_generate_response("vec![hello, world]");
|
||||||
|
|
||||||
|
// Create global Ollama service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
// Set up the editor with the Ollama provider
|
// Set up the editor with the Ollama provider
|
||||||
editor_cx.update_editor(|editor, window, cx| {
|
editor_cx.update_editor(|editor, window, cx| {
|
||||||
|
@ -403,9 +811,6 @@ mod tests {
|
||||||
// Set initial state
|
// Set initial state
|
||||||
editor_cx.set_state("let items = ˇ");
|
editor_cx.set_state("let items = ˇ");
|
||||||
|
|
||||||
// Configure a multi-word completion
|
|
||||||
fake_http_client.set_generate_response("vec![hello, world]");
|
|
||||||
|
|
||||||
// Trigger the completion through the provider
|
// Trigger the completion through the provider
|
||||||
let buffer =
|
let buffer =
|
||||||
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
||||||
|
@ -455,7 +860,28 @@ mod tests {
|
||||||
init_test(cx);
|
init_test(cx);
|
||||||
|
|
||||||
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
||||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
|
||||||
|
// Create fake HTTP client and set up global service
|
||||||
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||||
|
fake_http_client.set_generate_response("bar");
|
||||||
|
|
||||||
|
// Create global Ollama service
|
||||||
|
let service = cx.update(|cx| {
|
||||||
|
OllamaService::new(
|
||||||
|
fake_http_client.clone(),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.update(|cx| {
|
||||||
|
OllamaService::set_global(service.clone(), cx);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
let provider = cx.update(|cx| {
|
||||||
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||||
|
});
|
||||||
|
|
||||||
// Set up the editor with the Ollama provider
|
// Set up the editor with the Ollama provider
|
||||||
editor_cx.update_editor(|editor, window, cx| {
|
editor_cx.update_editor(|editor, window, cx| {
|
||||||
|
@ -464,9 +890,6 @@ mod tests {
|
||||||
|
|
||||||
editor_cx.set_state("fooˇ");
|
editor_cx.set_state("fooˇ");
|
||||||
|
|
||||||
// Configure completion response that extends the current text
|
|
||||||
fake_http_client.set_generate_response("bar");
|
|
||||||
|
|
||||||
// Trigger the completion through the provider
|
// Trigger the completion through the provider
|
||||||
let buffer =
|
let buffer =
|
||||||
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
||||||
|
|
|
@ -6,7 +6,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
||||||
|
|
||||||
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
||||||
use language_models::AllLanguageModelSettings;
|
use language_models::AllLanguageModelSettings;
|
||||||
use ollama::OllamaCompletionProvider;
|
use ollama::{OllamaCompletionProvider, OllamaService};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
||||||
|
@ -18,6 +18,11 @@ use zed_actions;
|
||||||
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
|
// Initialize global Ollama service
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
let ollama_service = OllamaService::new(client.http_client(), settings.api_url.clone(), cx);
|
||||||
|
OllamaService::set_global(ollama_service, cx);
|
||||||
|
|
||||||
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
|
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
|
||||||
cx.observe_new({
|
cx.observe_new({
|
||||||
let editors = editors.clone();
|
let editors = editors.clone();
|
||||||
|
@ -138,8 +143,13 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if provider == EditPredictionProvider::Ollama {
|
} else if provider == EditPredictionProvider::Ollama {
|
||||||
// Update Ollama providers when settings change but provider stays the same
|
// Update global Ollama service when settings change
|
||||||
update_ollama_providers(&editors, &client, user_store.clone(), cx);
|
let _settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
if let Some(service) = OllamaService::global(cx) {
|
||||||
|
service.update(cx, |service, cx| {
|
||||||
|
service.refresh_models(cx);
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -152,46 +162,6 @@ fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_ollama_providers(
|
|
||||||
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
|
||||||
client: &Arc<Client>,
|
|
||||||
user_store: Entity<UserStore>,
|
|
||||||
cx: &mut App,
|
|
||||||
) {
|
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
|
||||||
let _current_model = settings
|
|
||||||
.available_models
|
|
||||||
.first()
|
|
||||||
.map(|m| m.name.clone())
|
|
||||||
.unwrap_or_else(|| "codellama:7b".to_string());
|
|
||||||
|
|
||||||
for (editor, window) in editors.borrow().iter() {
|
|
||||||
_ = window.update(cx, |_window, window, cx| {
|
|
||||||
_ = editor.update(cx, |editor, cx| {
|
|
||||||
if let Some(provider) = editor.edit_prediction_provider() {
|
|
||||||
// Check if this is an Ollama provider by comparing names
|
|
||||||
if provider.name() == "ollama" {
|
|
||||||
// Recreate the provider with the new model
|
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
|
||||||
let _api_url = settings.api_url.clone();
|
|
||||||
|
|
||||||
// Get client from the registry context (need to pass it)
|
|
||||||
// For now, we'll trigger a full reassignment
|
|
||||||
assign_edit_prediction_provider(
|
|
||||||
editor,
|
|
||||||
EditPredictionProvider::Ollama,
|
|
||||||
&client,
|
|
||||||
user_store.clone(),
|
|
||||||
window,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn assign_edit_prediction_providers(
|
fn assign_edit_prediction_providers(
|
||||||
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
||||||
provider: EditPredictionProvider,
|
provider: EditPredictionProvider,
|
||||||
|
@ -333,27 +303,25 @@ fn assign_edit_prediction_provider(
|
||||||
}
|
}
|
||||||
EditPredictionProvider::Ollama => {
|
EditPredictionProvider::Ollama => {
|
||||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
|
||||||
// Only create provider if models are configured
|
|
||||||
// Note: Only FIM-capable models work with inline completion:
|
|
||||||
// ✓ Supported: qwen2.5-coder:*, starcoder2:*, codeqwen:*
|
|
||||||
// ✗ Not supported: codellama:*, deepseek-coder:*, llama3:*
|
|
||||||
if let Some(first_model) = settings.available_models.first() {
|
|
||||||
let api_url = settings.api_url.clone();
|
|
||||||
let model = first_model.name.clone();
|
|
||||||
|
|
||||||
// Get API key from environment variable only (credentials would require async handling)
|
|
||||||
let api_key = std::env::var("OLLAMA_API_KEY").ok();
|
let api_key = std::env::var("OLLAMA_API_KEY").ok();
|
||||||
|
|
||||||
let provider = cx.new(|_| {
|
// Get model from settings or use discovered models
|
||||||
OllamaCompletionProvider::new(client.http_client(), api_url, model, api_key)
|
let model = if let Some(first_model) = settings.available_models.first() {
|
||||||
});
|
first_model.name.clone()
|
||||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
} else if let Some(service) = OllamaService::global(cx) {
|
||||||
|
// Use first discovered model
|
||||||
|
service
|
||||||
|
.read(cx)
|
||||||
|
.available_models()
|
||||||
|
.first()
|
||||||
|
.map(|m| m.name.clone())
|
||||||
|
.unwrap_or_else(|| "qwen2.5-coder:3b".to_string())
|
||||||
} else {
|
} else {
|
||||||
// No models configured - don't create a provider
|
"qwen2.5-coder:3b".to_string()
|
||||||
// User will see "Configure Models" option in the completion menu
|
};
|
||||||
editor.set_edit_prediction_provider::<OllamaCompletionProvider>(None, window, cx);
|
|
||||||
}
|
let provider = cx.new(|cx| OllamaCompletionProvider::new(model, api_key, cx));
|
||||||
|
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue