assistant: Use tools in other providers (#15803)

- [x] OpenAI
- [ ] ~Google~ Moved into a separate branch at:
https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've
ran into issues with having the API digest our schema without tripping
over itself - the function call parameters are malformed and whatnot. We
can resume from that branch if needed.
- [x] Ollama
- [x] Cloud
- [ ] ~Copilot Chat (?)~

Release Notes:

- Added tool calling capabilities to OpenAI and Ollama models.
This commit is contained in:
Piotr Osiewicz 2024-08-06 15:45:47 +02:00 committed by GitHub
parent be514f23e1
commit 874f0c0712
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 392 additions and 64 deletions

View file

@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
use serde_json::{Map, Value};
use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use strum::EnumIter;
@ -121,25 +121,34 @@ pub struct Request {
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<String>,
pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Map<String, Value>>,
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
Required,
None,
Other(ToolDefinition),
}
#[derive(Deserialize, Serialize, Debug)]
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {