package dev.langchain4j.rag;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator;
import dev.langchain4j.rag.content.injector.ContentInjector;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/rag/DefaultRetrievalAugmentor.class */
public class DefaultRetrievalAugmentor implements RetrievalAugmentor {
    private final QueryTransformer queryTransformer;
    private final QueryRouter queryRouter;
    private final ContentAggregator contentAggregator;
    private final ContentInjector contentInjector;
    private final Executor executor;

    /* loaded from: input_file:dev/langchain4j/rag/DefaultRetrievalAugmentor$DefaultRetrievalAugmentorBuilder.class */
    public static class DefaultRetrievalAugmentorBuilder {
        private QueryTransformer queryTransformer;
        private QueryRouter queryRouter;
        private ContentAggregator contentAggregator;
        private ContentInjector contentInjector;
        private Executor executor;

        DefaultRetrievalAugmentorBuilder() {
        }

        public DefaultRetrievalAugmentorBuilder contentRetriever(ContentRetriever contentRetriever) {
            this.queryRouter = new DefaultQueryRouter((ContentRetriever) ValidationUtils.ensureNotNull(contentRetriever, "contentRetriever"));
            return this;
        }

        public DefaultRetrievalAugmentorBuilder queryTransformer(QueryTransformer queryTransformer) {
            this.queryTransformer = queryTransformer;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder queryRouter(QueryRouter queryRouter) {
            this.queryRouter = queryRouter;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder contentAggregator(ContentAggregator contentAggregator) {
            this.contentAggregator = contentAggregator;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder contentInjector(ContentInjector contentInjector) {
            this.contentInjector = contentInjector;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder executor(Executor executor) {
            this.executor = executor;
            return this;
        }

        public DefaultRetrievalAugmentor build() {
            return new DefaultRetrievalAugmentor(this.queryTransformer, this.queryRouter, this.contentAggregator, this.contentInjector, this.executor);
        }
    }

    public DefaultRetrievalAugmentor(QueryTransformer queryTransformer, QueryRouter queryRouter, ContentAggregator contentAggregator, ContentInjector contentInjector, Executor executor) {
        this.queryTransformer = (QueryTransformer) Utils.getOrDefault(queryTransformer, (Supplier<QueryTransformer>) DefaultQueryTransformer::new);
        this.queryRouter = (QueryRouter) ValidationUtils.ensureNotNull(queryRouter, "queryRouter");
        this.contentAggregator = (ContentAggregator) Utils.getOrDefault(contentAggregator, (Supplier<ContentAggregator>) DefaultContentAggregator::new);
        this.contentInjector = (ContentInjector) Utils.getOrDefault(contentInjector, (Supplier<ContentInjector>) DefaultContentInjector::new);
        this.executor = (Executor) Utils.getOrDefault(executor, (Supplier<Executor>) DefaultRetrievalAugmentor::createDefaultExecutor);
    }

    private static ExecutorService createDefaultExecutor() {
        return new ThreadPoolExecutor(0, Integer.MAX_VALUE, 1L, TimeUnit.SECONDS, new SynchronousQueue());
    }

    @Override // dev.langchain4j.rag.RetrievalAugmentor
    public AugmentationResult augment(AugmentationRequest augmentationRequest) {
        ChatMessage chatMessage = augmentationRequest.chatMessage();
        if (!(chatMessage instanceof UserMessage)) {
            throw new IllegalArgumentException("Unsupported message type: " + chatMessage.type());
        }
        List<Content> aggregate = this.contentAggregator.aggregate(process(this.queryTransformer.transform(Query.from(((UserMessage) chatMessage).singleText(), augmentationRequest.metadata()))));
        return AugmentationResult.builder().chatMessage(this.contentInjector.inject(aggregate, chatMessage)).contents(aggregate).build();
    }

    private Map<Query, Collection<List<Content>>> process(Collection<Query> collection) {
        if (collection.size() == 1) {
            Query next = collection.iterator().next();
            Collection<ContentRetriever> route = this.queryRouter.route(next);
            return route.size() == 1 ? Collections.singletonMap(next, Collections.singletonList(route.iterator().next().retrieve(next))) : route.size() > 1 ? Collections.singletonMap(next, retrieveFromAll(route, next).join()) : Collections.emptyMap();
        }
        if (collection.size() <= 1) {
            return Collections.emptyMap();
        }
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        collection.forEach(query -> {
            concurrentHashMap.put(query, CompletableFuture.supplyAsync(() -> {
                return this.queryRouter.route(query);
            }, this.executor).thenCompose(collection2 -> {
                return retrieveFromAll(collection2, query);
            }));
        });
        return join(concurrentHashMap);
    }

    private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> collection, Query query) {
        List list = (List) collection.stream().map(contentRetriever -> {
            return CompletableFuture.supplyAsync(() -> {
                return contentRetriever.retrieve(query);
            }, this.executor);
        }).collect(Collectors.toList());
        return CompletableFuture.allOf((CompletableFuture[]) list.toArray(new CompletableFuture[0])).thenApply(r4 -> {
            return (Collection) list.stream().map((v0) -> {
                return v0.join();
            }).collect(Collectors.toList());
        });
    }

    private static Map<Query, Collection<List<Content>>> join(Map<Query, CompletableFuture<Collection<List<Content>>>> map) {
        return (Map) CompletableFuture.allOf((CompletableFuture[]) map.values().toArray(new CompletableFuture[0])).thenApply(r5 -> {
            return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return (Collection) ((CompletableFuture) entry.getValue()).join();
            }));
        }).join();
    }

    public static DefaultRetrievalAugmentorBuilder builder() {
        return new DefaultRetrievalAugmentorBuilder();
    }
}
