added token count to documents during parsing

This commit is contained in:
KCaverly 2023-08-30 11:05:46 -04:00
parent a7e6a65deb
commit e377ada1a9
4 changed files with 54 additions and 12 deletions

View file

@ -1,6 +1,6 @@
use crate::{
db::dot,
embedding::EmbeddingProvider,
embedding::{DummyEmbeddings, EmbeddingProvider},
parsing::{subtract_ranges, CodeContextRetriever, Document},
semantic_index_settings::SemanticIndexSettings,
SearchResult, SemanticIndex,
@ -227,7 +227,8 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/// A doc comment
@ -314,7 +315,8 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
{
@ -397,7 +399,8 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/* globals importScripts, backend */
@ -495,7 +498,8 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
-- Creates a new class
@ -568,7 +572,8 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
defmodule File.Stream do
@ -684,7 +689,8 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
/**
@ -836,7 +842,8 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
# This concern is inspired by "sudo mode" on GitHub. It
@ -1026,7 +1033,8 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
let mut retriever = CodeContextRetriever::new();
let embedding_provider = Arc::new(DummyEmbeddings {});
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
<?php
@ -1216,6 +1224,10 @@ impl FakeEmbeddingProvider {
#[async_trait]
impl EmbeddingProvider for FakeEmbeddingProvider {
fn count_tokens(&self, span: &str) -> usize {
span.len()
}
async fn embed_batch(&self, spans: Vec<&str>) -> Result<Vec<Vec<f32>>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);