152 lines
4.2 KiB
TypeScript
Raw Normal View History

2024-12-26 12:41:37 -03:00
import { MapDocument, URLTrace } from "../../controllers/v1/types";
import { performRanking } from "../ranker";
import { isUrlBlocked } from "../../scraper/WebScraper/utils/blocklist";
import { logger } from "../logger";
2024-10-28 16:02:07 -03:00
import { CohereClient } from "cohere-ai";
import { extractConfig } from "./config";
2024-12-26 12:41:37 -03:00
2024-10-28 16:02:07 -03:00
const cohere = new CohereClient({
2024-12-11 19:51:08 -03:00
token: process.env.COHERE_API_KEY,
2024-10-28 16:02:07 -03:00
});
2024-12-26 12:41:37 -03:00
interface RankingResult {
mappedLinks: MapDocument[];
linksAndScores: {
link: string;
linkWithContext: string;
score: number;
originalIndex: number;
}[];
}
2024-10-28 16:02:07 -03:00
export async function rerankDocuments(
documents: (string | Record<string, string>)[],
query: string,
topN = 3,
2024-12-11 19:51:08 -03:00
model = "rerank-english-v3.0",
2024-10-28 16:02:07 -03:00
) {
const rerank = await cohere.v2.rerank({
documents,
query,
topN,
model,
2024-12-11 19:51:08 -03:00
returnDocuments: true,
2024-10-28 16:02:07 -03:00
});
2024-12-11 19:46:11 -03:00
return rerank.results
.sort((a, b) => b.relevanceScore - a.relevanceScore)
.map((x) => ({
document: x.document,
index: x.index,
2024-12-11 19:51:08 -03:00
relevanceScore: x.relevanceScore,
2024-12-11 19:46:11 -03:00
}));
2024-10-28 16:02:07 -03:00
}
2024-12-26 12:41:37 -03:00
export async function rerankLinks(
mappedLinks: MapDocument[],
searchQuery: string,
urlTraces: URLTrace[],
): Promise<MapDocument[]> {
const mappedLinksRerank = mappedLinks.map(
(x) => `url: ${x.url}, title: ${x.title}, description: ${x.description}`,
);
const linksAndScores = await performRanking(
mappedLinksRerank,
mappedLinks.map((l) => l.url),
searchQuery,
);
2024-12-26 12:41:37 -03:00
// First try with high threshold
let filteredLinks = filterAndProcessLinks(
mappedLinks,
linksAndScores,
extractConfig.INITIAL_SCORE_THRESHOLD,
2024-12-26 12:41:37 -03:00
);
2024-12-26 12:41:37 -03:00
// If we don't have enough high-quality links, try with lower threshold
if (filteredLinks.length < extractConfig.MIN_REQUIRED_LINKS) {
2024-12-26 12:41:37 -03:00
logger.info(
`Only found ${filteredLinks.length} links with score > ${extractConfig.INITIAL_SCORE_THRESHOLD}. Trying lower threshold...`,
2024-12-26 12:41:37 -03:00
);
filteredLinks = filterAndProcessLinks(
mappedLinks,
linksAndScores,
extractConfig.FALLBACK_SCORE_THRESHOLD,
2024-12-26 12:41:37 -03:00
);
if (filteredLinks.length === 0) {
// If still no results, take top N results regardless of score
logger.warn(
`No links found with score > ${extractConfig.FALLBACK_SCORE_THRESHOLD}. Taking top ${extractConfig.MIN_REQUIRED_LINKS} results.`,
2024-12-26 12:41:37 -03:00
);
filteredLinks = linksAndScores
.sort((a, b) => b.score - a.score)
.slice(0, extractConfig.MIN_REQUIRED_LINKS)
2024-12-26 12:41:37 -03:00
.map((x) => mappedLinks.find((link) => link.url === x.link))
.filter(
(x): x is MapDocument =>
x !== undefined && x.url !== undefined && !isUrlBlocked(x.url),
);
}
}
// Update URL traces with relevance scores and mark filtered out URLs
linksAndScores.forEach((score) => {
const trace = urlTraces.find((t) => t.url === score.link);
if (trace) {
trace.relevanceScore = score.score;
// If URL didn't make it through filtering, mark it as filtered out
if (!filteredLinks.some(link => link.url === score.link)) {
trace.warning = `Relevance score ${score.score} below threshold`;
trace.usedInCompletion = false;
}
}
});
const rankedLinks = filteredLinks.slice(0, extractConfig.MAX_RANKING_LIMIT);
2024-12-26 12:41:37 -03:00
// Mark URLs that will be used in completion
rankedLinks.forEach(link => {
const trace = urlTraces.find(t => t.url === link.url);
if (trace) {
trace.usedInCompletion = true;
}
});
// Mark URLs that were dropped due to ranking limit
filteredLinks.slice(extractConfig.MAX_RANKING_LIMIT).forEach(link => {
2024-12-26 12:41:37 -03:00
const trace = urlTraces.find(t => t.url === link.url);
if (trace) {
trace.warning = 'Excluded due to ranking limit';
trace.usedInCompletion = false;
}
});
return rankedLinks;
}
function filterAndProcessLinks(
mappedLinks: MapDocument[],
linksAndScores: {
link: string;
linkWithContext: string;
score: number;
originalIndex: number;
}[],
threshold: number,
): MapDocument[] {
return linksAndScores
.filter((x) => x.score > threshold)
.map((x) => mappedLinks.find((link) => link.url === x.link))
.filter(
(x): x is MapDocument =>
x !== undefined && x.url !== undefined && !isUrlBlocked(x.url),
);
}