assistant: Use tools in other providers (#15803)

- [x] OpenAI
- [ ] ~Google~ Moved into a separate branch at:
https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've
ran into issues with having the API digest our schema without tripping
over itself - the function call parameters are malformed and whatnot. We
can resume from that branch if needed.
- [x] Ollama
- [x] Cloud
- [ ] ~Copilot Chat (?)~

Release Notes:

- Added tool calling capabilities to OpenAI and Ollama models.
This commit is contained in:
Piotr Osiewicz 2024-08-06 15:45:47 +02:00 committed by GitHub
parent be514f23e1
commit 874f0c0712
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 392 additions and 64 deletions

View file

@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@ -94,22 +95,63 @@ impl Model {
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant { content: String },
User { content: String },
System { content: String },
Assistant {
content: String,
tool_calls: Option<Vec<OllamaToolCall>>,
},
User {
content: String,
},
System {
content: String,
},
}
#[derive(Serialize)]
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum OllamaToolCall {
Function(OllamaFunctionCall),
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct OllamaFunctionCall {
pub name: String,
pub arguments: Value,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct OllamaFunctionTool {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum OllamaTool {
Function { function: OllamaFunctionTool },
}
#[derive(Serialize, Debug)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
pub keep_alive: KeepAlive,
pub options: Option<ChatOptions>,
pub tools: Vec<OllamaTool>,
}
impl ChatRequest {
pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
self.stream = false;
self.tools = tools;
self
}
}
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
#[derive(Serialize, Default)]
#[derive(Serialize, Default, Debug)]
pub struct ChatOptions {
pub num_ctx: Option<usize>,
pub num_predict: Option<isize>,
@ -118,7 +160,7 @@ pub struct ChatOptions {
pub top_p: Option<f32>,
}
#[derive(Deserialize)]
#[derive(Deserialize, Debug)]
pub struct ChatResponseDelta {
#[allow(unused)]
pub model: String,
@ -162,6 +204,38 @@ pub struct ModelDetails {
pub quantization_level: String,
}
pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
request: ChatRequest,
) -> Result<ChatResponseDelta> {
let uri = format!("{api_url}/api/chat");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
if response.status().is_success() {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
Ok(response_message)
} else {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
let body_str = std::str::from_utf8(&body)?;
Err(anyhow!(
"Failed to connect to API: {} {}",
response.status(),
body_str
))
}
}
pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,