package dev.langchain4j.model.workersai;

import dev.langchain4j.data.image.Image;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.image.ImageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.workersai.client.AbstractWorkersAIModel;
import dev.langchain4j.model.workersai.client.WorkersAiImageGenerationRequest;
import dev.langchain4j.model.workersai.spi.WorkersAiImageModelBuilderFactory;
import dev.langchain4j.spi.ServiceHelper;
import java.awt.image.BufferedImage;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.util.Base64;
import java.util.Iterator;
import javax.imageio.ImageIO;
import okhttp3.ResponseBody;

/* loaded from: input_file:dev/langchain4j/model/workersai/WorkersAiImageModel.class */
public class WorkersAiImageModel extends AbstractWorkersAIModel implements ImageModel {
    private static final String MIME_TYPE = "image/png";

    /* loaded from: input_file:dev/langchain4j/model/workersai/WorkersAiImageModel$Builder.class */
    public static class Builder {
        public String accountId;
        public String apiToken;
        public String modelName;

        public Builder accountId(String str) {
            this.accountId = str;
            return this;
        }

        public Builder apiToken(String str) {
            this.apiToken = str;
            return this;
        }

        public Builder modelName(String str) {
            this.modelName = str;
            return this;
        }

        public WorkersAiImageModel build() {
            return new WorkersAiImageModel(this);
        }
    }

    public WorkersAiImageModel(Builder builder) {
        this(builder.accountId, builder.modelName, builder.apiToken);
    }

    public WorkersAiImageModel(String str, String str2, String str3) {
        super(str, str2, str3);
    }

    public static Builder builder() {
        Iterator it = ServiceHelper.loadFactories(WorkersAiImageModelBuilderFactory.class).iterator();
        return it.hasNext() ? ((WorkersAiImageModelBuilderFactory) it.next()).get() : new Builder();
    }

    @Override // dev.langchain4j.model.image.ImageModel
    public Response<Image> generate(String str) {
        ValidationUtils.ensureNotBlank(str, "Prompt");
        return new Response<>(convertAsImage(executeQuery(str, null, null)), null, FinishReason.STOP);
    }

    @Override // dev.langchain4j.model.image.ImageModel
    public Response<Image> edit(Image image, String str) {
        ValidationUtils.ensureNotBlank(str, "Prompt");
        ValidationUtils.ensureNotNull(image, "Image");
        return new Response<>(convertAsImage(executeQuery(str, null, image)), null, FinishReason.STOP);
    }

    @Override // dev.langchain4j.model.image.ImageModel
    public Response<Image> edit(Image image, Image image2, String str) {
        ValidationUtils.ensureNotBlank(str, "Prompt");
        ValidationUtils.ensureNotNull(image, "Image");
        ValidationUtils.ensureNotNull(image2, "Mask");
        return new Response<>(convertAsImage(executeQuery(str, image2, image)), null, FinishReason.STOP);
    }

    public Response<File> generate(String str, String str2) {
        ValidationUtils.ensureNotBlank(str, "Prompt");
        ValidationUtils.ensureNotBlank(str2, "Destination file");
        try {
            byte[] executeQuery = executeQuery(str, null, null);
            FileOutputStream fileOutputStream = new FileOutputStream(str2);
            try {
                fileOutputStream.write(executeQuery);
                fileOutputStream.close();
                return new Response<>(new File(str2), null, FinishReason.STOP);
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private byte[] executeQuery(String str, Image image, Image image2) {
        try {
            WorkersAiImageGenerationRequest workersAiImageGenerationRequest = new WorkersAiImageGenerationRequest();
            workersAiImageGenerationRequest.setPrompt(str);
            if (image != null && image.url() != null) {
                workersAiImageGenerationRequest.setImage(getPixels(image.url().toURL()));
            }
            if (image2 != null && image2.url() != null) {
                workersAiImageGenerationRequest.setMask(getPixels(image2.url().toURL()));
            }
            retrofit2.Response execute = this.workerAiClient.generateImage(workersAiImageGenerationRequest, this.accountId, this.modelName).execute();
            if (!execute.isSuccessful() || execute.body() == null) {
                throw new IllegalStateException("An error occured while generating image.");
            }
            InputStream byteStream = ((ResponseBody) execute.body()).byteStream();
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            byte[] bArr = new byte[1024];
            while (true) {
                int read = byteStream.read(bArr, 0, bArr.length);
                if (read == -1) {
                    byteArrayOutputStream.flush();
                    return byteArrayOutputStream.toByteArray();
                }
                byteArrayOutputStream.write(bArr, 0, read);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public int[] getPixels(URL url) throws Exception {
        BufferedImage read = ImageIO.read(url);
        int width = read.getWidth();
        int height = read.getHeight();
        int[] iArr = new int[width * height];
        int i = 0;
        for (int i2 = 0; i2 < height; i2++) {
            for (int i3 = 0; i3 < width; i3++) {
                int rgb = read.getRGB(i3, i2);
                int i4 = i;
                i++;
                iArr[i4] = (((rgb >> 24) & 255) << 24) | (((rgb >> 16) & 255) << 16) | (((rgb >> 8) & 255) << 8) | (rgb & 255);
            }
        }
        return iArr;
    }

    public Image convertAsImage(byte[] bArr) {
        return Image.builder().base64Data(Base64.getEncoder().encodeToString(bArr)).mimeType(MIME_TYPE).build();
    }
}
