ollama: Add tool call support (#29563)

The goal of this PR is to support tool calls using ollama. A lot of the
serialization work was done in
https://github.com/zed-industries/zed/pull/15803 however the abstraction
over language models always disables tools.

## Changelog:

- Use `serde_json::Value` inside `OllamaFunctionCall` just as it's used
in `OllamaFunctionCall`. This fixes deserialization of ollama tool
calls.
- Added deserialization tests using json from official ollama api docs.
- Fetch model capabilities during model enumeration from ollama provider
- Added `supports_tools` setting to manually configure if a model
supports tools

## TODO:

- [x] Fix tool call serialization/deserialization
- [x] Fetch model capabilities from ollama api
- [x] Add tests for parsing model capabilities 
- [ ] Documentation for `supports_tools` field for ollama language model
config
- [ ] Convert between generic language model types
- [x] Pass tools to ollama

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Nathan Sobo <nathan@zed.dev>
This commit is contained in:
tidely 2025-05-05 19:52:23 +02:00 committed by GitHub
parent e9616259d0
commit 769ec59162
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 360 additions and 88 deletions

View file

@ -2,42 +2,11 @@ use anyhow::{Context as _, Result, anyhow};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::{Value, value::RawValue};
use std::{convert::TryFrom, sync::Arc, time::Duration};
use serde_json::Value;
use std::{sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
User,
Assistant,
System,
}
impl TryFrom<String> for Role {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self> {
match value.as_str() {
"user" => Ok(Self::User),
"assistant" => Ok(Self::Assistant),
"system" => Ok(Self::System),
_ => Err(anyhow!("invalid role '{value}'")),
}
}
}
impl From<Role> for String {
fn from(val: Role) -> Self {
match val {
Role::User => "user".to_owned(),
Role::Assistant => "assistant".to_owned(),
Role::System => "system".to_owned(),
}
}
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(untagged)]
@ -68,6 +37,7 @@ pub struct Model {
pub display_name: Option<String>,
pub max_tokens: usize,
pub keep_alive: Option<KeepAlive>,
pub supports_tools: bool,
}
fn get_max_tokens(name: &str) -> usize {
@ -93,7 +63,12 @@ fn get_max_tokens(name: &str) -> usize {
}
impl Model {
pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
pub fn new(
name: &str,
display_name: Option<&str>,
max_tokens: Option<usize>,
supports_tools: bool,
) -> Self {
Self {
name: name.to_owned(),
display_name: display_name
@ -101,6 +76,7 @@ impl Model {
.or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
keep_alive: Some(KeepAlive::indefinite()),
supports_tools,
}
}
@ -141,7 +117,7 @@ pub enum OllamaToolCall {
#[derive(Serialize, Deserialize, Debug)]
pub struct OllamaFunctionCall {
pub name: String,
pub arguments: Box<RawValue>,
pub arguments: Value,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@ -229,6 +205,19 @@ pub struct ModelDetails {
pub quantization_level: String,
}
#[derive(Deserialize, Debug)]
pub struct ModelShow {
#[serde(default)]
pub capabilities: Vec<String>,
}
impl ModelShow {
pub fn supports_tools(&self) -> bool {
// .contains expects &String, which would require an additional allocation
self.capabilities.iter().any(|v| v == "tools")
}
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
@ -244,14 +233,14 @@ pub async fn complete(
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
@ -279,13 +268,9 @@ pub async fn stream_chat_completion(
Ok(reader
.lines()
.filter_map(move |line| async move {
match line {
Ok(line) => {
Some(serde_json::from_str(&line).context("Unable to parse chat response"))
}
Err(e) => Some(Err(e.into())),
}
.map(|line| match line {
Ok(line) => serde_json::from_str(&line).context("Unable to parse chat response"),
Err(e) => Err(e.into()),
})
.boxed())
} else {
@ -332,6 +317,33 @@ pub async fn get_models(
}
}
/// Fetch details of a model, used to determine model capabilities
pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<ModelShow> {
let uri = format!("{api_url}/api/show");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(AsyncBody::from(
serde_json::json!({ "model": model }).to_string(),
))?;
let mut response = client.send(request).await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
if response.status().is_success() {
let details: ModelShow = serde_json::from_str(body.as_str())?;
Ok(details)
} else {
Err(anyhow!(
"Failed to connect to Ollama API: {} {}",
response.status(),
body,
))
}
}
/// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate");
@ -339,12 +351,13 @@ pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &s
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(
&serde_json::json!({
.body(AsyncBody::from(
serde_json::json!({
"model": model,
"keep_alive": "15m",
}),
)?))?;
})
.to_string(),
))?;
let mut response = client.send(request).await?;
@ -361,3 +374,161 @@ pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &s
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_completion() {
let response = serde_json::json!({
"model": "llama3.2",
"created_at": "2023-12-12T14:13:43.416799Z",
"message": {
"role": "assistant",
"content": "Hello! How are you today?"
},
"done": true,
"total_duration": 5191566416u64,
"load_duration": 2154458,
"prompt_eval_count": 26,
"prompt_eval_duration": 383809000,
"eval_count": 298,
"eval_duration": 4799921000u64
});
let _: ChatResponseDelta = serde_json::from_value(response).unwrap();
}
#[test]
fn parse_streaming_completion() {
let partial = serde_json::json!({
"model": "llama3.2",
"created_at": "2023-08-04T08:52:19.385406455-07:00",
"message": {
"role": "assistant",
"content": "The",
"images": null
},
"done": false
});
let _: ChatResponseDelta = serde_json::from_value(partial).unwrap();
let last = serde_json::json!({
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458u64,
"load_duration": 1334875,
"prompt_eval_count": 26,
"prompt_eval_duration": 342546000,
"eval_count": 282,
"eval_duration": 4535599000u64
});
let _: ChatResponseDelta = serde_json::from_value(last).unwrap();
}
#[test]
fn parse_tool_call() {
let response = serde_json::json!({
"model": "llama3.2:3b",
"created_at": "2025-04-28T20:02:02.140489Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "weather",
"arguments": {
"city": "london",
}
}
}
]
},
"done_reason": "stop",
"done": true,
"total_duration": 2758629166u64,
"load_duration": 1770059875,
"prompt_eval_count": 147,
"prompt_eval_duration": 684637583,
"eval_count": 16,
"eval_duration": 302561917,
});
let result: ChatResponseDelta = serde_json::from_value(response).unwrap();
match result.message {
ChatMessage::Assistant {
content,
tool_calls,
} => {
assert!(content.is_empty());
assert!(tool_calls.is_some_and(|v| !v.is_empty()));
}
_ => panic!("Deserialized wrong role"),
}
}
#[test]
fn parse_show_model() {
let response = serde_json::json!({
"license": "LLAMA 3.2 COMMUNITY LICENSE AGREEMENT...",
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": ["llama"],
"parameter_size": "3.2B",
"quantization_level": "Q4_K_M"
},
"model_info": {
"general.architecture": "llama",
"general.basename": "Llama-3.2",
"general.file_type": 15,
"general.finetune": "Instruct",
"general.languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
"general.parameter_count": 3212749888u64,
"general.quantization_version": 2,
"general.size_label": "3B",
"general.tags": ["facebook", "meta", "pytorch", "llama", "llama-3", "text-generation"],
"general.type": "model",
"llama.attention.head_count": 24,
"llama.attention.head_count_kv": 8,
"llama.attention.key_length": 128,
"llama.attention.layer_norm_rms_epsilon": 0.00001,
"llama.attention.value_length": 128,
"llama.block_count": 28,
"llama.context_length": 131072,
"llama.embedding_length": 3072,
"llama.feed_forward_length": 8192,
"llama.rope.dimension_count": 128,
"llama.rope.freq_base": 500000,
"llama.vocab_size": 128256,
"tokenizer.ggml.bos_token_id": 128000,
"tokenizer.ggml.eos_token_id": 128009,
"tokenizer.ggml.merges": null,
"tokenizer.ggml.model": "gpt2",
"tokenizer.ggml.pre": "llama-bpe",
"tokenizer.ggml.token_type": null,
"tokenizer.ggml.tokens": null
},
"tensors": [
{ "name": "rope_freqs.weight", "type": "F32", "shape": [64] },
{ "name": "token_embd.weight", "type": "Q4_K_S", "shape": [3072, 128256] }
],
"capabilities": ["completion", "tools"],
"modified_at": "2025-04-29T21:24:41.445877632+03:00"
});
let result: ModelShow = serde_json::from_value(response).unwrap();
assert!(result.supports_tools());
assert!(result.capabilities.contains(&"tools".to_string()));
assert!(result.capabilities.contains(&"completion".to_string()));
}
}