package com.envestnet.aaacli.core;

import java.io.*;
import java.nio.file.*;
import java.util.logging.*;
import java.util.stream.Collectors;

public class MLRunner {

    private final File tempScriptFolder;
    private final Logger logger;
    public final File logFile;

    public MLRunner() throws IOException {
        // Prepare a temporary directory for Python scripts and logs
        tempScriptFolder = Files.createTempDirectory("python_scripts").toFile();
        tempScriptFolder.deleteOnExit();

        // Initialize logging to a file in the temp directory
        logFile = new File(tempScriptFolder, "MLRunner.log");
        logger = Logger.getLogger(MLRunner.class.getName());
        FileHandler fileHandler = new FileHandler(logFile.getAbsolutePath());
        logger.addHandler(fileHandler);
        SimpleFormatter formatter = new SimpleFormatter();
        fileHandler.setFormatter(formatter);
        logger.setLevel(Level.ALL);

        logger.info("Temporary directory for Python scripts and logs: " + tempScriptFolder.getAbsolutePath());
    }

    public void installRequirements() throws IOException, InterruptedException {
        logger.info("Installing Python requirements from requirements.txt...");

        // Define the path to the requirements.txt file in the model directory
        Path requirementsPath = Paths.get("model/requirements.txt");
        File requirementsFile = new File(tempScriptFolder, "requirements.txt");
        Files.copy(requirementsPath, requirementsFile.toPath(), StandardCopyOption.REPLACE_EXISTING);

        // Install Python dependencies
        ProcessBuilder processBuilder = new ProcessBuilder("pip3", "install", "-r", requirementsFile.getAbsolutePath());
        processBuilder.directory(tempScriptFolder);
        Process process = processBuilder.start();
        process.waitFor();
    }

    public void extractPythonScripts() throws IOException {
        logger.info("Extracting Python scripts...");

        // Define the path to the model/python directory
        Path modelPythonPath = Paths.get("model/python");

        String[] scripts = {
                "predict.py",
                // ... other script names ...
        };

        for (String script : scripts) {
            Path scriptPath = modelPythonPath.resolve(script);
            if (!Files.exists(scriptPath)) {
                logger.severe("Script not found: " + scriptPath);
                continue;
            }
            File scriptFile = new File(tempScriptFolder, script);
            scriptFile.getParentFile().mkdirs(); // Ensure the parent directories exist
            Files.copy(scriptPath, scriptFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
            logger.info("Extracted script: " + script);
        }
    }

    public void runScript(String scriptName) throws IOException, InterruptedException {
        logger.info("Running Python script: " + scriptName);
        File scriptFile = new File(tempScriptFolder, scriptName);
        ProcessBuilder processBuilder = new ProcessBuilder("python3", scriptFile.getAbsolutePath());
        processBuilder.directory(tempScriptFolder);
        processBuilder.redirectErrorStream(true);

        Process process = processBuilder.start();
        try (BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()))) {
            String line;
            while ((line = reader.readLine()) != null) {
                logger.info(line);
            }
        }
        int exitCode = process.waitFor();
        logger.info("Script exited with code : " + exitCode);
    }

    public static void main(String[] args) {
        try {
            MLRunner executor = new MLRunner();
            System.out.println("Log file location: " + executor.logFile.getAbsolutePath());
            executor.checkPythonVersion();
            executor.extractPythonScripts();
            executor.installRequirements();
            executor.runScript("predict.py");
        } catch (IOException | InterruptedException e) {
            Logger.getLogger(MLRunner.class.getName()).log(Level.SEVERE, null, e);
        }
    }
}
