Auto detect models WIP
This commit is contained in:
parent
5a1506c3c2
commit
0bdb42e65d
8 changed files with 952 additions and 128 deletions
3
Cargo.lock
generated
3
Cargo.lock
generated
|
@ -8363,6 +8363,7 @@ dependencies = [
|
|||
"fs",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
"gpui_tokio",
|
||||
"http_client",
|
||||
"indoc",
|
||||
"inline_completion",
|
||||
|
@ -8370,6 +8371,7 @@ dependencies = [
|
|||
"language_model",
|
||||
"language_models",
|
||||
"lsp",
|
||||
"ollama",
|
||||
"paths",
|
||||
"project",
|
||||
"regex",
|
||||
|
@ -20253,6 +20255,7 @@ dependencies = [
|
|||
"nix 0.29.0",
|
||||
"node_runtime",
|
||||
"notifications",
|
||||
"ollama",
|
||||
"onboarding",
|
||||
"outline",
|
||||
"outline_panel",
|
||||
|
|
|
@ -25,6 +25,7 @@ indoc.workspace = true
|
|||
inline_completion.workspace = true
|
||||
language.workspace = true
|
||||
language_models.workspace = true
|
||||
ollama.workspace = true
|
||||
|
||||
paths.workspace = true
|
||||
regex.workspace = true
|
||||
|
@ -48,6 +49,9 @@ http_client = { workspace = true, features = ["test-support"] }
|
|||
indoc.workspace = true
|
||||
language_model = { workspace = true, features = ["test-support"] }
|
||||
lsp = { workspace = true, features = ["test-support"] }
|
||||
ollama = { workspace = true, features = ["test-support"] }
|
||||
project = { workspace = true, features = ["test-support"] }
|
||||
serde_json.workspace = true
|
||||
settings = { workspace = true, features = ["test-support"] }
|
||||
theme = { workspace = true, features = ["test-support"] }
|
||||
gpui_tokio.workspace = true
|
||||
|
|
|
@ -21,6 +21,8 @@ use language::{
|
|||
};
|
||||
use language_models::AllLanguageModelSettings;
|
||||
|
||||
use ollama;
|
||||
|
||||
use paths;
|
||||
use regex::Regex;
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
|
@ -413,6 +415,10 @@ impl InlineCompletionButton {
|
|||
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
|
||||
.detach();
|
||||
|
||||
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||
cx.observe(&service, |_, _, cx| cx.notify()).detach();
|
||||
}
|
||||
|
||||
Self {
|
||||
editor_subscription: None,
|
||||
editor_enabled: None,
|
||||
|
@ -858,8 +864,30 @@ impl InlineCompletionButton {
|
|||
let settings = AllLanguageModelSettings::get_global(cx);
|
||||
let ollama_settings = &settings.ollama;
|
||||
|
||||
// Clone needed values to avoid borrowing issues
|
||||
let available_models = ollama_settings.available_models.clone();
|
||||
// Get models from both settings and global service discovery
|
||||
let mut available_models = ollama_settings.available_models.clone();
|
||||
|
||||
// Add discovered models from the global Ollama service
|
||||
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||
let discovered_models = service.read(cx).available_models();
|
||||
for model in discovered_models {
|
||||
// Convert from ollama::Model to language_models AvailableModel
|
||||
let available_model = language_models::provider::ollama::AvailableModel {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
keep_alive: model.keep_alive.clone(),
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_vision,
|
||||
supports_thinking: model.supports_thinking,
|
||||
};
|
||||
|
||||
// Add if not already in settings (settings take precedence)
|
||||
if !available_models.iter().any(|m| m.name == model.name) {
|
||||
available_models.push(available_model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// API URL configuration - only show if Ollama settings exist in the user's config
|
||||
let menu = if Self::ollama_settings_exist(cx) {
|
||||
|
@ -878,7 +906,7 @@ impl InlineCompletionButton {
|
|||
let menu = menu.separator().header("Available Models");
|
||||
|
||||
// Add each available model as a menu entry
|
||||
available_models.iter().fold(menu, |menu, model| {
|
||||
let menu = available_models.iter().fold(menu, |menu, model| {
|
||||
let model_name = model.display_name.as_ref().unwrap_or(&model.name);
|
||||
let is_current = available_models
|
||||
.first()
|
||||
|
@ -898,6 +926,13 @@ impl InlineCompletionButton {
|
|||
}
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
// Add refresh models option
|
||||
menu.separator().entry("Refresh Models", None, {
|
||||
move |_window, cx| {
|
||||
Self::refresh_ollama_models(cx);
|
||||
}
|
||||
})
|
||||
} else {
|
||||
menu.separator()
|
||||
|
@ -908,6 +943,11 @@ impl InlineCompletionButton {
|
|||
Self::open_ollama_settings(fs.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
.entry("Refresh Models", None, {
|
||||
move |_window, cx| {
|
||||
Self::refresh_ollama_models(cx);
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Use the common language settings menu
|
||||
|
@ -997,6 +1037,14 @@ impl InlineCompletionButton {
|
|||
});
|
||||
}
|
||||
|
||||
fn refresh_ollama_models(cx: &mut App) {
|
||||
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
|
||||
let editor = editor.read(cx);
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
|
@ -1188,3 +1236,359 @@ fn toggle_edit_prediction_mode(fs: Arc<dyn Fs>, mode: EditPredictionsMode, cx: &
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::TestAppContext;
|
||||
use http_client;
|
||||
use language_models::provider::ollama::AvailableModel;
|
||||
use ollama::{OllamaService, fake::FakeHttpClient};
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
gpui_tokio::init(cx);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
language::init(cx);
|
||||
language_settings::init(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ollama_menu_shows_discovered_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client with mock models response
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
// Mock /api/tags response
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "qwen2.5-coder:3b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "3B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "codellama:7b-code",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 2000000,
|
||||
"digest": "def456",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "codellama",
|
||||
"families": ["codellama"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
|
||||
// Mock /api/show response
|
||||
let capabilities = serde_json::json!({
|
||||
"capabilities": ["tools"]
|
||||
});
|
||||
fake_http_client.set_response("/api/show", capabilities.to_string());
|
||||
|
||||
// Create and set global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Wait for model discovery
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models are accessible through the service
|
||||
cx.update(|cx| {
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
let discovered_models = service.read(cx).available_models();
|
||||
assert_eq!(discovered_models.len(), 2);
|
||||
|
||||
let model_names: Vec<&str> =
|
||||
discovered_models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||
assert!(model_names.contains(&"codellama:7b-code"));
|
||||
} else {
|
||||
panic!("Global service should be available");
|
||||
}
|
||||
});
|
||||
|
||||
// Verify the global service has the expected models
|
||||
service.read_with(cx, |service, _| {
|
||||
let models = service.available_models();
|
||||
assert_eq!(models.len(), 2);
|
||||
|
||||
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||
assert!(model_names.contains(&"codellama:7b-code"));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ollama_menu_shows_service_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client with models
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "qwen2.5-coder:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
fake_http_client.set_response(
|
||||
"/api/show",
|
||||
serde_json::json!({"capabilities": []}).to_string(),
|
||||
);
|
||||
|
||||
// Create and set global service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(fake_http_client, "http://localhost:11434".to_string(), cx)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Test that discovered models are accessible
|
||||
cx.update(|cx| {
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
let discovered_models = service.read(cx).available_models();
|
||||
assert_eq!(discovered_models.len(), 1);
|
||||
assert_eq!(discovered_models[0].name, "qwen2.5-coder:7b");
|
||||
} else {
|
||||
panic!("Global service should be available");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ollama_menu_refreshes_on_service_update(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
// Initially empty models
|
||||
fake_http_client.set_response("/api/tags", serde_json::json!({"models": []}).to_string());
|
||||
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify the service subscription mechanism works by creating a button
|
||||
let _button = cx.update(|cx| {
|
||||
let fs = fs::FakeFs::new(cx.background_executor().clone());
|
||||
let user_store = cx.new(|cx| {
|
||||
client::UserStore::new(
|
||||
Arc::new(http_client::FakeHttpClient::create(|_| {
|
||||
Box::pin(async { Err(anyhow::anyhow!("not implemented")) })
|
||||
})),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let popover_handle = PopoverMenuHandle::default();
|
||||
|
||||
cx.new(|cx| InlineCompletionButton::new(fs, user_store, popover_handle, cx))
|
||||
});
|
||||
|
||||
// Verify initially no models
|
||||
service.read_with(cx, |service, _| {
|
||||
assert_eq!(service.available_models().len(), 0);
|
||||
});
|
||||
|
||||
// Update mock to return models
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "phi3:mini",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 500000,
|
||||
"digest": "xyz789",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "phi3",
|
||||
"families": ["phi3"],
|
||||
"parameter_size": "3.8B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
fake_http_client.set_response(
|
||||
"/api/show",
|
||||
serde_json::json!({"capabilities": []}).to_string(),
|
||||
);
|
||||
|
||||
// Trigger refresh
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models were refreshed
|
||||
service.read_with(cx, |service, _| {
|
||||
let models = service.available_models();
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].name, "phi3:mini");
|
||||
});
|
||||
|
||||
// The button should have been notified and will rebuild its menu with new models
|
||||
// when next requested (this tests the subscription mechanism)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_refresh_models_button_functionality(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
// Start with one model
|
||||
let initial_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "mistral:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "initial123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "mistral",
|
||||
"families": ["mistral"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", initial_response.to_string());
|
||||
fake_http_client.set_response(
|
||||
"/api/show",
|
||||
serde_json::json!({"capabilities": []}).to_string(),
|
||||
);
|
||||
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify initial model
|
||||
service.read_with(cx, |service, _| {
|
||||
assert_eq!(service.available_models().len(), 1);
|
||||
assert_eq!(service.available_models()[0].name, "mistral:7b");
|
||||
});
|
||||
|
||||
// Update mock to simulate new model available
|
||||
let updated_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "mistral:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "initial123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "mistral",
|
||||
"families": ["mistral"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "gemma2:9b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 2000000,
|
||||
"digest": "new456",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "gemma2",
|
||||
"families": ["gemma2"],
|
||||
"parameter_size": "9B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", updated_response.to_string());
|
||||
|
||||
// Simulate clicking "Refresh Models" button
|
||||
cx.update(|cx| {
|
||||
InlineCompletionButton::refresh_ollama_models(cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models were refreshed
|
||||
service.read_with(cx, |service, _| {
|
||||
let models = service.available_models();
|
||||
assert_eq!(models.len(), 2);
|
||||
|
||||
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"mistral:7b"));
|
||||
assert!(model_names.contains(&"gemma2:9b"));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use anyhow::{Result, anyhow};
|
||||
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use futures::{Stream, TryFutureExt, stream};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
|
||||
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
|
@ -141,6 +141,29 @@ impl State {
|
|||
}
|
||||
|
||||
impl OllamaLanguageModelProvider {
|
||||
pub fn global(cx: &App) -> Option<Entity<Self>> {
|
||||
cx.try_global::<GlobalOllamaLanguageModelProvider>()
|
||||
.map(|provider| provider.0.clone())
|
||||
}
|
||||
|
||||
pub fn set_global(provider: Entity<Self>, cx: &mut App) {
|
||||
cx.set_global(GlobalOllamaLanguageModelProvider(provider));
|
||||
}
|
||||
|
||||
pub fn available_models_for_completion(&self, cx: &App) -> Vec<ollama::Model> {
|
||||
self.state.read(cx).available_models.clone()
|
||||
}
|
||||
|
||||
pub fn http_client(&self) -> Arc<dyn HttpClient> {
|
||||
self.http_client.clone()
|
||||
}
|
||||
|
||||
pub fn refresh_models(&self, cx: &mut App) {
|
||||
self.state.update(cx, |state, cx| {
|
||||
state.restart_fetch_models_task(cx);
|
||||
});
|
||||
}
|
||||
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
|
||||
let this = Self {
|
||||
http_client: http_client.clone(),
|
||||
|
@ -667,6 +690,10 @@ impl Render for ConfigurationView {
|
|||
}
|
||||
}
|
||||
|
||||
struct GlobalOllamaLanguageModelProvider(Entity<OllamaLanguageModelProvider>);
|
||||
|
||||
impl Global for GlobalOllamaLanguageModelProvider {}
|
||||
|
||||
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
|
||||
ollama::OllamaTool::Function {
|
||||
function: OllamaFunctionTool {
|
||||
|
|
|
@ -30,6 +30,7 @@ language.workspace = true
|
|||
log.workspace = true
|
||||
|
||||
project.workspace = true
|
||||
settings.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
|
|
|
@ -541,14 +541,8 @@ pub mod fake {
|
|||
) {
|
||||
let fake_client = std::sync::Arc::new(FakeHttpClient::new());
|
||||
|
||||
let provider = cx.new(|_| {
|
||||
OllamaCompletionProvider::new(
|
||||
fake_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwencoder".to_string(),
|
||||
None,
|
||||
)
|
||||
});
|
||||
let provider =
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwencoder".to_string(), None, cx));
|
||||
|
||||
(provider, fake_client)
|
||||
}
|
||||
|
|
|
@ -1,43 +1,172 @@
|
|||
use crate::{GenerateOptions, GenerateRequest, generate};
|
||||
use crate::{GenerateOptions, GenerateRequest, Model, generate};
|
||||
use anyhow::{Context as AnyhowContext, Result};
|
||||
use futures::StreamExt;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
use gpui::{App, Context, Entity, EntityId, Task};
|
||||
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task};
|
||||
use http_client::HttpClient;
|
||||
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
||||
use language::{Anchor, Buffer, ToOffset};
|
||||
use settings::SettingsStore;
|
||||
|
||||
use project::Project;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
|
||||
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
||||
|
||||
pub struct OllamaCompletionProvider {
|
||||
// Global Ollama service for managing models across all providers
|
||||
pub struct OllamaService {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
api_url: String,
|
||||
available_models: Vec<Model>,
|
||||
fetch_models_task: Option<Task<Result<()>>>,
|
||||
_settings_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl OllamaService {
|
||||
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, cx: &mut App) -> Entity<Self> {
|
||||
cx.new(|cx| {
|
||||
let subscription = cx.observe_global::<SettingsStore>({
|
||||
move |this: &mut OllamaService, cx| {
|
||||
this.restart_fetch_models_task(cx);
|
||||
}
|
||||
});
|
||||
|
||||
let mut service = Self {
|
||||
http_client,
|
||||
api_url,
|
||||
available_models: Vec::new(),
|
||||
fetch_models_task: None,
|
||||
_settings_subscription: subscription,
|
||||
};
|
||||
|
||||
service.restart_fetch_models_task(cx);
|
||||
service
|
||||
})
|
||||
}
|
||||
|
||||
pub fn global(cx: &App) -> Option<Entity<Self>> {
|
||||
cx.try_global::<GlobalOllamaService>()
|
||||
.map(|service| service.0.clone())
|
||||
}
|
||||
|
||||
pub fn set_global(service: Entity<Self>, cx: &mut App) {
|
||||
cx.set_global(GlobalOllamaService(service));
|
||||
}
|
||||
|
||||
pub fn available_models(&self) -> &[Model] {
|
||||
&self.available_models
|
||||
}
|
||||
|
||||
pub fn refresh_models(&mut self, cx: &mut Context<Self>) {
|
||||
self.restart_fetch_models_task(cx);
|
||||
}
|
||||
|
||||
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
|
||||
self.fetch_models_task = Some(self.fetch_models(cx));
|
||||
}
|
||||
|
||||
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
||||
let http_client = Arc::clone(&self.http_client);
|
||||
let api_url = self.api_url.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
|
||||
Ok(models) => models,
|
||||
Err(_) => return Ok(()), // Silently fail and use empty list
|
||||
};
|
||||
|
||||
let tasks = models
|
||||
.into_iter()
|
||||
// Filter out embedding models
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| {
|
||||
let http_client = Arc::clone(&http_client);
|
||||
let api_url = api_url.clone();
|
||||
async move {
|
||||
let name = model.name.as_str();
|
||||
let capabilities =
|
||||
crate::show_model(http_client.as_ref(), &api_url, name).await?;
|
||||
let ollama_model = Model::new(
|
||||
name,
|
||||
None,
|
||||
None,
|
||||
Some(capabilities.supports_tools()),
|
||||
Some(capabilities.supports_vision()),
|
||||
Some(capabilities.supports_thinking()),
|
||||
);
|
||||
Ok(ollama_model)
|
||||
}
|
||||
});
|
||||
|
||||
// Rate-limit capability fetches
|
||||
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
|
||||
.buffer_unordered(5)
|
||||
.collect::<Vec<Result<_>>>()
|
||||
.await
|
||||
.into_iter()
|
||||
.collect::<Result<Vec<_>>>()
|
||||
.unwrap_or_default();
|
||||
|
||||
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.available_models = ollama_models;
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct GlobalOllamaService(Entity<OllamaService>);
|
||||
|
||||
impl Global for GlobalOllamaService {}
|
||||
|
||||
pub struct OllamaCompletionProvider {
|
||||
model: String,
|
||||
buffer_id: Option<EntityId>,
|
||||
file_extension: Option<String>,
|
||||
current_completion: Option<String>,
|
||||
pending_refresh: Option<Task<Result<()>>>,
|
||||
api_key: Option<String>,
|
||||
_service_subscription: Option<Subscription>,
|
||||
}
|
||||
|
||||
impl OllamaCompletionProvider {
|
||||
pub fn new(
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
api_url: String,
|
||||
model: String,
|
||||
api_key: Option<String>,
|
||||
) -> Self {
|
||||
pub fn new(model: String, api_key: Option<String>, cx: &mut Context<Self>) -> Self {
|
||||
let subscription = if let Some(service) = OllamaService::global(cx) {
|
||||
Some(cx.observe(&service, |_this, _service, cx| {
|
||||
cx.notify();
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Self {
|
||||
http_client,
|
||||
api_url,
|
||||
model,
|
||||
buffer_id: None,
|
||||
file_extension: None,
|
||||
current_completion: None,
|
||||
pending_refresh: None,
|
||||
api_key,
|
||||
_service_subscription: subscription,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn available_models(&self, cx: &App) -> Vec<Model> {
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
service.read(cx).available_models().to_vec()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn refresh_models(&self, cx: &mut App) {
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,14 +233,28 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
|
||||
fn refresh(
|
||||
&mut self,
|
||||
_project: Option<Entity<Project>>,
|
||||
project: Option<Entity<Project>>,
|
||||
buffer: Entity<Buffer>,
|
||||
cursor_position: Anchor,
|
||||
debounce: bool,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = self.api_url.clone();
|
||||
// Get API settings from the global Ollama service or fallback
|
||||
let (http_client, api_url) = if let Some(service) = OllamaService::global(cx) {
|
||||
let service_ref = service.read(cx);
|
||||
(service_ref.http_client.clone(), service_ref.api_url.clone())
|
||||
} else {
|
||||
// Fallback if global service isn't available
|
||||
(
|
||||
project
|
||||
.as_ref()
|
||||
.map(|p| p.read(cx).client().http_client() as Arc<dyn HttpClient>)
|
||||
.unwrap_or_else(|| {
|
||||
Arc::new(http_client::BlockedHttpClient::new()) as Arc<dyn HttpClient>
|
||||
}),
|
||||
crate::OLLAMA_API_URL.to_string(),
|
||||
)
|
||||
};
|
||||
|
||||
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
||||
if debounce {
|
||||
|
@ -156,14 +299,17 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
|
||||
let response = generate(http_client.as_ref(), &api_url, api_key, request)
|
||||
.await
|
||||
.context("Failed to get completion from Ollama")?;
|
||||
.context("Failed to get completion from Ollama");
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.pending_refresh = None;
|
||||
if !response.response.trim().is_empty() {
|
||||
this.current_completion = Some(response.response);
|
||||
} else {
|
||||
this.current_completion = None;
|
||||
match response {
|
||||
Ok(response) if !response.response.trim().is_empty() => {
|
||||
this.current_completion = Some(response.response);
|
||||
}
|
||||
_ => {
|
||||
this.current_completion = None;
|
||||
}
|
||||
}
|
||||
cx.notify();
|
||||
})?;
|
||||
|
@ -248,7 +394,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::fake::Ollama;
|
||||
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
|
||||
|
@ -269,31 +414,238 @@ mod tests {
|
|||
}
|
||||
|
||||
/// Test the complete Ollama completion flow from refresh to suggestion
|
||||
#[test]
|
||||
fn test_get_stop_tokens() {
|
||||
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
#[gpui::test]
|
||||
fn test_get_stop_tokens(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Test CodeLlama code model gets stop tokens
|
||||
let codellama_provider = OllamaCompletionProvider::new(
|
||||
http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"codellama:7b-code".to_string(),
|
||||
None,
|
||||
);
|
||||
let codellama_provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx))
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
codellama_provider.get_stop_tokens(),
|
||||
Some(vec!["<EOT>".to_string()])
|
||||
);
|
||||
codellama_provider.read_with(cx, |provider, _| {
|
||||
assert_eq!(provider.get_stop_tokens(), Some(vec!["<EOT>".to_string()]));
|
||||
});
|
||||
|
||||
// Test non-CodeLlama model doesn't get stop tokens
|
||||
let qwen_provider = OllamaCompletionProvider::new(
|
||||
http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
"qwen2.5-coder:3b".to_string(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(qwen_provider.get_stop_tokens(), None);
|
||||
let qwen_provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
qwen_provider.read_with(cx, |provider, _| {
|
||||
assert_eq!(provider.get_stop_tokens(), None);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_model_discovery(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
|
||||
// Mock /api/tags response (list models)
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "qwen2.5-coder:3b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "3B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "codellama:7b-code",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 2000000,
|
||||
"digest": "def456",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "codellama",
|
||||
"families": ["codellama"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "nomic-embed-text",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 500000,
|
||||
"digest": "ghi789",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "nomic-embed",
|
||||
"families": ["nomic-embed"],
|
||||
"parameter_size": "137M",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
|
||||
// Mock /api/show responses for model capabilities
|
||||
let qwen_capabilities = serde_json::json!({
|
||||
"capabilities": ["tools", "thinking"]
|
||||
});
|
||||
|
||||
let _codellama_capabilities = serde_json::json!({
|
||||
"capabilities": []
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/show", qwen_capabilities.to_string());
|
||||
|
||||
// Create global Ollama service for testing
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Set it as global
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Create completion provider
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
// Wait for model discovery to complete
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models were discovered through the global provider
|
||||
provider.read_with(cx, |provider, cx| {
|
||||
let models = provider.available_models(cx);
|
||||
assert_eq!(models.len(), 2); // Should exclude nomic-embed-text
|
||||
|
||||
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"codellama:7b-code"));
|
||||
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||
assert!(!model_names.contains(&"nomic-embed-text"));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_model_discovery_api_failure(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client that returns errors
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
fake_http_client.set_error("Connection refused");
|
||||
|
||||
// Create global Ollama service that will fail
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Create completion provider
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
// Wait for model discovery to complete (with failure)
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify graceful handling - should have empty model list
|
||||
provider.read_with(cx, |provider, cx| {
|
||||
let models = provider.available_models(cx);
|
||||
assert_eq!(models.len(), 0);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_refresh_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
|
||||
// Initially return empty model list
|
||||
let empty_response = serde_json::json!({"models": []});
|
||||
fake_http_client.set_response("/api/tags", empty_response.to_string());
|
||||
|
||||
// Create global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify initially empty
|
||||
provider.read_with(cx, |provider, cx| {
|
||||
assert_eq!(provider.available_models(cx).len(), 0);
|
||||
});
|
||||
|
||||
// Update mock to return models
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "qwen2.5-coder:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
|
||||
let capabilities = serde_json::json!({
|
||||
"capabilities": ["tools", "thinking"]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/show", capabilities.to_string());
|
||||
|
||||
// Trigger refresh
|
||||
provider.update(cx, |provider, cx| {
|
||||
provider.refresh_models(cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models were refreshed
|
||||
provider.read_with(cx, |provider, cx| {
|
||||
let models = provider.available_models(cx);
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].name, "qwen2.5-coder:7b");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
|
@ -306,12 +658,28 @@ mod tests {
|
|||
buffer.anchor_before(11) // Position in the middle of the function
|
||||
});
|
||||
|
||||
// Create Ollama provider with fake HTTP client
|
||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
||||
|
||||
// Configure mock HTTP response
|
||||
// Create fake HTTP client and set up global service
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
fake_http_client.set_generate_response("println!(\"Hello\");");
|
||||
|
||||
// Create global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Create provider
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
// Trigger completion refresh (no debounce for test speed)
|
||||
provider.update(cx, |provider, cx| {
|
||||
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
|
||||
|
@ -363,7 +731,26 @@ mod tests {
|
|||
buffer.anchor_after(16) // After "vec"
|
||||
});
|
||||
|
||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
||||
// Create fake HTTP client and set up global service
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
|
||||
// Create global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Create provider
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
// Configure response that starts with what user already typed
|
||||
fake_http_client.set_generate_response("vec![1, 2, 3]");
|
||||
|
@ -393,7 +780,28 @@ mod tests {
|
|||
init_test(cx);
|
||||
|
||||
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
||||
|
||||
// Create fake HTTP client and set up global service
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
fake_http_client.set_generate_response("vec![hello, world]");
|
||||
|
||||
// Create global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Create provider
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
// Set up the editor with the Ollama provider
|
||||
editor_cx.update_editor(|editor, window, cx| {
|
||||
|
@ -403,9 +811,6 @@ mod tests {
|
|||
// Set initial state
|
||||
editor_cx.set_state("let items = ˇ");
|
||||
|
||||
// Configure a multi-word completion
|
||||
fake_http_client.set_generate_response("vec![hello, world]");
|
||||
|
||||
// Trigger the completion through the provider
|
||||
let buffer =
|
||||
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
||||
|
@ -455,7 +860,28 @@ mod tests {
|
|||
init_test(cx);
|
||||
|
||||
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
||||
let (provider, fake_http_client) = Ollama::fake(cx);
|
||||
|
||||
// Create fake HTTP client and set up global service
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
fake_http_client.set_generate_response("bar");
|
||||
|
||||
// Create global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Create provider
|
||||
let provider = cx.update(|cx| {
|
||||
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
||||
});
|
||||
|
||||
// Set up the editor with the Ollama provider
|
||||
editor_cx.update_editor(|editor, window, cx| {
|
||||
|
@ -464,9 +890,6 @@ mod tests {
|
|||
|
||||
editor_cx.set_state("fooˇ");
|
||||
|
||||
// Configure completion response that extends the current text
|
||||
fake_http_client.set_generate_response("bar");
|
||||
|
||||
// Trigger the completion through the provider
|
||||
let buffer =
|
||||
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
||||
|
|
|
@ -6,7 +6,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
|||
|
||||
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
||||
use language_models::AllLanguageModelSettings;
|
||||
use ollama::OllamaCompletionProvider;
|
||||
use ollama::{OllamaCompletionProvider, OllamaService};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
||||
|
@ -18,6 +18,11 @@ use zed_actions;
|
|||
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
||||
|
||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
// Initialize global Ollama service
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let ollama_service = OllamaService::new(client.http_client(), settings.api_url.clone(), cx);
|
||||
OllamaService::set_global(ollama_service, cx);
|
||||
|
||||
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
|
||||
cx.observe_new({
|
||||
let editors = editors.clone();
|
||||
|
@ -138,8 +143,13 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
|||
}
|
||||
}
|
||||
} else if provider == EditPredictionProvider::Ollama {
|
||||
// Update Ollama providers when settings change but provider stays the same
|
||||
update_ollama_providers(&editors, &client, user_store.clone(), cx);
|
||||
// Update global Ollama service when settings change
|
||||
let _settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -152,46 +162,6 @@ fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
|
|||
}
|
||||
}
|
||||
|
||||
fn update_ollama_providers(
|
||||
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
||||
client: &Arc<Client>,
|
||||
user_store: Entity<UserStore>,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let _current_model = settings
|
||||
.available_models
|
||||
.first()
|
||||
.map(|m| m.name.clone())
|
||||
.unwrap_or_else(|| "codellama:7b".to_string());
|
||||
|
||||
for (editor, window) in editors.borrow().iter() {
|
||||
_ = window.update(cx, |_window, window, cx| {
|
||||
_ = editor.update(cx, |editor, cx| {
|
||||
if let Some(provider) = editor.edit_prediction_provider() {
|
||||
// Check if this is an Ollama provider by comparing names
|
||||
if provider.name() == "ollama" {
|
||||
// Recreate the provider with the new model
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let _api_url = settings.api_url.clone();
|
||||
|
||||
// Get client from the registry context (need to pass it)
|
||||
// For now, we'll trigger a full reassignment
|
||||
assign_edit_prediction_provider(
|
||||
editor,
|
||||
EditPredictionProvider::Ollama,
|
||||
&client,
|
||||
user_store.clone(),
|
||||
window,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_edit_prediction_providers(
|
||||
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
|
||||
provider: EditPredictionProvider,
|
||||
|
@ -333,27 +303,25 @@ fn assign_edit_prediction_provider(
|
|||
}
|
||||
EditPredictionProvider::Ollama => {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let api_key = std::env::var("OLLAMA_API_KEY").ok();
|
||||
|
||||
// Only create provider if models are configured
|
||||
// Note: Only FIM-capable models work with inline completion:
|
||||
// ✓ Supported: qwen2.5-coder:*, starcoder2:*, codeqwen:*
|
||||
// ✗ Not supported: codellama:*, deepseek-coder:*, llama3:*
|
||||
if let Some(first_model) = settings.available_models.first() {
|
||||
let api_url = settings.api_url.clone();
|
||||
let model = first_model.name.clone();
|
||||
|
||||
// Get API key from environment variable only (credentials would require async handling)
|
||||
let api_key = std::env::var("OLLAMA_API_KEY").ok();
|
||||
|
||||
let provider = cx.new(|_| {
|
||||
OllamaCompletionProvider::new(client.http_client(), api_url, model, api_key)
|
||||
});
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
// Get model from settings or use discovered models
|
||||
let model = if let Some(first_model) = settings.available_models.first() {
|
||||
first_model.name.clone()
|
||||
} else if let Some(service) = OllamaService::global(cx) {
|
||||
// Use first discovered model
|
||||
service
|
||||
.read(cx)
|
||||
.available_models()
|
||||
.first()
|
||||
.map(|m| m.name.clone())
|
||||
.unwrap_or_else(|| "qwen2.5-coder:3b".to_string())
|
||||
} else {
|
||||
// No models configured - don't create a provider
|
||||
// User will see "Configure Models" option in the completion menu
|
||||
editor.set_edit_prediction_provider::<OllamaCompletionProvider>(None, window, cx);
|
||||
}
|
||||
"qwen2.5-coder:3b".to_string()
|
||||
};
|
||||
|
||||
let provider = cx.new(|cx| OllamaCompletionProvider::new(model, api_key, cx));
|
||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue