language_models: Add support for tool use to LM Studio provider (#30589)

Closes #30004

**Quick demo:**


https://github.com/user-attachments/assets/0ac93851-81d7-4128-a34b-1f3ae4bcff6d

**Additional notes:**

I've tried to stick to existing code in OpenAI provider as much as
possible without changing much to keep the diff small.

This PR is done in collaboration with @yagil from LM Studio. We agreed
upon the format in which LM Studio will return information about tool
use support for the model in the upcoming version. As of current stable
version nothing is going to change for the users, but once they update
to a newer LM Studio tool use gets automatically enabled for them. I
think this is much better UX then defaulting to true right now.


Release Notes:

- Added support for tool calls to LM Studio provider

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>
This commit is contained in:
Fedor Nezhivoi 2025-05-26 18:54:17 +07:00 committed by GitHub
parent 6363fdab88
commit 998542b048
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 320 additions and 120 deletions

View file

@ -383,7 +383,9 @@ impl AssistantSettingsContent {
_ => None, _ => None,
}; };
settings.provider = Some(AssistantProviderContentV1::LmStudio { settings.provider = Some(AssistantProviderContentV1::LmStudio {
default_model: Some(lmstudio::Model::new(&model, None, None)), default_model: Some(lmstudio::Model::new(
&model, None, None, false,
)),
api_url, api_url,
}); });
} }

View file

