language_models: Add support for tool use to LM Studio provider (#30589)

Closes #30004

**Quick demo:**


https://github.com/user-attachments/assets/0ac93851-81d7-4128-a34b-1f3ae4bcff6d

**Additional notes:**

I've tried to stick to existing code in OpenAI provider as much as
possible without changing much to keep the diff small.

This PR is done in collaboration with @yagil from LM Studio. We agreed
upon the format in which LM Studio will return information about tool
use support for the model in the upcoming version. As of current stable
version nothing is going to change for the users, but once they update
to a newer LM Studio tool use gets automatically enabled for them. I
think this is much better UX then defaulting to true right now.


Release Notes:

- Added support for tool calls to LM Studio provider

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Fedor Nezhivoi 2025-05-26 18:54:17 +07:00 committed by GitHub
parent 6363fdab88
commit 998542b048
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 320 additions and 120 deletions

View file

@ -2,7 +2,7 @@ use anyhow::{Context as _, Result};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::{Value, value::RawValue};
use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
@ -47,14 +47,21 @@ pub struct Model {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize,
pub supports_tool_calls: bool,
}
impl Model {
pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
pub fn new(
name: &str,
display_name: Option<&str>,
max_tokens: Option<usize>,
supports_tool_calls: bool,
) -> Self {
Self {
name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2048),
supports_tool_calls,
}
}
@ -69,15 +76,43 @@ impl Model {
pub fn max_token_count(&self) -> usize {
self.max_tokens
}
pub fn supports_tool_calls(&self) -> bool {
self.supports_tool_calls
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
Required,
None,
Other(ToolDefinition),
}
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
Assistant {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<LmStudioToolCall>>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Vec<ToolCall>,
},
User {
content: String,
@ -85,31 +120,29 @@ pub enum ChatMessage {
System {
content: String,
},
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
pub enum LmStudioToolCall {
Function(LmStudioFunctionCall),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LmStudioFunctionCall {
pub name: String,
pub arguments: Box<RawValue>,
Tool {
content: String,
tool_call_id: String,
},
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LmStudioFunctionTool {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
pub struct ToolCall {
pub id: String,
#[serde(flatten)]
pub content: ToolCallContent,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum LmStudioTool {
Function { function: LmStudioFunctionTool },
pub enum ToolCallContent {
Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
}
#[derive(Serialize, Debug)]
@ -117,10 +150,16 @@ pub struct ChatCompletionRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
pub tools: Vec<LmStudioTool>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
}
#[derive(Serialize, Deserialize, Debug)]
@ -135,8 +174,7 @@ pub struct ChatResponse {
#[derive(Serialize, Deserialize, Debug)]
pub struct ChoiceDelta {
pub index: u32,
#[serde(default)]
pub delta: serde_json::Value,
pub delta: ResponseMessageDelta,
pub finish_reason: Option<String>,
}
@ -164,6 +202,16 @@ pub struct Usage {
pub total_tokens: u32,
}
#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
#[serde(transparent)]
pub struct Capabilities(Vec<String>);
impl Capabilities {
pub fn supports_tool_calls(&self) -> bool {
self.0.iter().any(|cap| cap == "tool_use")
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)]
pub enum ResponseStreamResult {
@ -175,16 +223,17 @@ pub enum ResponseStreamResult {
pub struct ResponseStreamEvent {
pub created: u32,
pub model: String,
pub object: String,
pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>,
}
#[derive(Serialize, Deserialize)]
#[derive(Deserialize)]
pub struct ListModelsResponse {
pub data: Vec<ModelEntry>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct ModelEntry {
pub id: String,
pub object: String,
@ -196,6 +245,8 @@ pub struct ModelEntry {
pub state: ModelState,
pub max_context_length: Option<u32>,
pub loaded_context_length: Option<u32>,
#[serde(default)]
pub capabilities: Capabilities,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
@ -265,7 +316,7 @@ pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,
request: ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<ChatResponse>>> {
) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions");
let request_builder = http::Request::builder()
.method(Method::POST)