language_models: Fix non-streaming Copilot Chat models (#28537)
This PR fixes usage of non-streaming Copilot Chat models. Closes https://github.com/zed-industries/zed/issues/28528. Release Notes: - Fixed an issue with using non-streaming Copilot Chat models (e.g., o1, o3-mini).
This commit is contained in:
parent
90f30b5c20
commit
d88694f8da
1 changed files with 30 additions and 19 deletions
|
@ -254,6 +254,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
Ok(request) => request,
|
Ok(request) => request,
|
||||||
Err(err) => return futures::future::ready(Err(err)).boxed(),
|
Err(err) => return futures::future::ready(Err(err)).boxed(),
|
||||||
};
|
};
|
||||||
|
let is_streaming = copilot_request.stream;
|
||||||
|
|
||||||
let request_limiter = self.request_limiter.clone();
|
let request_limiter = self.request_limiter.clone();
|
||||||
let future = cx.spawn(async move |cx| {
|
let future = cx.spawn(async move |cx| {
|
||||||
|
@ -261,7 +262,10 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
request_limiter
|
request_limiter
|
||||||
.stream(async move {
|
.stream(async move {
|
||||||
let response = request.await?;
|
let response = request.await?;
|
||||||
Ok(map_to_language_model_completion_events(response))
|
Ok(map_to_language_model_completion_events(
|
||||||
|
response,
|
||||||
|
is_streaming,
|
||||||
|
))
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
});
|
});
|
||||||
|
@ -271,6 +275,7 @@ impl LanguageModel for CopilotChatLanguageModel {
|
||||||
|
|
||||||
pub fn map_to_language_model_completion_events(
|
pub fn map_to_language_model_completion_events(
|
||||||
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
|
||||||
|
is_streaming: bool,
|
||||||
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
struct RawToolCall {
|
struct RawToolCall {
|
||||||
|
@ -289,7 +294,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
events,
|
events,
|
||||||
tool_calls_by_index: HashMap::default(),
|
tool_calls_by_index: HashMap::default(),
|
||||||
},
|
},
|
||||||
|mut state| async move {
|
move |mut state| async move {
|
||||||
if let Some(event) = state.events.next().await {
|
if let Some(event) = state.events.next().await {
|
||||||
match event {
|
match event {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
|
@ -300,7 +305,13 @@ pub fn map_to_language_model_completion_events(
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(delta) = choice.delta.as_ref() else {
|
let delta = if is_streaming {
|
||||||
|
choice.delta.as_ref()
|
||||||
|
} else {
|
||||||
|
choice.message.as_ref()
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(delta) = delta else {
|
||||||
return Some((
|
return Some((
|
||||||
vec![Err(anyhow!("Response contained no delta"))],
|
vec![Err(anyhow!("Response contained no delta"))],
|
||||||
state,
|
state,
|
||||||
|
@ -312,26 +323,26 @@ pub fn map_to_language_model_completion_events(
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
|
||||||
}
|
}
|
||||||
|
|
||||||
for tool_call in &delta.tool_calls {
|
for tool_call in &delta.tool_calls {
|
||||||
let entry = state
|
let entry = state
|
||||||
.tool_calls_by_index
|
.tool_calls_by_index
|
||||||
.entry(tool_call.index)
|
.entry(tool_call.index)
|
||||||
.or_default();
|
.or_default();
|
||||||
|
|
||||||
if let Some(tool_id) = tool_call.id.clone() {
|
if let Some(tool_id) = tool_call.id.clone() {
|
||||||
entry.id = tool_id;
|
entry.id = tool_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(function) = tool_call.function.as_ref() {
|
||||||
|
if let Some(name) = function.name.clone() {
|
||||||
|
entry.name = name;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(function) = tool_call.function.as_ref() {
|
if let Some(arguments) = function.arguments.clone() {
|
||||||
if let Some(name) = function.name.clone() {
|
entry.arguments.push_str(&arguments);
|
||||||
entry.name = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(arguments) = function.arguments.clone() {
|
|
||||||
entry.arguments.push_str(&arguments);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
match choice.finish_reason.as_deref() {
|
match choice.finish_reason.as_deref() {
|
||||||
Some("stop") => {
|
Some("stop") => {
|
||||||
|
@ -361,7 +372,7 @@ pub fn map_to_language_model_completion_events(
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
Some(stop_reason) => {
|
Some(stop_reason) => {
|
||||||
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
|
log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
|
||||||
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
events.push(Ok(LanguageModelCompletionEvent::Stop(
|
||||||
StopReason::EndTurn,
|
StopReason::EndTurn,
|
||||||
)));
|
)));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue