ollama: Add tool call support (#29563)

The goal of this PR is to support tool calls using ollama. A lot of the
serialization work was done in
https://github.com/zed-industries/zed/pull/15803 however the abstraction
over language models always disables tools.

## Changelog:

- Use `serde_json::Value` inside `OllamaFunctionCall` just as it's used
in `OllamaFunctionCall`. This fixes deserialization of ollama tool
calls.
- Added deserialization tests using json from official ollama api docs.
- Fetch model capabilities during model enumeration from ollama provider
- Added `supports_tools` setting to manually configure if a model
supports tools

## TODO:

- [x] Fix tool call serialization/deserialization
- [x] Fetch model capabilities from ollama api
- [x] Add tests for parsing model capabilities 
- [ ] Documentation for `supports_tools` field for ollama language model
config
- [ ] Convert between generic language model types
- [x] Pass tools to ollama

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
tidely 2025-05-05 19:52:23 +02:00 committed by GitHub
parent e9616259d0
commit 769ec59162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 360 additions and 88 deletions

View file

@ -1,9 +1,11 @@
use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRequestTool, LanguageModelToolUse, LanguageModelToolUseId, StopReason,
};
use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -11,12 +13,14 @@ use language_model::{
LanguageModelRequest, RateLimiter, Role,
};
use ollama::{
ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model,
stream_chat_completion,
ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::{collections::BTreeMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt;
@ -47,6 +51,8 @@ pub struct AvailableModel {
pub max_tokens: usize,
/// The number of seconds to keep the connection open after the last request
pub keep_alive: Option<KeepAlive>,
/// Whether the model supports tools
pub supports_tools: bool,
}
pub struct OllamaLanguageModelProvider {
@ -68,26 +74,44 @@ impl State {
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let http_client = Arc::clone(&self.http_client);
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(async move |this, cx| {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
let tasks = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name, None, None))
.collect();
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
async move {
let name = model.name.as_str();
let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
let ollama_model =
ollama::Model::new(name, None, None, capabilities.supports_tools());
Ok(ollama_model)
}
});
models.sort_by(|a, b| a.name.cmp(&b.name));
// Rate-limit capability fetches
// since there is an arbitrary number of models available
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
.buffer_unordered(5)
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()?;
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
this.available_models = models;
this.available_models = ollama_models;
cx.notify();
})
})
@ -189,6 +213,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: model.keep_alive.clone(),
supports_tools: model.supports_tools,
},
);
}
@ -269,7 +294,7 @@ impl OllamaLanguageModel {
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
tools: vec![],
tools: request.tools.into_iter().map(tool_into_ollama).collect(),
}
}
}
@ -292,7 +317,7 @@ impl LanguageModel for OllamaLanguageModel {
}
fn supports_tools(&self) -> bool {
false
self.model.supports_tools
}
fn telemetry_id(&self) -> String {
@ -341,39 +366,100 @@ impl LanguageModel for OllamaLanguageModel {
};
let future = self.request_limiter.stream(async move {
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content, .. } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
}
Err(error) => Some(Err(error)),
}
})
.boxed();
let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream = map_to_language_model_completion_events(stream);
Ok(stream)
});
async move {
Ok(future
.await?
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
}
.boxed()
future.map_ok(|f| f.boxed()).boxed()
}
}
fn map_to_language_model_completion_events(
stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
// Used for creating unique tool use ids
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
struct State {
stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
used_tools: bool,
}
// We need to create a ToolUse and Stop event from a single
// response from the original stream
let stream = stream::unfold(
State {
stream,
used_tools: false,
},
async move |mut state| {
let response = state.stream.next().await?;
let delta = match response {
Ok(delta) => delta,
Err(e) => {
let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
return Some((vec![event], state));
}
};
let mut events = Vec::new();
match delta.message {
ChatMessage::User { content } => {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
ChatMessage::System { content } => {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
ChatMessage::Assistant {
content,
tool_calls,
} => {
// Check for tool calls
if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
match tool_call {
OllamaToolCall::Function(function) => {
let tool_id = format!(
"{}-{}",
&function.name,
TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
);
let event =
LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
id: LanguageModelToolUseId::from(tool_id),
name: Arc::from(function.name),
raw_input: function.arguments.to_string(),
input: function.arguments,
is_input_complete: true,
});
events.push(Ok(event));
state.used_tools = true;
}
}
} else {
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
}
};
if delta.done {
if state.used_tools {
state.used_tools = false;
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
} else {
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
}
Some((events, state))
},
);
stream.flat_map(futures::stream::iter)
}
struct ConfigurationView {
state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>,
@ -509,3 +595,13 @@ impl Render for ConfigurationView {
}
}
}
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
ollama::OllamaTool::Function {
function: OllamaFunctionTool {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
}
}