@ -1,10 +1,13 @@
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use collections::HashMap;
use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task}; use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use http_client::HttpClient; use http_client::HttpClient;
use language_model::{ use language_model::{
AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent, AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelToolChoice, LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
StopReason, WrappedTextContent,
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@ -12,12 +15,14 @@ use language_model::{
LanguageModelRequest, RateLimiter, Role, LanguageModelRequest, RateLimiter, Role,
}; };
use lmstudio::{ use lmstudio::{
ChatCompletionRequest, ChatMessage, ModelType, get_models, preload_model, ChatCompletionRequest, ChatMessage, ModelType, ResponseStreamEvent, get_models, preload_model,
stream_chat_completion, stream_chat_completion,
}; };
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::pin::Pin;
use std::str::FromStr;
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use ui::{ButtonLike, Indicator, List, prelude::*}; use ui::{ButtonLike, Indicator, List, prelude::*};
use util::ResultExt; use util::ResultExt;
@ -40,12 +45,10 @@ pub struct LmStudioSettings {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel { pub struct AvailableModel {
/// The model name in the LM Studio API. e.g. qwen2.5-coder-7b, phi-4, etc
pub name: String, pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>, pub display_name: Option<String>,
/// The model's context window size.
pub max_tokens: usize, pub max_tokens: usize,
pub supports_tool_calls: bool,
} }
pub struct LmStudioLanguageModelProvider { pub struct LmStudioLanguageModelProvider {
@ -77,7 +80,14 @@ impl State {
let mut models: Vec<lmstudio::Model> = models let mut models: Vec<lmstudio::Model> = models
.into_iter() .into_iter()
.filter(|model| model.r#type != ModelType::Embeddings) .filter(|model| model.r#type != ModelType::Embeddings)
.map(|model| lmstudio::Model::new(&model.id, None, None)) .map(|model| {
lmstudio::Model::new(
&model.id,
None,
None,
model.capabilities.supports_tool_calls(),
)
})
.collect(); .collect();
models.sort_by(|a, b| a.name.cmp(&b.name)); models.sort_by(|a, b| a.name.cmp(&b.name));
@ -156,12 +166,16 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
IconName::AiLmStudio IconName::AiLmStudio
} }
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
self.provided_models(cx).into_iter().next() // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
// In a constrained environment where user might not have enough resources it'll be a bad UX to select something
// to load by default.
None
} }
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> { fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
self.default_model(cx) // See explanation for default_model.
None
} }
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@ -184,6 +198,7 @@ impl LanguageModelProvider for LmStudioLanguageModelProvider {
name: model.name.clone(), name: model.name.clone(),
display_name: model.display_name.clone(), display_name: model.display_name.clone(),
max_tokens: model.max_tokens, max_tokens: model.max_tokens,
supports_tool_calls: model.supports_tool_calls,
}, },
); );
} }
@ -237,31 +252,117 @@ pub struct LmStudioLanguageModel {
impl LmStudioLanguageModel { impl LmStudioLanguageModel {
fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest { fn to_lmstudio_request(&self, request: LanguageModelRequest) -> ChatCompletionRequest {
let mut messages = Vec::new();
for message in request.messages {
for content in message.content {
match content {
MessageContent::Text(text) | MessageContent::Thinking { text, .. } => messages
.push(match message.role {
Role::User => ChatMessage::User { content: text },
Role::Assistant => ChatMessage::Assistant {
content: Some(text),
tool_calls: Vec::new(),
},
Role::System => ChatMessage::System { content: text },
}),
MessageContent::RedactedThinking(_) => {}
MessageContent::Image(_) => {}
MessageContent::ToolUse(tool_use) => {
let tool_call = lmstudio::ToolCall {
id: tool_use.id.to_string(),
content: lmstudio::ToolCallContent::Function {
function: lmstudio::FunctionContent {
name: tool_use.name.to_string(),
arguments: serde_json::to_string(&tool_use.input)
.unwrap_or_default(),
},
},
};
if let Some(lmstudio::ChatMessage::Assistant { tool_calls, .. }) =
messages.last_mut()
{
tool_calls.push(tool_call);
} else {
messages.push(lmstudio::ChatMessage::Assistant {
content: None,
tool_calls: vec![tool_call],
});
}
}
MessageContent::ToolResult(tool_result) => {
match &tool_result.content {
LanguageModelToolResultContent::Text(text)
| LanguageModelToolResultContent::WrappedText(WrappedTextContent {
text,
..
}) => {
messages.push(lmstudio::ChatMessage::Tool {
content: text.to_string(),
tool_call_id: tool_result.tool_use_id.to_string(),
});
}
LanguageModelToolResultContent::Image(_) => {
// no support for images for now
}
};
}
}
}
}
ChatCompletionRequest { ChatCompletionRequest {
model: self.model.name.clone(), model: self.model.name.clone(),
messages: request messages,
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.string_contents(),
},
Role::Assistant => ChatMessage::Assistant {
content: Some(msg.string_contents()),
tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.string_contents(),
},
})
.collect(),
stream: true, stream: true,
max_tokens: Some(-1), max_tokens: Some(-1),
stop: Some(request.stop), stop: Some(request.stop),
temperature: request.temperature.or(Some(0.0)), // In LM Studio you can configure specific settings you'd like to use for your model.
tools: vec![], // For example Qwen3 is recommended to be used with 0.7 temperature.
// It would be a bad UX to silently override these settings from Zed, so we pass no temperature as a default.
temperature: request.temperature.or(None),
tools: request
.tools
.into_iter()
.map(|tool| lmstudio::ToolDefinition::Function {
function: lmstudio::FunctionDefinition {
name: tool.name,
description: Some(tool.description),
parameters: Some(tool.input_schema),
},
})
.collect(),
tool_choice: request.tool_choice.map(|choice| match choice {
LanguageModelToolChoice::Auto => lmstudio::ToolChoice::Auto,
LanguageModelToolChoice::Any => lmstudio::ToolChoice::Required,
LanguageModelToolChoice::None => lmstudio::ToolChoice::None,
}),
} }
} }
fn stream_completion(
&self,
request: ChatCompletionRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
{
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let request = stream_chat_completion(http_client.as_ref(), &api_url, request);
let response = request.await?;
Ok(response)
});
async move { Ok(future.await?.boxed()) }.boxed()
}
} }
impl LanguageModel for LmStudioLanguageModel { impl LanguageModel for LmStudioLanguageModel {
@ -282,17 +383,22 @@ impl LanguageModel for LmStudioLanguageModel {
} }
fn supports_tools(&self) -> bool { fn supports_tools(&self) -> bool {
false self.model.supports_tool_calls()
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
self.supports_tools()
&& match choice {
LanguageModelToolChoice::Auto => true,
LanguageModelToolChoice::Any => true,
LanguageModelToolChoice::None => true,
}
} }
fn supports_images(&self) -> bool { fn supports_images(&self) -> bool {
false false
} }
fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
false
}
fn telemetry_id(&self) -> String { fn telemetry_id(&self) -> String {
format!("lmstudio/{}", self.model.id()) format!("lmstudio/{}", self.model.id())
} }
@ -328,85 +434,126 @@ impl LanguageModel for LmStudioLanguageModel {
>, >,
> { > {
let request = self.to_lmstudio_request(request); let request = self.to_lmstudio_request(request);
let completions = self.stream_completion(request, cx);
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).lmstudio;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
// Create a stream mapper to handle content across multiple deltas
let stream_mapper = LmStudioStreamMapper::new();
let stream = response
.map(move |response| {
response.and_then(|fragment| stream_mapper.process_fragment(fragment))
})
.filter_map(|result| async move {
match result {
Ok(Some(content)) => Some(Ok(content)),
Ok(None) => None,
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
});
async move { async move {
Ok(future let mapper = LmStudioEventMapper::new();
.await? Ok(mapper.map_stream(completions.await?).boxed())
.map(|result| {
result
.map(LanguageModelCompletionEvent::Text)
.map_err(LanguageModelCompletionError::Other)
})
.boxed())
} }
.boxed() .boxed()
} }
} }
// This will be more useful when we implement tool calling. Currently keeping it empty. struct LmStudioEventMapper {
struct LmStudioStreamMapper {} tool_calls_by_index: HashMap<usize, RawToolCall>,
}
impl LmStudioStreamMapper { impl LmStudioEventMapper {
fn new() -> Self { fn new() -> Self {
Self {} Self {
tool_calls_by_index: HashMap::default(),
}
} }
fn process_fragment(&self, fragment: lmstudio::ChatResponse) -> Result<Option<String>> { pub fn map_stream(
// Most of the time, there will be only one choice mut self,
let Some(choice) = fragment.choices.first() else { events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
return Ok(None); ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
})
})
}
pub fn map_event(
&mut self,
event: ResponseStreamEvent,
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let Some(choice) = event.choices.into_iter().next() else {
return vec![Err(LanguageModelCompletionError::Other(anyhow!(
"Response contained no choices"
)))];
}; };
// Extract the delta content let mut events = Vec::new();
if let Ok(delta) = if let Some(content) = choice.delta.content {
serde_json::from_value::<lmstudio::ResponseMessageDelta>(choice.delta.clone()) events.push(Ok(LanguageModelCompletionEvent::Text(content)));
{ }
if let Some(content) = delta.content {
if !content.is_empty() { if let Some(tool_calls) = choice.delta.tool_calls {
return Ok(Some(content)); for tool_call in tool_calls {
let entry = self.tool_calls_by_index.entry(tool_call.index).or_default();
if let Some(tool_id) = tool_call.id {
entry.id = tool_id;
}
if let Some(function) = tool_call.function {
if let Some(name) = function.name {
// At the time of writing this code LM Studio (0.3.15) is incompatible with the OpenAI API:
// 1. It sends function name in the first chunk
// 2. It sends empty string in the function name field in all subsequent chunks for arguments
// According to https://platform.openai.com/docs/guides/function-calling?api-mode=responses#streaming
// function name field should be sent only inside the first chunk.
if !name.is_empty() {
entry.name = name;
}
}
if let Some(arguments) = function.arguments {
entry.arguments.push_str(&arguments);
}
} }
} }
} }
// If there's a finish_reason, we're done match choice.finish_reason.as_deref() {
if choice.finish_reason.is_some() { Some("stop") => {
return Ok(None); events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
Some("tool_calls") => {
events.extend(self.tool_calls_by_index.drain().map(|(_, tool_call)| {
match serde_json::Value::from_str(&tool_call.arguments) {
Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
LanguageModelToolUse {
id: tool_call.id.into(),
name: tool_call.name.into(),
is_input_complete: true,
input,
raw_input: tool_call.arguments,
},
)),
Err(error) => Err(LanguageModelCompletionError::BadInputJson {
id: tool_call.id.into(),
tool_name: tool_call.name.into(),
raw_input: tool_call.arguments.into(),
json_parse_error: error.to_string(),
}),
}
}));
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
}
Some(stop_reason) => {
log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
}
None => {}
} }
Ok(None) events
} }
} }
#[derive(Default)]
struct RawToolCall {
id: String,
name: String,
arguments: String,
}
struct ConfigurationView { struct ConfigurationView {
state: gpui::Entity<State>, state: gpui::Entity<State>,
loading_models_task: Option<Task<()>>, loading_models_task: Option<Task<()>>,

View file

@ -2,7 +2,7 @@ use anyhow::{Context as _, Result};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::{Value, value::RawValue}; use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration}; use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0"; pub const LMSTUDIO_API_URL: &str = "http://localhost:1234/api/v0";
@ -47,14 +47,21 @@ pub struct Model {
pub name: String, pub name: String,
pub display_name: Option<String>, pub display_name: Option<String>,
pub max_tokens: usize, pub max_tokens: usize,
pub supports_tool_calls: bool,
} }
impl Model { impl Model {
pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self { pub fn new(
name: &str,
display_name: Option<&str>,
max_tokens: Option<usize>,
supports_tool_calls: bool,
) -> Self {
Self { Self {
name: name.to_owned(), name: name.to_owned(),
display_name: display_name.map(|s| s.to_owned()), display_name: display_name.map(|s| s.to_owned()),
max_tokens: max_tokens.unwrap_or(2048), max_tokens: max_tokens.unwrap_or(2048),
supports_tool_calls,
} }
} }
@ -69,15 +76,43 @@ impl Model {
pub fn max_token_count(&self) -> usize { pub fn max_token_count(&self) -> usize {
self.max_tokens self.max_tokens
} }
pub fn supports_tool_calls(&self) -> bool {
self.supports_tool_calls
}
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ToolChoice {
Auto,
Required,
None,
Other(ToolDefinition),
}
#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: Option<String>,
pub parameters: Option<Value>,
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")] #[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage { pub enum ChatMessage {
Assistant { Assistant {
#[serde(default)] #[serde(default)]
content: Option<String>, content: Option<String>,
#[serde(default)] #[serde(default, skip_serializing_if = "Vec::is_empty")]
tool_calls: Option<Vec<LmStudioToolCall>>, tool_calls: Vec<ToolCall>,
}, },
User { User {
content: String, content: String,
@ -85,31 +120,29 @@ pub enum ChatMessage {
System { System {
content: String, content: String,
}, },
} Tool {
content: String,
#[derive(Serialize, Deserialize, Debug)] tool_call_id: String,
#[serde(rename_all = "lowercase")] },
pub enum LmStudioToolCall {
Function(LmStudioFunctionCall),
}
#[derive(Serialize, Deserialize, Debug)]
pub struct LmStudioFunctionCall {
pub name: String,
pub arguments: Box<RawValue>,
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct LmStudioFunctionTool { pub struct ToolCall {
pub name: String, pub id: String,
pub description: Option<String>, #[serde(flatten)]
pub parameters: Option<Value>, pub content: ToolCallContent,
} }
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "type", rename_all = "lowercase")] #[serde(tag = "type", rename_all = "lowercase")]
pub enum LmStudioTool { pub enum ToolCallContent {
Function { function: LmStudioFunctionTool }, Function { function: FunctionContent },
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct FunctionContent {
pub name: String,
pub arguments: String,
} }
#[derive(Serialize, Debug)] #[derive(Serialize, Debug)]
@ -117,10 +150,16 @@ pub struct ChatCompletionRequest {
pub model: String, pub model: String,
pub messages: Vec<ChatMessage>, pub messages: Vec<ChatMessage>,
pub stream: bool, pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<i32>, pub max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>, pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>, pub temperature: Option<f32>,
pub tools: Vec<LmStudioTool>, #[serde(skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
} }
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
@ -135,8 +174,7 @@ pub struct ChatResponse {
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct ChoiceDelta { pub struct ChoiceDelta {
pub index: u32, pub index: u32,
#[serde(default)] pub delta: ResponseMessageDelta,
pub delta: serde_json::Value,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
} }
@ -164,6 +202,16 @@ pub struct Usage {
pub total_tokens: u32, pub total_tokens: u32,
} }
#[derive(Debug, Default, Clone, Deserialize, PartialEq)]
#[serde(transparent)]
pub struct Capabilities(Vec<String>);
impl Capabilities {
pub fn supports_tool_calls(&self) -> bool {
self.0.iter().any(|cap| cap == "tool_use")
}
}
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
#[serde(untagged)] #[serde(untagged)]
pub enum ResponseStreamResult { pub enum ResponseStreamResult {
@ -175,16 +223,17 @@ pub enum ResponseStreamResult {
pub struct ResponseStreamEvent { pub struct ResponseStreamEvent {
pub created: u32, pub created: u32,
pub model: String, pub model: String,
pub object: String,
pub choices: Vec<ChoiceDelta>, pub choices: Vec<ChoiceDelta>,
pub usage: Option<Usage>, pub usage: Option<Usage>,
} }
#[derive(Serialize, Deserialize)] #[derive(Deserialize)]
pub struct ListModelsResponse { pub struct ListModelsResponse {
pub data: Vec<ModelEntry>, pub data: Vec<ModelEntry>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Deserialize, PartialEq)]
pub struct ModelEntry { pub struct ModelEntry {
pub id: String, pub id: String,
pub object: String, pub object: String,
@ -196,6 +245,8 @@ pub struct ModelEntry {
pub state: ModelState, pub state: ModelState,
pub max_context_length: Option<u32>, pub max_context_length: Option<u32>,
pub loaded_context_length: Option<u32>, pub loaded_context_length: Option<u32>,
#[serde(default)]
pub capabilities: Capabilities,
} }
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
@ -265,7 +316,7 @@ pub async fn stream_chat_completion(
client: &dyn HttpClient, client: &dyn HttpClient,
api_url: &str, api_url: &str,
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> Result<BoxStream<'static, Result<ChatResponse>>> { ) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
let uri = format!("{api_url}/chat/completions"); let uri = format!("{api_url}/chat/completions");
let request_builder = http::Request::builder() let request_builder = http::Request::builder()
.method(Method::POST) .method(Method::POST)