diff --git a/Cargo.lock b/Cargo.lock index 42649b137f..57dac2f3c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9143,6 +9143,7 @@ dependencies = [ "credentials_provider", "deepseek", "editor", + "fs", "futures 0.3.31", "google_ai", "gpui", diff --git a/crates/language_models/Cargo.toml b/crates/language_models/Cargo.toml index b5bfb870f6..ad4e593d4f 100644 --- a/crates/language_models/Cargo.toml +++ b/crates/language_models/Cargo.toml @@ -29,6 +29,7 @@ copilot.workspace = true credentials_provider.workspace = true deepseek = { workspace = true, features = ["schemars"] } editor.workspace = true +fs.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true diff --git a/crates/language_models/src/provider/ollama.rs b/crates/language_models/src/provider/ollama.rs index 3f2d47fba3..b8b55e252e 100644 --- a/crates/language_models/src/provider/ollama.rs +++ b/crates/language_models/src/provider/ollama.rs @@ -1,4 +1,6 @@ -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; +use credentials_provider::CredentialsProvider; +use fs::Fs; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{Stream, TryFutureExt, stream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; @@ -10,17 +12,20 @@ use language_model::{ LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, }; +use menu; use ollama::{ - ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, - OllamaToolCall, get_models, show_model, stream_chat_completion, + ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OLLAMA_API_KEY_VAR, + OLLAMA_API_URL, OllamaFunctionTool, OllamaToolCall, get_models, show_model, + stream_chat_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use settings::{Settings, SettingsStore}; +use settings::{Settings, SettingsStore, update_settings_file}; use std::pin::Pin; use std::sync::atomic::{AtomicU64, Ordering}; use std::{collections::HashMap, sync::Arc}; -use ui::{ButtonLike, Indicator, List, prelude::*}; +use ui::{ButtonLike, ElevationIndex, Indicator, List, Tooltip, prelude::*}; +use ui_input::SingleLineInput; use util::ResultExt; use crate::AllLanguageModelSettings; @@ -67,21 +72,61 @@ pub struct State { available_models: Vec, fetch_model_task: Option>>, _subscription: Subscription, + api_key: Option, + api_key_from_env: bool, } impl State { fn is_authenticated(&self) -> bool { - !self.available_models.is_empty() + !self.available_models.is_empty() || self.api_key.is_some() + } + + fn reset_api_key(&self, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .ollama + .api_url + .clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .delete_credentials(&api_url, cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = None; + this.api_key_from_env = false; + cx.notify(); + }) + }) + } + + fn set_api_key(&mut self, api_key: String, cx: &mut Context) -> Task> { + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .ollama + .api_url + .clone(); + cx.spawn(async move |this, cx| { + credentials_provider + .write_credentials(&api_url, "Bearer", api_key.as_bytes(), cx) + .await + .log_err(); + this.update(cx, |this, cx| { + this.api_key = Some(api_key); + cx.notify(); + }) + }) } fn fetch_models(&mut self, cx: &mut Context) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = Arc::clone(&self.http_client); let api_url = settings.api_url.clone(); + let api_key = self.api_key.clone(); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(async move |this, cx| { - let models = get_models(http_client.as_ref(), &api_url, None).await?; + let models = get_models(http_client.as_ref(), &api_url, api_key.clone(), None).await?; let tasks = models .into_iter() @@ -92,9 +137,11 @@ impl State { .map(|model| { let http_client = Arc::clone(&http_client); let api_url = api_url.clone(); + let api_key = api_key.clone(); async move { let name = model.name.as_str(); - let capabilities = show_model(http_client.as_ref(), &api_url, name).await?; + let capabilities = + show_model(http_client.as_ref(), &api_url, api_key, name).await?; let ollama_model = ollama::Model::new( name, None, @@ -135,8 +182,38 @@ impl State { return Task::ready(Ok(())); } + let credentials_provider = ::global(cx); + let api_url = AllLanguageModelSettings::get_global(cx) + .ollama + .api_url + .clone(); let fetch_models_task = self.fetch_models(cx); - cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?)) + cx.spawn(async move |this, cx| { + let (api_key, from_env) = if let Ok(api_key) = std::env::var(OLLAMA_API_KEY_VAR) { + (Some(api_key), true) + } else { + match credentials_provider.read_credentials(&api_url, cx).await { + Ok(Some((_, api_key))) => ( + Some(String::from_utf8(api_key).context("invalid Ollama API key")?), + false, + ), + Ok(None) => (None, false), + Err(_) => (None, false), + } + }; + + this.update(cx, |this, cx| { + this.api_key = api_key; + this.api_key_from_env = from_env; + cx.notify(); + })?; + + // Always try to fetch models - if no API key is needed (local Ollama), it will work + // If API key is needed and provided, it will work + // If API key is needed and not provided, it will fail gracefully + let _ = fetch_models_task.await; + Ok(()) + }) } } @@ -162,6 +239,8 @@ impl OllamaLanguageModelProvider { available_models: Default::default(), fetch_model_task: None, _subscription: subscription, + api_key: None, + api_key_from_env: false, } }), }; @@ -240,6 +319,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { model, http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), + state: self.state.clone(), }) as Arc }) .collect::>(); @@ -267,7 +347,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider { } fn reset_credentials(&self, cx: &mut App) -> Task> { - self.state.update(cx, |state, cx| state.fetch_models(cx)) + self.state.update(cx, |state, cx| state.reset_api_key(cx)) } } @@ -276,6 +356,7 @@ pub struct OllamaLanguageModel { model: ollama::Model, http_client: Arc, request_limiter: RateLimiter, + state: gpui::Entity, } impl OllamaLanguageModel { @@ -424,15 +505,19 @@ impl LanguageModel for OllamaLanguageModel { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); - let Ok(api_url) = cx.update(|cx| { + let Ok((api_url, api_key)) = cx.update(|cx| { let settings = &AllLanguageModelSettings::get_global(cx).ollama; - settings.api_url.clone() + ( + settings.api_url.clone(), + self.state.read(cx).api_key.clone(), + ) }) else { return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let future = self.request_limiter.stream(async move { - let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; + let stream = + stream_chat_completion(http_client.as_ref(), &api_url, api_key, request).await?; let stream = map_to_language_model_completion_events(stream); Ok(stream) }); @@ -541,12 +626,44 @@ fn map_to_language_model_completion_events( } struct ConfigurationView { + api_key_editor: gpui::Entity, + api_url_editor: gpui::Entity, state: gpui::Entity, loading_models_task: Option>, } impl ConfigurationView { pub fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { + let api_key_editor = cx.new(|cx| { + SingleLineInput::new( + window, + cx, + "ol-000000000000000000000000000000000000000000000000", + ) + .label("API key") + }); + + let api_url = AllLanguageModelSettings::get_global(cx) + .ollama + .api_url + .clone(); + + let api_url_editor = cx.new(|cx| { + let input = SingleLineInput::new(window, cx, OLLAMA_API_URL).label("API URL"); + + if !api_url.is_empty() { + input.editor.update(cx, |editor, cx| { + editor.set_text(&*api_url, window, cx); + }); + } + input + }); + + cx.observe(&state, |_, _, cx| { + cx.notify(); + }) + .detach(); + let loading_models_task = Some(cx.spawn_in(window, { let state = state.clone(); async move |this, cx| { @@ -565,6 +682,8 @@ impl ConfigurationView { })); Self { + api_key_editor, + api_url_editor, state, loading_models_task, } @@ -575,103 +694,348 @@ impl ConfigurationView { .update(cx, |state, cx| state.fetch_models(cx)) .detach_and_log_err(cx); } + + fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context) { + let api_key = self + .api_key_editor + .read(cx) + .editor() + .read(cx) + .text(cx) + .trim() + .to_string(); + + // Don't proceed if no API key is provided and we're not authenticated + if api_key.is_empty() && !self.state.read(cx).is_authenticated() { + return; + } + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + if !api_key.is_empty() { + state + .update(cx, |state, cx| state.set_api_key(api_key, cx))? + .await + } else { + Ok(()) + } + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context) { + self.api_key_editor.update(cx, |input, cx| { + input.editor.update(cx, |editor, cx| { + editor.set_text("", window, cx); + }); + }); + + let state = self.state.clone(); + cx.spawn_in(window, async move |_, cx| { + state.update(cx, |state, cx| state.reset_api_key(cx))?.await + }) + .detach_and_log_err(cx); + + cx.notify(); + } + + fn should_render_api_key_editor(&self, cx: &mut Context) -> bool { + self.state.read(cx).api_key.is_none() + } + + fn save_api_url(&mut self, cx: &mut Context) { + let api_url = self + .api_url_editor + .read(cx) + .editor() + .read(cx) + .text(cx) + .trim() + .to_string(); + + let current_url = AllLanguageModelSettings::get_global(cx) + .ollama + .api_url + .clone(); + + let effective_current_url = if current_url.is_empty() { + OLLAMA_API_URL + } else { + ¤t_url + }; + + if !api_url.is_empty() && api_url != effective_current_url { + let fs = ::global(cx); + update_settings_file::(fs, cx, move |settings, _| { + if let Some(settings) = settings.ollama.as_mut() { + settings.api_url = Some(api_url); + } else { + settings.ollama = Some(crate::settings::OllamaSettingsContent { + api_url: Some(api_url), + available_models: None, + }); + } + }); + } + } + + fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context) { + self.api_url_editor.update(cx, |input, cx| { + input.editor.update(cx, |editor, cx| { + editor.set_text(OLLAMA_API_URL, window, cx); + }); + }); + let fs = ::global(cx); + update_settings_file::(fs, cx, |settings, _cx| { + if let Some(settings) = settings.ollama.as_mut() { + settings.api_url = None; + } + }); + cx.notify(); + } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let is_authenticated = self.state.read(cx).is_authenticated(); - - let ollama_intro = - "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama."; + let env_var_set = self.state.read(cx).api_key_from_env; if self.loading_models_task.is_some() { - div().child(Label::new("Loading models...")).into_any() - } else { - v_flex() - .gap_2() + div() .child( - v_flex().gap_1().child(Label::new(ollama_intro)).child( - List::new() - .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant.")) - .child(InstructionListItem::text_only( - "Once installed, try `ollama run llama3.2`", - )), - ), - ) - .child( - h_flex() - .w_full() - .justify_between() + v_flex() .gap_2() .child( h_flex() - .w_full() .gap_2() - .map(|this| { - if is_authenticated { - this.child( - Button::new("ollama-site", "Ollama") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) - .into_any_element(), - ) - } else { - this.child( - Button::new( - "download_ollama_button", - "Download Ollama", - ) - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _, cx| { - cx.open_url(OLLAMA_DOWNLOAD_URL) - }) - .into_any_element(), - ) - } - }) - .child( - Button::new("view-models", "View All Models") - .style(ButtonStyle::Subtle) - .icon(IconName::ArrowUpRight) - .icon_size(IconSize::Small) - .icon_color(Color::Muted) - .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), - ), + .child(Indicator::dot().color(Color::Accent)) + .child(Label::new("Connecting to Ollama...")), ) - .map(|this| { - if is_authenticated { - this.child( - ButtonLike::new("connected") - .disabled(true) - .cursor_style(gpui::CursorStyle::Arrow) - .child( - h_flex() - .gap_2() - .child(Indicator::dot().color(Color::Success)) - .child(Label::new("Connected")) - .into_any_element(), - ), - ) - } else { - this.child( - Button::new("retry_ollama_models", "Connect") - .icon_position(IconPosition::Start) - .icon_size(IconSize::XSmall) - .icon(IconName::PlayFilled) - .on_click(cx.listener(move |this, _, _, cx| { - this.retry_connection(cx) - })), - ) - } - }) + .child( + Label::new("Checking for available models and server status") + .size(LabelSize::Small) + .color(Color::Muted), + ), ) .into_any() + } else { + v_flex() + .child( + if !is_authenticated { + v_flex().child( + Label::new("Run powerful language models locally on your machine with Ollama. Get started with Llama 3.3, Mistral, Gemma 2, and hundreds of other models.") + .size(LabelSize::Small) + .color(Color::Muted) + ) + .v_flex() + .gap_2() + .child( + Label::new("Getting Started") + .size(LabelSize::Small) + .color(Color::Default) + ) + .child( + List::new() + .child(InstructionListItem::text_only("1. Download and install Ollama from ollama.com")) + .child(InstructionListItem::text_only("2. Start Ollama and download a model: `ollama run gpt-oss:20b`")) + .child(InstructionListItem::text_only("3. Click 'Connect' below to start using Ollama in Zed")) + ).child( + Label::new("API Keys and API URLs are optional, Zed will default to local ollama usage.") + .size(LabelSize::Small) + .color(Color::Muted) + ) + .into_any() + } else { + div().into_any() + } + ) + .child( + if self.should_render_api_key_editor(cx) { + v_flex() + .on_action(cx.listener(Self::save_api_key)) + .child(self.api_key_editor.clone()) + .child( + Label::new( + format!("You can also assign the {OLLAMA_API_KEY_VAR} environment variable and restart Zed.") + ) + .size(LabelSize::XSmall) + .color(Color::Muted), + ).into_any() + } else { + v_flex() + .child( + h_flex() + .p_3() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().elevated_surface_background) + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child( + Label::new( + if env_var_set { + format!("API key set in {OLLAMA_API_KEY_VAR} environment variable.") + } else { + "API key configured.".to_string() + } + ) + ) + ) + .child( + Button::new("reset-api-key", "Reset API Key") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .when(env_var_set, |this| { + this.tooltip(Tooltip::text(format!("To reset your API key, unset the {OLLAMA_API_KEY_VAR} environment variable."))) + }) + .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx))), + ) + ) + .into_any() + } + ) + .child({ + let custom_api_url_set = AllLanguageModelSettings::get_global(cx).ollama.api_url != OLLAMA_API_URL; + + if custom_api_url_set { + v_flex() + .gap_2() + .child( + h_flex() + .p_3() + .justify_between() + .rounded_md() + .border_1() + .border_color(cx.theme().colors().border) + .bg(cx.theme().colors().elevated_surface_background) + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child( + v_flex() + .gap_1() + .child( + Label::new( + format!("API URL configured. {}", &AllLanguageModelSettings::get_global(cx).ollama.api_url) + ) + ) + ) + ) + .child( + Button::new("reset-api-url", "Reset API URL") + .label_size(LabelSize::Small) + .icon(IconName::Undo) + .icon_size(IconSize::Small) + .icon_position(IconPosition::Start) + .layer(ElevationIndex::ModalSurface) + .on_click(cx.listener(|this, _, window, cx| { + this.reset_api_url(window, cx) + })) + ) + ) + .into_any() + } else { + v_flex() + .child( + v_flex() + .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| { + this.save_api_url(cx); + cx.notify(); + })) + .gap_2() + .child(self.api_url_editor.clone()) + ) + .into_any() + } + }) + .child( + v_flex() + .gap_2() + .child( + h_flex() + .w_full() + .justify_between() + .gap_2() + .child( + h_flex() + .w_full() + .gap_2() + .map(|this| { + if is_authenticated { + this.child( + Button::new("ollama-site", "Ollama") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) + .into_any_element(), + ) + } else { + this.child( + Button::new( + "download_ollama_button", + "Download Ollama", + ) + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| { + cx.open_url(OLLAMA_DOWNLOAD_URL) + }) + .into_any_element(), + ) + } + }) + .child( + Button::new("view-models", "View All Models") + .style(ButtonStyle::Subtle) + .icon(IconName::ArrowUpRight) + .icon_size(IconSize::XSmall) + .icon_color(Color::Muted) + .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), + ), + ) + .map(|this| { + if is_authenticated { + this.child( + ButtonLike::new("connected") + .disabled(true) + .cursor_style(gpui::CursorStyle::Arrow) + .child( + h_flex() + .gap_2() + .child(Indicator::dot().color(Color::Success)) + .child(Label::new("Connected")) + .into_any_element(), + ), + ) + } else { + this.child( + Button::new("retry_ollama_models", "Connect") + .icon_position(IconPosition::Start) + .icon_size(IconSize::XSmall) + .icon(IconName::PlayOutlined) + .on_click(cx.listener(move |this, _, _, cx| { + this.retry_connection(cx) + })), + ) + } + }) + ) + ) + .into_any() } } } diff --git a/crates/ollama/src/ollama.rs b/crates/ollama/src/ollama.rs index 64cd1cc0cb..6d17c7c4d7 100644 --- a/crates/ollama/src/ollama.rs +++ b/crates/ollama/src/ollama.rs @@ -6,6 +6,7 @@ use serde_json::Value; use std::time::Duration; pub const OLLAMA_API_URL: &str = "http://localhost:11434"; +pub const OLLAMA_API_KEY_VAR: &str = "OLLAMA_API_KEY"; #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)] @@ -275,14 +276,19 @@ pub async fn complete( pub async fn stream_chat_completion( client: &dyn HttpClient, api_url: &str, + api_key: Option, request: ChatRequest, ) -> Result>> { let uri = format!("{api_url}/api/chat"); - let request_builder = http::Request::builder() + let mut request_builder = http::Request::builder() .method(Method::POST) .uri(uri) .header("Content-Type", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) + } + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; let mut response = client.send(request).await?; if response.status().is_success() { @@ -309,14 +315,19 @@ pub async fn stream_chat_completion( pub async fn get_models( client: &dyn HttpClient, api_url: &str, + api_key: Option, _: Option, ) -> Result> { let uri = format!("{api_url}/api/tags"); - let request_builder = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::GET) .uri(uri) .header("Accept", "application/json"); + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")); + } + let request = request_builder.body(AsyncBody::default())?; let mut response = client.send(request).await?; @@ -336,15 +347,25 @@ pub async fn get_models( } /// Fetch details of a model, used to determine model capabilities -pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result { +pub async fn show_model( + client: &dyn HttpClient, + api_url: &str, + api_key: Option, + model: &str, +) -> Result { let uri = format!("{api_url}/api/show"); - let request = HttpRequest::builder() + let mut request_builder = HttpRequest::builder() .method(Method::POST) .uri(uri) - .header("Content-Type", "application/json") - .body(AsyncBody::from( - serde_json::json!({ "model": model }).to_string(), - ))?; + .header("Content-Type", "application/json"); + + if let Some(api_key) = api_key { + request_builder = request_builder.header("Authorization", format!("Bearer {api_key}")) + } + + let request = request_builder.body(AsyncBody::from( + serde_json::json!({ "model": model }).to_string(), + ))?; let mut response = client.send(request).await?; let mut body = String::new(); diff --git a/docs/src/ai/llm-providers.md b/docs/src/ai/llm-providers.md index 5ef6081421..808d4b1992 100644 --- a/docs/src/ai/llm-providers.md +++ b/docs/src/ai/llm-providers.md @@ -378,6 +378,20 @@ If the model is tagged with `thinking` in the Ollama catalog, set this option an The `supports_images` option enables the model's vision capabilities, allowing it to process images included in the conversation context. If the model is tagged with `vision` in the Ollama catalog, set this option and you can use it in Zed. +#### Ollama Authentication + +In addition to running Ollama on your own hardware, which generally does not require authentication, Zed also supports connecting to Ollama API Keys are required for authentication. + +One such service is [Ollama Turbo])(https://ollama.com/turbo). To configure Zed to use Ollama turbo: + +1. Sign in to your Ollama account and subscribe to Ollama Turbo +2. Visit [ollama.com/settings/keys](https://ollama.com/settings/keys) and create an API key +3. Open the settings view (`agent: open settings`) and go to the Ollama section +4. Paste your API key and press enter. +5. For the API URL enter `https://ollama.com` + +Zed will also use the `OLLAMA_API_KEY` environment variables if defined. + ### OpenAI {#openai} 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys)