package dev.langchain4j.model.vertexai;

import com.google.cloud.discoveryengine.v1beta.RankRequest;
import com.google.cloud.discoveryengine.v1beta.RankServiceClient;
import com.google.cloud.discoveryengine.v1beta.RankServiceSettings;
import com.google.cloud.discoveryengine.v1beta.RankingConfigName;
import com.google.cloud.discoveryengine.v1beta.RankingRecord;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.scoring.ScoringModel;
import java.io.IOException;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/model/vertexai/VertexAiScoringModel.class */
public class VertexAiScoringModel implements ScoringModel {
    private final String model;
    private final String projectId;
    private final String projectNumber;
    private final String location;
    private final String titleMetadataKey;

    /* loaded from: input_file:dev/langchain4j/model/vertexai/VertexAiScoringModel$Builder.class */
    public static class Builder {
        private String model;
        private String projectId;
        private String projectNumber;
        private String location;
        private String titleMetadataKey;

        public Builder model(String str) {
            this.model = ValidationUtils.ensureNotBlank(str, "model");
            return this;
        }

        public Builder projectId(String str) {
            this.projectId = str;
            return this;
        }

        public Builder projectNumber(String str) {
            this.projectNumber = str;
            return this;
        }

        public Builder location(String str) {
            this.location = str;
            return this;
        }

        public Builder titleMetadataKey(String str) {
            this.titleMetadataKey = ValidationUtils.ensureNotBlank(str, "titleMetadataKey");
            return this;
        }

        public VertexAiScoringModel build() {
            return new VertexAiScoringModel(this.projectId, this.projectNumber, this.location, this.model, this.titleMetadataKey);
        }
    }

    public VertexAiScoringModel(String str, String str2, String str3, String str4, String str5) {
        this.projectId = ValidationUtils.ensureNotBlank(str, "projectId");
        this.projectNumber = ValidationUtils.ensureNotBlank(str2, "projectNumber");
        this.location = ValidationUtils.ensureNotBlank(str3, "location");
        this.model = ValidationUtils.ensureNotBlank(str4, "model");
        this.titleMetadataKey = str5 != null ? str5 : "title";
    }

    @Override // dev.langchain4j.model.scoring.ScoringModel
    public Response<List<Double>> scoreAll(List<TextSegment> list, String str) {
        AtomicInteger atomicInteger = new AtomicInteger();
        try {
            RankServiceClient create = RankServiceClient.create(RankServiceSettings.newBuilder().build());
            try {
                RankRequest.Builder newBuilder = RankRequest.newBuilder();
                if (this.model != null && !this.model.isEmpty()) {
                    newBuilder.setModel(this.model);
                }
                newBuilder.setRankingConfig(RankingConfigName.newBuilder().setProject(this.projectId).setLocation(this.location).setRankingConfig(String.format("projects/%s/locations/%s/rankingConfigs/default_ranking_config.", this.projectNumber, this.location)).build().getRankingConfig()).setQuery(str).setIgnoreRecordDetailsInResponse(true).addAllRecords((Iterable) list.stream().map(textSegment -> {
                    RankingRecord.Builder content = RankingRecord.newBuilder().setContent(textSegment.text());
                    if (textSegment.metadata().getString(this.titleMetadataKey) != null) {
                        content.setTitle(textSegment.metadata().getString(this.titleMetadataKey));
                    }
                    content.setId(String.valueOf(atomicInteger.getAndIncrement()));
                    return content.build();
                }).collect(Collectors.toList()));
                Response<List<Double>> from = Response.from((List) create.rank(newBuilder.build()).getRecordsList().stream().sorted(Comparator.comparing(rankingRecord -> {
                    return Double.valueOf(rankingRecord.getId());
                })).map((v0) -> {
                    return v0.getScore();
                }).map((v0) -> {
                    return Double.valueOf(v0);
                }).collect(Collectors.toList()));
                if (create != null) {
                    create.close();
                }
                return from;
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
