ZIm/crates/language_models/src/provider/ollama.rs
tidely 7bdc99abc1
Fix clippy::redundant_clone lint violations (#36558)
This removes around 900 unnecessary clones, ranging from cloning a few
ints all the way to large data structures and images.

A lot of these were fixed using `cargo clippy --fix --workspace
--all-targets`, however it often breaks other lints and needs to be run
again. This was then followed up with some manual fixing.

I understand this is a large diff, but all the changes are pretty
trivial. Rust is doing some heavy lifting here for us. Once I get it up
to speed with main, I'd appreciate this getting merged rather sooner
than later.

Release Notes:

- N/A
2025-08-20 12:20:13 +02:00

687 lines
26 KiB
Rust

use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
};
use ollama::{
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
OllamaToolCall, get_models, show_model, stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{collections::HashMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt;
use crate::AllLanguageModelSettings;
use crate::ui::InstructionListItem;
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const OLLAMA_SITE: &str = "https://ollama.com/";
const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
/// The model name in the Ollama API (e.g. "llama3.2:latest")
pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The Context Length parameter to the model (aka num_ctx or n_ctx)
pub max_tokens: u64,
/// The number of seconds to keep the connection open after the last request
pub keep_alive: Option<KeepAlive>,
/// Whether the model supports tools
pub supports_tools: Option<bool>,
/// Whether the model supports vision
pub supports_images: Option<bool>,
/// Whether to enable think mode
pub supports_thinking: Option<bool>,
}
pub struct OllamaLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Entity<State>,
}
pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
fetch_model_task: Option<Task<Result<()>>>,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = Arc::clone(&self.http_client);
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(async move |this, cx| {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let tasks = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
async move {
let name = model.name.as_str();
let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
let ollama_model = ollama::Model::new(
name,
None,
None,
Some(capabilities.supports_tools()),
Some(capabilities.supports_vision()),
Some(capabilities.supports_thinking()),
);
Ok(ollama_model)
}
});
// Rate-limit capability fetches
// since there is an arbitrary number of models available
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
.buffer_unordered(5)
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
this.available_models = ollama_models;
cx.notify();
})
})
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
let task = self.fetch_models(cx);
self.fetch_model_task.replace(task);
}
fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
if self.is_authenticated() {
return Task::ready(Ok(()));
}
let fetch_models_task = self.fetch_models(cx);
cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
}
}
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
move |this: &mut State, cx| {
let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
if &settings != new_settings {
settings = new_settings.clone();
this.restart_fetch_models_task(cx);
cx.notify();
}
}
});
State {
http_client,
available_models: Default::default(),
fetch_model_task: None,
_subscription: subscription,
}
}),
};
this.state
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
this
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
PROVIDER_ID
}
fn name(&self) -> LanguageModelProviderName {
PROVIDER_NAME
}
fn icon(&self) -> IconName {
IconName::AiOllama
}
fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
// We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
// In a constrained environment where user might not have enough resources it'll be a bad UX to select something
// to load by default.
None
}
fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
// See explanation for default_model.
None
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
let mut models: HashMap<String, ollama::Model> = HashMap::new();
// Add models from the Ollama API
for model in self.state.read(cx).available_models.iter() {
models.insert(model.name.clone(), model.clone());
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.ollama
.available_models
.iter()
{
models.insert(
model.name.clone(),
ollama::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: model.keep_alive.clone(),
supports_tools: model.supports_tools,
supports_vision: model.supports_images,
supports_thinking: model.supports_thinking,
},
);
}
let mut models = models
.into_values()
.map(|model| {
Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()),
model,
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect::<Vec<_>>();
models.sort_by_key(|model| model.name());
models
}
fn is_authenticated(&self, cx: &App) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(
&self,
_target_agent: language_model::ConfigurationViewTargetAgent,
window: &mut Window,
cx: &mut App,
) -> AnyView {
let state = self.state.clone();
cx.new(|cx| ConfigurationView::new(state, window, cx))
.into()
}
fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OllamaLanguageModel {
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
let supports_vision = self.model.supports_vision.unwrap_or(false);
ChatRequest {
model: self.model.name.clone(),
messages: request
.messages
.into_iter()
.map(|msg| {
let images = if supports_vision {
msg.content
.iter()
.filter_map(|content| match content {
MessageContent::Image(image) => Some(image.source.to_string()),
_ => None,
})
.collect::<Vec<String>>()
} else {
vec![]
};
match msg.role {
Role::User => ChatMessage::User {
content: msg.string_contents(),
images: if images.is_empty() {
None
} else {
Some(images)
},
},
Role::Assistant => {
let content = msg.string_contents();
let thinking =
msg.content.into_iter().find_map(|content| match content {
MessageContent::Thinking { text, .. } if !text.is_empty() => {
Some(text)
}
_ => None,
});
ChatMessage::Assistant {
content,
tool_calls: None,
images: if images.is_empty() {
None
} else {
Some(images)
},
thinking,
}
}
Role::System => ChatMessage::System {
content: msg.string_contents(),
},
}
})
.collect(),
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
stream: true,
options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens),
stop: Some(request.stop),
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
think: self
.model
.supports_thinking
.map(|supports_thinking| supports_thinking && request.thinking_allowed),
tools: request.tools.into_iter().map(tool_into_ollama).collect(),
}
}
}
impl LanguageModel for OllamaLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
PROVIDER_ID
}
fn provider_name(&self) -> LanguageModelProviderName {
PROVIDER_NAME
}
fn supports_tools(&self) -> bool {
self.model.supports_tools.unwrap_or(false)
}
fn supports_images(&self) -> bool {
self.model.supports_vision.unwrap_or(false)
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
LanguageModelToolChoice::Auto => false,
LanguageModelToolChoice::Any => false,
LanguageModelToolChoice::None => false,
}
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
}
fn max_token_count(&self) -> u64 {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
_cx: &App,
) -> BoxFuture<'static, Result<u64>> {
// There is no endpoint for this _yet_ in Ollama
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
let token_count = request
.messages
.iter()
.map(|msg| msg.string_contents().chars().count())
.sum::<usize>()
/ 4;
async move { Ok(token_count as u64) }.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<
BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
LanguageModelCompletionError,
>,
> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
};
let future = self.request_limiter.stream(async move {
let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream = map_to_language_model_completion_events(stream);
Ok(stream)
});
future.map_ok(|f| f.boxed()).boxed()
}
}
fn map_to_language_model_completion_events(
stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
// Used for creating unique tool use ids
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
struct State {
stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
used_tools: bool,
}
// We need to create a ToolUse and Stop event from a single
// response from the original stream
let stream = stream::unfold(
State {
stream,
used_tools: false,
},
async move |mut state| {
let response = state.stream.next().await?;
let delta = match response {
Ok(delta) => delta,
Err(e) => {
let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
return Some((vec![event], state));
}
};
let mut events = Vec::new();
match delta.message {
ChatMessage::User { content, images: _ } => {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
ChatMessage::System { content } => {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
ChatMessage::Assistant {
content,
tool_calls,
images: _,
thinking,
} => {
if let Some(text) = thinking {
events.push(Ok(LanguageModelCompletionEvent::Thinking {
text,
signature: None,
}));
}
if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
match tool_call {
OllamaToolCall::Function(function) => {
let tool_id = format!(
"{}-{}",
&function.name,
TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
);
let event =
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
id: LanguageModelToolUseId::from(tool_id),
name: Arc::from(function.name),
raw_input: function.arguments.to_string(),
input: function.arguments,
is_input_complete: true,
});
events.push(Ok(event));
state.used_tools = true;
}
}
} else if !content.is_empty() {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
}
};
if delta.done {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
input_tokens: delta.prompt_eval_count.unwrap_or(0),
output_tokens: delta.eval_count.unwrap_or(0),
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
if state.used_tools {
state.used_tools = false;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
} else {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
}
Some((events, state))
},
);
stream.flat_map(futures::stream::iter)
}
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
}
impl ConfigurationView {
pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
let loading_models_task = Some(cx.spawn_in(window, {
let state = state.clone();
async move |this, cx| {
if let Some(task) = state
.update(cx, |state, cx| state.authenticate(cx))
.log_err()
{
task.await.log_err();
}
this.update(cx, |this, cx| {
this.loading_models_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
state,
loading_models_task,
}
}
fn retry_connection(&self, cx: &mut App) {
self.state
.update(cx, |state, cx| state.fetch_models(cx))
.detach_and_log_err(cx);
}
}
impl Render for ConfigurationView {
fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let is_authenticated = self.state.read(cx).is_authenticated();
let ollama_intro =
"Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
if self.loading_models_task.is_some() {
div().child(Label::new("Loading models...")).into_any()
} else {
v_flex()
.gap_2()
.child(
v_flex().gap_1().child(Label::new(ollama_intro)).child(
List::new()
.child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
.child(InstructionListItem::text_only(
"Once installed, try `ollama run llama3.2`",
)),
),
)
.child(
h_flex()
.w_full()
.justify_between()
.gap_2()
.child(
h_flex()
.w_full()
.gap_2()
.map(|this| {
if is_authenticated {
this.child(
Button::new("ollama-site", "Ollama")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
.into_any_element(),
)
} else {
this.child(
Button::new(
"download_ollama_button",
"Download Ollama",
)
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| {
cx.open_url(OLLAMA_DOWNLOAD_URL)
})
.into_any_element(),
)
}
})
.child(
Button::new("view-models", "View All Models")
.style(ButtonStyle::Subtle)
.icon(IconName::ArrowUpRight)
.icon_size(IconSize::Small)
.icon_color(Color::Muted)
.on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
),
)
.map(|this| {
if is_authenticated {
this.child(
ButtonLike::new("connected")
.disabled(true)
.cursor_style(gpui::CursorStyle::Arrow)
.child(
h_flex()
.gap_2()
.child(Indicator::dot().color(Color::Success))
.child(Label::new("Connected"))
.into_any_element(),
),
)
} else {
this.child(
Button::new("retry_ollama_models", "Connect")
.icon_position(IconPosition::Start)
.icon_size(IconSize::XSmall)
.icon(IconName::PlayFilled)
.on_click(cx.listener(move |this, _, _, cx| {
this.retry_connection(cx)
})),
)
}
})
)
.into_any()
}
}
}
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
ollama::OllamaTool::Function {
function: OllamaFunctionTool {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
}
}