language_models: Add support for tool use to LM Studio provider (#30589)
Closes #30004 **Quick demo:** https://github.com/user-attachments/assets/0ac93851-81d7-4128-a34b-1f3ae4bcff6d **Additional notes:** I've tried to stick to existing code in OpenAI provider as much as possible without changing much to keep the diff small. This PR is done in collaboration with @yagil from LM Studio. We agreed upon the format in which LM Studio will return information about tool use support for the model in the upcoming version. As of current stable version nothing is going to change for the users, but once they update to a newer LM Studio tool use gets automatically enabled for them. I think this is much better UX then defaulting to true right now. Release Notes: - Added support for tool calls to LM Studio provider --------- Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
parent
6363fdab88
commit
998542b048
3 changed files with 320 additions and 120 deletions
|
@ -2,7 +2,7 @@ use anyhow::{Context as _, Result};
|
|||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, value::RawValue};
|
||||
use serde_json::Value;
|
||||
use std::{convert::TryFrom, sync::Arc, time::Duration};
|
||||
|
||||
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
|
||||
|
@ -47,14 +47,21 @@ pub struct Model {
|
|||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: usize,
|
||||
pub supports_tool_calls: bool,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
display_name: Option<&str>,
|
||||
max_tokens: Option<usize>,
|
||||
supports_tool_calls: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_owned(),
|
||||
display_name: display_name.map(|s| s.to_owned()),
|
||||
max_tokens: max_tokens.unwrap_or(2048),
|
||||
supports_tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,15 +76,43 @@ impl Model {
|
|||
pub fn max_token_count(&self) -> usize {
|
||||
self.max_tokens
|
||||
}
|
||||
|
||||
pub fn supports_tool_calls(&self) -> bool {
|
||||
self.supports_tool_calls
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
Other(ToolDefinition),
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize, Debug)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolDefinition {
|
||||
#[allow(dead_code)]
|
||||
Function { function: FunctionDefinition },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum ChatMessage {
|
||||
Assistant {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<LmStudioToolCall>>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
User {
|
||||
content: String,
|
||||
|
@ -85,31 +120,29 @@ pub enum ChatMessage {
|
|||
System {
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum LmStudioToolCall {
|
||||
Function(LmStudioFunctionCall),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LmStudioFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: Box<RawValue>,
|
||||
Tool {
|
||||
content: String,
|
||||
tool_call_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct LmStudioFunctionTool {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parameters: Option<Value>,
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(flatten)]
|
||||
pub content: ToolCallContent,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "lowercase")]
|
||||
pub enum LmStudioTool {
|
||||
Function { function: LmStudioFunctionTool },
|
||||
pub enum ToolCallContent {
|
||||
Function { function: FunctionContent },
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
pub struct FunctionContent {
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug)]
|
||||
|
@ -117,10 +150,16 @@ pub struct ChatCompletionRequest {
|
|||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
pub tools: Vec<LmStudioTool>,
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub tools: Vec<ToolDefinition>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
|
@ -135,8 +174,7 @@ pub struct ChatResponse {
|
|||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ChoiceDelta {
|
||||
pub index: u32,
|
||||
#[serde(default)]
|
||||
pub delta: serde_json::Value,
|
||||
pub delta: ResponseMessageDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
|
@ -164,6 +202,16 @@ pub struct Usage {
|
|||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
|
||||
#[serde(transparent)]
|
||||
pub struct Capabilities(Vec<String>);
|
||||
|
||||
impl Capabilities {
|
||||
pub fn supports_tool_calls(&self) -> bool {
|
||||
self.0.iter().any(|cap| cap == "tool_use")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum ResponseStreamResult {
|
||||
|
@ -175,16 +223,17 @@ pub enum ResponseStreamResult {
|
|||
pub struct ResponseStreamEvent {
|
||||
pub created: u32,
|
||||
pub model: String,
|
||||
pub object: String,
|
||||
pub choices: Vec<ChoiceDelta>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Deserialize)]
|
||||
pub struct ListModelsResponse {
|
||||
pub data: Vec<ModelEntry>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq)]
|
||||
pub struct ModelEntry {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
|
@ -196,6 +245,8 @@ pub struct ModelEntry {
|
|||
pub state: ModelState,
|
||||
pub max_context_length: Option<u32>,
|
||||
pub loaded_context_length: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub capabilities: Capabilities,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
|
@ -265,7 +316,7 @@ pub async fn stream_chat_completion(
|
|||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> Result<BoxStream<'static, Result<ChatResponse>>> {
|
||||
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
|
||||
let uri = format!("{api_url}/chat/completions");
|
||||
let request_builder = http::Request::builder()
|
||||
.method(Method::POST)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue