use anyhow::{Result, anyhow}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{Stream, TryFutureExt, stream}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use http_client::HttpClient; use language_model::{ AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason, }; use language_model::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; use ollama::{ ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool, OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion, }; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; use std::pin::Pin; use std::sync::atomic::{AtomicU64, Ordering}; use std::{collections::BTreeMap, sync::Arc}; use ui::{ButtonLike, Indicator, List, prelude::*}; use util::ResultExt; use crate::AllLanguageModelSettings; use crate::ui::InstructionListItem; const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download"; const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library"; const OLLAMA_SITE: &str = "https://ollama.com/"; const PROVIDER_ID: &str = "ollama"; const PROVIDER_NAME: &str = "Ollama"; #[derive(Default, Debug, Clone, PartialEq)] pub struct OllamaSettings { pub api_url: String, pub available_models: Vec, } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] pub struct AvailableModel { /// The model name in the Ollama API (e.g. "llama3.2:latest") pub name: String, /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel. pub display_name: Option, /// The Context Length parameter to the model (aka num_ctx or n_ctx) pub max_tokens: usize, /// The number of seconds to keep the connection open after the last request pub keep_alive: Option, /// Whether the model supports tools pub supports_tools: Option, /// Whether to enable think mode pub supports_thinking: Option, } pub struct OllamaLanguageModelProvider { http_client: Arc, state: gpui::Entity, } pub struct State { http_client: Arc, available_models: Vec, fetch_model_task: Option>>, _subscription: Subscription, } impl State { fn is_authenticated(&self) -> bool { !self.available_models.is_empty() } fn fetch_models(&mut self, cx: &mut Context) -> Task> { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = Arc::clone(&self.http_client); let api_url = settings.api_url.clone(); // As a proxy for the server being "authenticated", we'll check if its up by fetching the models cx.spawn(async move |this, cx| { let models = get_models(http_client.as_ref(), &api_url, None).await?; let tasks = models .into_iter() // Since there is no metadata from the Ollama API // indicating which models are embedding models, // simply filter out models with "-embed" in their name .filter(|model| !model.name.contains("-embed")) .map(|model| { let http_client = Arc::clone(&http_client); let api_url = api_url.clone(); async move { let name = model.name.as_str(); let capabilities = show_model(http_client.as_ref(), &api_url, name).await?; let ollama_model = ollama::Model::new( name, None, None, Some(capabilities.supports_tools()), Some(capabilities.supports_thinking()), ); Ok(ollama_model) } }); // Rate-limit capability fetches // since there is an arbitrary number of models available let mut ollama_models: Vec<_> = futures::stream::iter(tasks) .buffer_unordered(5) .collect::>>() .await .into_iter() .collect::>>()?; ollama_models.sort_by(|a, b| a.name.cmp(&b.name)); this.update(cx, |this, cx| { this.available_models = ollama_models; cx.notify(); }) }) } fn restart_fetch_models_task(&mut self, cx: &mut Context) { let task = self.fetch_models(cx); self.fetch_model_task.replace(task); } fn authenticate(&mut self, cx: &mut Context) -> Task> { if self.is_authenticated() { return Task::ready(Ok(())); } let fetch_models_task = self.fetch_models(cx); cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?)) } } impl OllamaLanguageModelProvider { pub fn new(http_client: Arc, cx: &mut App) -> Self { let this = Self { http_client: http_client.clone(), state: cx.new(|cx| { let subscription = cx.observe_global::({ let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone(); move |this: &mut State, cx| { let new_settings = &AllLanguageModelSettings::get_global(cx).ollama; if &settings != new_settings { settings = new_settings.clone(); this.restart_fetch_models_task(cx); cx.notify(); } } }); State { http_client, available_models: Default::default(), fetch_model_task: None, _subscription: subscription, } }), }; this.state .update(cx, |state, cx| state.restart_fetch_models_task(cx)); this } } impl LanguageModelProviderState for OllamaLanguageModelProvider { type ObservableEntity = State; fn observable_entity(&self) -> Option> { Some(self.state.clone()) } } impl LanguageModelProvider for OllamaLanguageModelProvider { fn id(&self) -> LanguageModelProviderId { LanguageModelProviderId(PROVIDER_ID.into()) } fn name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } fn icon(&self) -> IconName { IconName::AiOllama } fn default_model(&self, cx: &App) -> Option> { self.provided_models(cx).into_iter().next() } fn default_fast_model(&self, cx: &App) -> Option> { self.default_model(cx) } fn provided_models(&self, cx: &App) -> Vec> { let mut models: BTreeMap = BTreeMap::default(); // Add models from the Ollama API for model in self.state.read(cx).available_models.iter() { models.insert(model.name.clone(), model.clone()); } // Override with available models from settings for model in AllLanguageModelSettings::get_global(cx) .ollama .available_models .iter() { models.insert( model.name.clone(), ollama::Model { name: model.name.clone(), display_name: model.display_name.clone(), max_tokens: model.max_tokens, keep_alive: model.keep_alive.clone(), supports_tools: model.supports_tools, supports_thinking: model.supports_thinking, }, ); } models .into_values() .map(|model| { Arc::new(OllamaLanguageModel { id: LanguageModelId::from(model.name.clone()), model: model.clone(), http_client: self.http_client.clone(), request_limiter: RateLimiter::new(4), }) as Arc }) .collect() } fn load_model(&self, model: Arc, cx: &App) { let settings = &AllLanguageModelSettings::get_global(cx).ollama; let http_client = self.http_client.clone(); let api_url = settings.api_url.clone(); let id = model.id().0.to_string(); cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await) .detach_and_log_err(cx); } fn is_authenticated(&self, cx: &App) -> bool { self.state.read(cx).is_authenticated() } fn authenticate(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.authenticate(cx)) } fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView { let state = self.state.clone(); cx.new(|cx| ConfigurationView::new(state, window, cx)) .into() } fn reset_credentials(&self, cx: &mut App) -> Task> { self.state.update(cx, |state, cx| state.fetch_models(cx)) } } pub struct OllamaLanguageModel { id: LanguageModelId, model: ollama::Model, http_client: Arc, request_limiter: RateLimiter, } impl OllamaLanguageModel { fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { ChatRequest { model: self.model.name.clone(), messages: request .messages .into_iter() .map(|msg| match msg.role { Role::User => ChatMessage::User { content: msg.string_contents(), }, Role::Assistant => { let content = msg.string_contents(); let thinking = msg.content.into_iter().find_map(|content| match content { MessageContent::Thinking { text, .. } if !text.is_empty() => Some(text), _ => None, }); ChatMessage::Assistant { content, tool_calls: None, thinking, } } Role::System => ChatMessage::System { content: msg.string_contents(), }, }) .collect(), keep_alive: self.model.keep_alive.clone().unwrap_or_default(), stream: true, options: Some(ChatOptions { num_ctx: Some(self.model.max_tokens), stop: Some(request.stop), temperature: request.temperature.or(Some(1.0)), ..Default::default() }), think: self.model.supports_thinking, tools: request.tools.into_iter().map(tool_into_ollama).collect(), } } } impl LanguageModel for OllamaLanguageModel { fn id(&self) -> LanguageModelId { self.id.clone() } fn name(&self) -> LanguageModelName { LanguageModelName::from(self.model.display_name().to_string()) } fn provider_id(&self) -> LanguageModelProviderId { LanguageModelProviderId(PROVIDER_ID.into()) } fn provider_name(&self) -> LanguageModelProviderName { LanguageModelProviderName(PROVIDER_NAME.into()) } fn supports_tools(&self) -> bool { self.model.supports_tools.unwrap_or(false) } fn supports_images(&self) -> bool { false } fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { match choice { LanguageModelToolChoice::Auto => false, LanguageModelToolChoice::Any => false, LanguageModelToolChoice::None => false, } } fn telemetry_id(&self) -> String { format!("ollama/{}", self.model.id()) } fn max_token_count(&self) -> usize { self.model.max_token_count() } fn count_tokens( &self, request: LanguageModelRequest, _cx: &App, ) -> BoxFuture<'static, Result> { // There is no endpoint for this _yet_ in Ollama // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 let token_count = request .messages .iter() .map(|msg| msg.string_contents().chars().count()) .sum::() / 4; async move { Ok(token_count) }.boxed() } fn stream_completion( &self, request: LanguageModelRequest, cx: &AsyncApp, ) -> BoxFuture< 'static, Result< BoxStream<'static, Result>, >, > { let request = self.to_ollama_request(request); let http_client = self.http_client.clone(); let Ok(api_url) = cx.update(|cx| { let settings = &AllLanguageModelSettings::get_global(cx).ollama; settings.api_url.clone() }) else { return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); }; let future = self.request_limiter.stream(async move { let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?; let stream = map_to_language_model_completion_events(stream); Ok(stream) }); future.map_ok(|f| f.boxed()).boxed() } } fn map_to_language_model_completion_events( stream: Pin> + Send>>, ) -> impl Stream> { // Used for creating unique tool use ids static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0); struct State { stream: Pin> + Send>>, used_tools: bool, } // We need to create a ToolUse and Stop event from a single // response from the original stream let stream = stream::unfold( State { stream, used_tools: false, }, async move |mut state| { let response = state.stream.next().await?; let delta = match response { Ok(delta) => delta, Err(e) => { let event = Err(LanguageModelCompletionError::Other(anyhow!(e))); return Some((vec![event], state)); } }; let mut events = Vec::new(); match delta.message { ChatMessage::User { content } => { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } ChatMessage::System { content } => { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } ChatMessage::Assistant { content, tool_calls, thinking, } => { if let Some(text) = thinking { events.push(Ok(LanguageModelCompletionEvent::Thinking { text, signature: None, })); } if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) { match tool_call { OllamaToolCall::Function(function) => { let tool_id = format!( "{}-{}", &function.name, TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed) ); let event = LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse { id: LanguageModelToolUseId::from(tool_id), name: Arc::from(function.name), raw_input: function.arguments.to_string(), input: function.arguments, is_input_complete: true, }); events.push(Ok(event)); state.used_tools = true; } } } else if !content.is_empty() { events.push(Ok(LanguageModelCompletionEvent::Text(content))); } } }; if delta.done { if state.used_tools { state.used_tools = false; events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse))); } else { events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))); } } Some((events, state)) }, ); stream.flat_map(futures::stream::iter) } struct ConfigurationView { state: gpui::Entity, loading_models_task: Option>, } impl ConfigurationView { pub fn new(state: gpui::Entity, window: &mut Window, cx: &mut Context) -> Self { let loading_models_task = Some(cx.spawn_in(window, { let state = state.clone(); async move |this, cx| { if let Some(task) = state .update(cx, |state, cx| state.authenticate(cx)) .log_err() { task.await.log_err(); } this.update(cx, |this, cx| { this.loading_models_task = None; cx.notify(); }) .log_err(); } })); Self { state, loading_models_task, } } fn retry_connection(&self, cx: &mut App) { self.state .update(cx, |state, cx| state.fetch_models(cx)) .detach_and_log_err(cx); } } impl Render for ConfigurationView { fn render(&mut self, _: &mut Window, cx: &mut Context) -> impl IntoElement { let is_authenticated = self.state.read(cx).is_authenticated(); let ollama_intro = "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama."; if self.loading_models_task.is_some() { div().child(Label::new("Loading models...")).into_any() } else { v_flex() .gap_2() .child( v_flex().gap_1().child(Label::new(ollama_intro)).child( List::new() .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant.")) .child(InstructionListItem::text_only( "Once installed, try `ollama run llama3.2`", )), ), ) .child( h_flex() .w_full() .justify_between() .gap_2() .child( h_flex() .w_full() .gap_2() .map(|this| { if is_authenticated { this.child( Button::new("ollama-site", "Ollama") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE)) .into_any_element(), ) } else { this.child( Button::new( "download_ollama_button", "Download Ollama", ) .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) .on_click(move |_, _, cx| { cx.open_url(OLLAMA_DOWNLOAD_URL) }) .into_any_element(), ) } }) .child( Button::new("view-models", "All Models") .style(ButtonStyle::Subtle) .icon(IconName::ArrowUpRight) .icon_size(IconSize::XSmall) .icon_color(Color::Muted) .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)), ), ) .map(|this| { if is_authenticated { this.child( ButtonLike::new("connected") .disabled(true) .cursor_style(gpui::CursorStyle::Arrow) .child( h_flex() .gap_2() .child(Indicator::dot().color(Color::Success)) .child(Label::new("Connected")) .into_any_element(), ), ) } else { this.child( Button::new("retry_ollama_models", "Connect") .icon_position(IconPosition::Start) .icon_size(IconSize::XSmall) .icon(IconName::Play) .on_click(cx.listener(move |this, _, _, cx| { this.retry_connection(cx) })), ) } }) ) .into_any() } } } fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool { ollama::OllamaTool::Function { function: OllamaFunctionTool { name: tool.name, description: Some(tool.description), parameters: Some(tool.input_schema), }, } }