270 lines
8.4 KiB
TypeScript
Raw Normal View History

2024-12-26 12:41:37 -03:00
import { MapDocument, URLTrace } from "../../controllers/v1/types";
import { performRanking } from "../ranker";
import { isUrlBlocked } from "../../scraper/WebScraper/utils/blocklist";
import { logger } from "../logger";
2024-10-28 16:02:07 -03:00
import { CohereClient } from "cohere-ai";
import { extractConfig } from "./config";
import { searchSimilarPages } from "./index/pinecone";
import { generateOpenAICompletions } from "../../scraper/scrapeURL/transformers/llmExtract";
import { buildRerankerUserPrompt } from "./build-prompts";
import { buildRerankerSystemPrompt } from "./build-prompts";
2024-12-26 12:41:37 -03:00
2024-10-28 16:02:07 -03:00
const cohere = new CohereClient({
2024-12-11 19:51:08 -03:00
token: process.env.COHERE_API_KEY,
2024-10-28 16:02:07 -03:00
});
2024-12-26 12:41:37 -03:00
interface RankingResult {
mappedLinks: MapDocument[];
linksAndScores: {
link: string;
linkWithContext: string;
score: number;
originalIndex: number;
}[];
}
2024-10-28 16:02:07 -03:00
export async function rerankDocuments(
documents: (string | Record<string, string>)[],
query: string,
topN = 3,
2024-12-11 19:51:08 -03:00
model = "rerank-english-v3.0",
2024-10-28 16:02:07 -03:00
) {
const rerank = await cohere.v2.rerank({
documents,
query,
topN,
model,
2024-12-11 19:51:08 -03:00
returnDocuments: true,
2024-10-28 16:02:07 -03:00
});
2024-12-11 19:46:11 -03:00
return rerank.results
.sort((a, b) => b.relevanceScore - a.relevanceScore)
.map((x) => ({
document: x.document,
index: x.index,
2024-12-11 19:51:08 -03:00
relevanceScore: x.relevanceScore,
2024-12-11 19:46:11 -03:00
}));
2024-10-28 16:02:07 -03:00
}
2024-12-26 12:41:37 -03:00
export async function rerankLinks(
mappedLinks: MapDocument[],
searchQuery: string,
urlTraces: URLTrace[],
): Promise<MapDocument[]> {
// console.log("Going to rerank links");
2024-12-26 12:41:37 -03:00
const mappedLinksRerank = mappedLinks.map(
(x) => `url: ${x.url}, title: ${x.title}, description: ${x.description}`,
);
const linksAndScores = await performRanking(
mappedLinksRerank,
mappedLinks.map((l) => l.url),
searchQuery
2024-12-26 12:41:37 -03:00
);
// First try with high threshold
let filteredLinks = filterAndProcessLinks(
mappedLinks,
linksAndScores,
extractConfig.RERANKING.INITIAL_SCORE_THRESHOLD_FOR_RELEVANCE,
2024-12-26 12:41:37 -03:00
);
// If we don't have enough high-quality links, try with lower threshold
if (filteredLinks.length < extractConfig.RERANKING.MIN_REQUIRED_LINKS) {
2024-12-26 12:41:37 -03:00
logger.info(
`Only found ${filteredLinks.length} links with score > ${extractConfig.RERANKING.INITIAL_SCORE_THRESHOLD_FOR_RELEVANCE}. Trying lower threshold...`,
2024-12-26 12:41:37 -03:00
);
filteredLinks = filterAndProcessLinks(
mappedLinks,
linksAndScores,
extractConfig.RERANKING.FALLBACK_SCORE_THRESHOLD_FOR_RELEVANCE,
2024-12-26 12:41:37 -03:00
);
if (filteredLinks.length === 0) {
// If still no results, take top N results regardless of score
logger.warn(
`No links found with score > ${extractConfig.RERANKING.FALLBACK_SCORE_THRESHOLD_FOR_RELEVANCE}. Taking top ${extractConfig.RERANKING.MIN_REQUIRED_LINKS} results.`,
2024-12-26 12:41:37 -03:00
);
filteredLinks = linksAndScores
.sort((a, b) => b.score - a.score)
.slice(0, extractConfig.RERANKING.MIN_REQUIRED_LINKS)
2024-12-26 12:41:37 -03:00
.map((x) => mappedLinks.find((link) => link.url === x.link))
.filter(
(x): x is MapDocument =>
x !== undefined && x.url !== undefined && !isUrlBlocked(x.url),
);
}
}
// Update URL traces with relevance scores and mark filtered out URLs
linksAndScores.forEach((score) => {
const trace = urlTraces.find((t) => t.url === score.link);
if (trace) {
trace.relevanceScore = score.score;
// If URL didn't make it through filtering, mark it as filtered out
2025-01-10 18:35:10 -03:00
if (!filteredLinks.some((link) => link.url === score.link)) {
2024-12-26 12:41:37 -03:00
trace.warning = `Relevance score ${score.score} below threshold`;
trace.usedInCompletion = false;
}
}
});
const rankedLinks = filteredLinks.slice(0, extractConfig.RERANKING.MAX_RANKING_LIMIT_FOR_RELEVANCE);
2024-12-26 12:41:37 -03:00
// Mark URLs that will be used in completion
2025-01-10 18:35:10 -03:00
rankedLinks.forEach((link) => {
const trace = urlTraces.find((t) => t.url === link.url);
2024-12-26 12:41:37 -03:00
if (trace) {
trace.usedInCompletion = true;
}
});
// Mark URLs that were dropped due to ranking limit
filteredLinks.slice(extractConfig.RERANKING.MAX_RANKING_LIMIT_FOR_RELEVANCE).forEach(link => {
const trace = urlTraces.find(t => t.url === link.url);
2024-12-26 12:41:37 -03:00
if (trace) {
2025-01-10 18:35:10 -03:00
trace.warning = "Excluded due to ranking limit";
2024-12-26 12:41:37 -03:00
trace.usedInCompletion = false;
}
});
// console.log("Reranked links: ", rankedLinks.length);
2024-12-26 12:41:37 -03:00
return rankedLinks;
}
function filterAndProcessLinks(
mappedLinks: MapDocument[],
linksAndScores: {
link: string;
linkWithContext: string;
score: number;
originalIndex: number;
}[],
threshold: number,
): MapDocument[] {
return linksAndScores
.filter((x) => x.score > threshold)
.map((x) => mappedLinks.find((link) => link.url === x.link))
.filter(
(x): x is MapDocument =>
x !== undefined && x.url !== undefined && !isUrlBlocked(x.url),
);
}
export type RerankerResult = {
mapDocument: MapDocument[];
tokensUsed: number;
}
export async function rerankLinksWithLLM(
mappedLinks: MapDocument[],
searchQuery: string,
urlTraces: URLTrace[],
): Promise<RerankerResult> {
const chunkSize = 100;
const chunks: MapDocument[][] = [];
const TIMEOUT_MS = 20000;
const MAX_RETRIES = 2;
let totalTokensUsed = 0;
// Split mappedLinks into chunks of 200
for (let i = 0; i < mappedLinks.length; i += chunkSize) {
chunks.push(mappedLinks.slice(i, i + chunkSize));
}
// console.log(`Total links: ${mappedLinks.length}, Number of chunks: ${chunks.length}`);
const schema = {
type: "object",
properties: {
relevantLinks: {
type: "array",
items: {
type: "object",
properties: {
url: { type: "string" },
relevanceScore: { type: "number" }
},
required: ["url", "relevanceScore"]
}
}
},
required: ["relevantLinks"]
};
const results = await Promise.all(
chunks.map(async (chunk, chunkIndex) => {
// console.log(`Processing chunk ${chunkIndex + 1}/${chunks.length} with ${chunk.length} links`);
const linksContent = chunk.map(link =>
`URL: ${link.url}${link.title ? `\nTitle: ${link.title}` : ''}${link.description ? `\nDescription: ${link.description}` : ''}`
).join("\n\n");
for (let retry = 0; retry <= MAX_RETRIES; retry++) {
try {
const timeoutPromise = new Promise<null>((resolve) => {
setTimeout(() => resolve(null), TIMEOUT_MS);
});
const completionPromise = generateOpenAICompletions(
logger.child({ method: "rerankLinksWithLLM", chunk: chunkIndex + 1, retry }),
{
mode: "llm",
systemPrompt: buildRerankerSystemPrompt(),
prompt: buildRerankerUserPrompt(searchQuery),
schema: schema
},
linksContent,
undefined,
true
);
const completion = await Promise.race([completionPromise, timeoutPromise]);
if (!completion) {
// console.log(`Chunk ${chunkIndex + 1}: Timeout on attempt ${retry + 1}`);
continue;
}
if (!completion.extract?.relevantLinks) {
// console.warn(`Chunk ${chunkIndex + 1}: No relevant links found in completion response`);
return [];
}
totalTokensUsed += completion.numTokens || 0;
// console.log(`Chunk ${chunkIndex + 1}: Found ${completion.extract.relevantLinks.length} relevant links`);
return completion.extract.relevantLinks;
} catch (error) {
console.warn(`Error processing chunk ${chunkIndex + 1} attempt ${retry + 1}:`, error);
if (retry === MAX_RETRIES) {
// console.log(`Chunk ${chunkIndex + 1}: Max retries reached, returning empty array`);
return [];
}
}
}
return [];
})
);
// console.log(`Processed ${results.length} chunks`);
// Flatten results and sort by relevance score
const flattenedResults = results.flat().sort((a, b) => b.relevanceScore - a.relevanceScore);
// console.log(`Total relevant links found: ${flattenedResults.length}`);
// Map back to MapDocument format, keeping only relevant links
const relevantLinks = flattenedResults
.map(result => mappedLinks.find(link => link.url === result.url))
.filter((link): link is MapDocument => link !== undefined);
// console.log(`Returning ${relevantLinks.length} relevant links`);
return {
mapDocument: relevantLinks,
tokensUsed: totalTokensUsed,
};
}