297 lines
9.2 KiB
TypeScript
Raw Normal View History

2024-12-26 12:41:37 -03:00
import { MapDocument, URLTrace } from "../../controllers/v1/types";
import { performRanking } from "../ranker";
import { isUrlBlocked } from "../../scraper/WebScraper/utils/blocklist";
import { logger } from "../logger";
2024-10-28 16:02:07 -03:00
import { CohereClient } from "cohere-ai";
import { extractConfig } from "./config";
import { searchSimilarPages } from "./index/pinecone";
import { generateCompletions } from "../../scraper/scrapeURL/transformers/llmExtract";
import { buildRerankerUserPrompt } from "./build-prompts";
import { buildRerankerSystemPrompt } from "./build-prompts";
2025-01-24 18:09:25 -03:00
import { dumpToFile } from "./helpers/dump-to-file";
2024-12-26 12:41:37 -03:00
2024-10-28 16:02:07 -03:00
const cohere = new CohereClient({
2024-12-11 19:51:08 -03:00
token: process.env.COHERE_API_KEY,
2024-10-28 16:02:07 -03:00
});
2024-12-26 12:41:37 -03:00
interface RankingResult {
mappedLinks: MapDocument[];
linksAndScores: {
link: string;
linkWithContext: string;
score: number;
originalIndex: number;
}[];
}
2024-10-28 16:02:07 -03:00
export async function rerankDocuments(
documents: (string | Record<string, string>)[],
query: string,
topN = 3,
2024-12-11 19:51:08 -03:00
model = "rerank-english-v3.0",
2024-10-28 16:02:07 -03:00
) {
const rerank = await cohere.v2.rerank({
documents,
query,
topN,
model,
2024-12-11 19:51:08 -03:00
returnDocuments: true,
2024-10-28 16:02:07 -03:00
});
2024-12-11 19:46:11 -03:00
return rerank.results
.sort((a, b) => b.relevanceScore - a.relevanceScore)
.map((x) => ({
document: x.document,
index: x.index,
2024-12-11 19:51:08 -03:00
relevanceScore: x.relevanceScore,
2024-12-11 19:46:11 -03:00
}));
2024-10-28 16:02:07 -03:00
}
2024-12-26 12:41:37 -03:00
export async function rerankLinks(
mappedLinks: MapDocument[],
searchQuery: string,
urlTraces: URLTrace[],
): Promise<MapDocument[]> {
// console.log("Going to rerank links");
2024-12-26 12:41:37 -03:00
const mappedLinksRerank = mappedLinks.map(
(x) => `url: ${x.url}, title: ${x.title}, description: ${x.description}`,
);
const linksAndScores = await performRanking(
mappedLinksRerank,
mappedLinks.map((l) => l.url),
2025-01-22 18:47:44 -03:00
searchQuery,
2024-12-26 12:41:37 -03:00
);
// First try with high threshold
let filteredLinks = filterAndProcessLinks(
mappedLinks,
linksAndScores,
extractConfig.RERANKING.INITIAL_SCORE_THRESHOLD_FOR_RELEVANCE,
2024-12-26 12:41:37 -03:00
);
// If we don't have enough high-quality links, try with lower threshold
if (filteredLinks.length < extractConfig.RERANKING.MIN_REQUIRED_LINKS) {
2024-12-26 12:41:37 -03:00
logger.info(
`Only found ${filteredLinks.length} links with score > ${extractConfig.RERANKING.INITIAL_SCORE_THRESHOLD_FOR_RELEVANCE}. Trying lower threshold...`,
2024-12-26 12:41:37 -03:00
);
filteredLinks = filterAndProcessLinks(
mappedLinks,
linksAndScores,
extractConfig.RERANKING.FALLBACK_SCORE_THRESHOLD_FOR_RELEVANCE,
2024-12-26 12:41:37 -03:00
);
if (filteredLinks.length === 0) {
// If still no results, take top N results regardless of score
logger.warn(
`No links found with score > ${extractConfig.RERANKING.FALLBACK_SCORE_THRESHOLD_FOR_RELEVANCE}. Taking top ${extractConfig.RERANKING.MIN_REQUIRED_LINKS} results.`,
2024-12-26 12:41:37 -03:00
);
filteredLinks = linksAndScores
.sort((a, b) => b.score - a.score)
.slice(0, extractConfig.RERANKING.MIN_REQUIRED_LINKS)
2024-12-26 12:41:37 -03:00
.map((x) => mappedLinks.find((link) => link.url === x.link))
.filter(
(x): x is MapDocument =>
x !== undefined && x.url !== undefined && !isUrlBlocked(x.url),
);
}
}
// Update URL traces with relevance scores and mark filtered out URLs
linksAndScores.forEach((score) => {
const trace = urlTraces.find((t) => t.url === score.link);
if (trace) {
trace.relevanceScore = score.score;
// If URL didn't make it through filtering, mark it as filtered out
2025-01-10 18:35:10 -03:00
if (!filteredLinks.some((link) => link.url === score.link)) {
2024-12-26 12:41:37 -03:00
trace.warning = `Relevance score ${score.score} below threshold`;
trace.usedInCompletion = false;
}
}
});
2025-01-22 18:47:44 -03:00
const rankedLinks = filteredLinks.slice(
0,
extractConfig.RERANKING.MAX_RANKING_LIMIT_FOR_RELEVANCE,
);
2024-12-26 12:41:37 -03:00
// Mark URLs that will be used in completion
2025-01-10 18:35:10 -03:00
rankedLinks.forEach((link) => {
const trace = urlTraces.find((t) => t.url === link.url);
2024-12-26 12:41:37 -03:00
if (trace) {
trace.usedInCompletion = true;
}
});
// Mark URLs that were dropped due to ranking limit
2025-01-22 18:47:44 -03:00
filteredLinks
.slice(extractConfig.RERANKING.MAX_RANKING_LIMIT_FOR_RELEVANCE)
.forEach((link) => {
const trace = urlTraces.find((t) => t.url === link.url);
if (trace) {
trace.warning = "Excluded due to ranking limit";
trace.usedInCompletion = false;
}
});
2024-12-26 12:41:37 -03:00
// console.log("Reranked links: ", rankedLinks.length);
2024-12-26 12:41:37 -03:00
return rankedLinks;
}
function filterAndProcessLinks(
mappedLinks: MapDocument[],
linksAndScores: {
link: string;
linkWithContext: string;
score: number;
originalIndex: number;
}[],
threshold: number,
): MapDocument[] {
return linksAndScores
.filter((x) => x.score > threshold)
.map((x) => mappedLinks.find((link) => link.url === x.link))
.filter(
(x): x is MapDocument =>
x !== undefined && x.url !== undefined && !isUrlBlocked(x.url),
);
}
export type RerankerResult = {
2025-01-24 18:09:25 -03:00
mapDocument: (MapDocument & { relevanceScore?: number; reason?: string })[];
tokensUsed: number;
2025-01-22 18:47:44 -03:00
};
export type RerankerOptions = {
links: MapDocument[];
searchQuery: string;
urlTraces: URLTrace[];
};
export async function rerankLinksWithLLM(options: RerankerOptions): Promise<RerankerResult> {
const { links, searchQuery, urlTraces } = options;
2025-01-24 18:35:36 -03:00
const chunkSize = 100;
const chunks: MapDocument[][] = [];
const TIMEOUT_MS = 20000;
const MAX_RETRIES = 2;
let totalTokensUsed = 0;
2025-01-22 18:47:44 -03:00
// Split links into chunks of 200
for (let i = 0; i < links.length; i += chunkSize) {
chunks.push(links.slice(i, i + chunkSize));
}
// console.log(`Total links: ${mappedLinks.length}, Number of chunks: ${chunks.length}`);
const schema = {
type: "object",
properties: {
relevantLinks: {
type: "array",
items: {
type: "object",
properties: {
url: { type: "string" },
2025-01-22 18:47:44 -03:00
relevanceScore: { type: "number" },
2025-01-24 18:09:25 -03:00
reason: { type: "string", description: "The reason why you chose the score for this link given the intent." },
},
required: ["url", "relevanceScore", "reason"],
2025-01-22 18:47:44 -03:00
},
},
},
2025-01-22 18:47:44 -03:00
required: ["relevantLinks"],
};
const results = await Promise.all(
chunks.map(async (chunk, chunkIndex) => {
// console.log(`Processing chunk ${chunkIndex + 1}/${chunks.length} with ${chunk.length} links`);
2025-01-22 18:47:44 -03:00
const linksContent = chunk
.map(
(link) =>
`URL: ${link.url}${link.title ? `\nTitle: ${link.title}` : ""}${link.description ? `\nDescription: ${link.description}` : ""}`,
)
.join("\n\n");
for (let retry = 0; retry <= MAX_RETRIES; retry++) {
try {
const timeoutPromise = new Promise<null>((resolve) => {
setTimeout(() => resolve(null), TIMEOUT_MS);
});
2025-01-24 18:09:25 -03:00
// dumpToFile(new Date().toISOString(),[buildRerankerSystemPrompt(), buildRerankerUserPrompt(searchQuery), schema, linksContent])
const completionPromise = generateCompletions({
logger: logger.child({
2025-01-22 18:47:44 -03:00
method: "rerankLinksWithLLM",
chunk: chunkIndex + 1,
retry,
}),
options: {
mode: "llm",
systemPrompt: buildRerankerSystemPrompt(),
prompt: buildRerankerUserPrompt(searchQuery),
2025-01-22 18:47:44 -03:00
schema: schema,
},
markdown: linksContent,
isExtractEndpoint: true
});
2025-01-22 18:47:44 -03:00
const completion = await Promise.race([
completionPromise,
timeoutPromise,
]);
if (!completion) {
// console.log(`Chunk ${chunkIndex + 1}: Timeout on attempt ${retry + 1}`);
continue;
}
if (!completion.extract?.relevantLinks) {
// console.warn(`Chunk ${chunkIndex + 1}: No relevant links found in completion response`);
return [];
}
totalTokensUsed += completion.numTokens || 0;
// console.log(`Chunk ${chunkIndex + 1}: Found ${completion.extract.relevantLinks.length} relevant links`);
return completion.extract.relevantLinks;
} catch (error) {
2025-01-22 18:47:44 -03:00
console.warn(
`Error processing chunk ${chunkIndex + 1} attempt ${retry + 1}:`,
error,
);
if (retry === MAX_RETRIES) {
// console.log(`Chunk ${chunkIndex + 1}: Max retries reached, returning empty array`);
return [];
}
}
}
return [];
2025-01-22 18:47:44 -03:00
}),
);
// console.log(`Processed ${results.length} chunks`);
// Flatten results and sort by relevance score
2025-01-22 18:47:44 -03:00
const flattenedResults = results
.flat()
.sort((a, b) => b.relevanceScore - a.relevanceScore);
// console.log(`Total relevant links found: ${flattenedResults.length}`);
// Map back to MapDocument format, keeping only relevant links
const relevantLinks = flattenedResults
.map((result) => {
const link = links.find((link) => link.url === result.url);
if (link) {
2025-01-24 18:09:25 -03:00
return { ...link, relevanceScore: result.relevanceScore ? parseFloat(result.relevanceScore) : 0, reason: result.reason };
}
return undefined;
})
.filter((link): link is NonNullable<typeof link> => link !== undefined);
return {
mapDocument: relevantLinks,
tokensUsed: totalTokensUsed,
};
}