assistant: Use tools in other providers (#15803)
- [x] OpenAI - [ ] ~Google~ Moved into a separate branch at: https://github.com/zed-industries/zed/tree/tool-calls-in-google-ai I've ran into issues with having the API digest our schema without tripping over itself - the function call parameters are malformed and whatnot. We can resume from that branch if needed. - [x] Ollama - [x] Cloud - [ ] ~Copilot Chat (?)~ Release Notes: - Added tool calling capabilities to OpenAI and Ollama models.
This commit is contained in:
parent
be514f23e1
commit
874f0c0712
5 changed files with 392 additions and 64 deletions
|
@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
|||
use isahc::config::Configurable;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
||||
|
||||
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
||||
|
@ -94,22 +95,63 @@ impl Model {
|
|||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum ChatMessage {
|
||||
Assistant { content: String },
|
||||
User { content: String },
|
||||
System { content: String },
|
||||
Assistant {
|
||||
content: String,
|
||||
tool_calls: Option<Vec<OllamaToolCall>>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
},
|
||||
System {
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OllamaToolCall {
|
||||
Function(OllamaFunctionCall),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct OllamaFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct OllamaFunctionTool {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum OllamaTool {
|
||||
Function { function: OllamaFunctionTool },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
pub keep_alive: KeepAlive,
|
||||
pub options: Option<ChatOptions>,
|
||||
pub tools: Vec<OllamaTool>,
|
||||
}
|
||||
|
||||
impl ChatRequest {
|
||||
pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
|
||||
self.stream = false;
|
||||
self.tools = tools;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
||||
#[derive(Serialize, Default)]
|
||||
#[derive(Serialize, Default, Debug)]
|
||||
pub struct ChatOptions {
|
||||
pub num_ctx: Option<usize>,
|
||||
pub num_predict: Option<isize>,
|
||||
|
@ -118,7 +160,7 @@ pub struct ChatOptions {
|
|||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ChatResponseDelta {
|
||||
#[allow(unused)]
|
||||
pub model: String,
|
||||
|
@ -162,6 +204,38 @@ pub struct ModelDetails {
|
|||
pub quantization_level: String,
|
||||
}
|
||||
|
||||
pub async fn complete(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
request: ChatRequest,
|
||||
) -> Result<ChatResponseDelta> {
|
||||
let uri = format!("{api_url}/api/chat");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let serialized_request = serde_json::to_string(&request)?;
|
||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
|
||||
Ok(response_message)
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
let body_str = std::str::from_utf8(&body)?;
|
||||
Err(anyhow!(
|
||||
"Failed to connect to API: {} {}",
|
||||
response.status(),
|
||||
body_str
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn stream_chat_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue