import aidlite  # 导入AIDLite库，用于模型推理
import numpy as np  # 导入NumPy库，用于数值计算
import cv2  # 导入OpenCV库，用于图像处理
import os  # 导入os库，用于文件路径操作

#读取标签
def read_label_list():
    with open('label.txt', 'r', encoding="utf8") as f:
        data = f.read().splitlines()  # 读取所有行并去除换行符
    return data

#前处理函数
def preprocess_image(image_path):
    img = cv2.imread(image_path)  # 读取图像文件
    if img is None:
        raise ValueError(f"Failed to load image from {image_path}!")  # 图像加载失败抛出异常
    # 转换为灰度图
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # 将BGR图像转换为灰度图像
    # 调整图像大小
    img = cv2.resize(img, (28, 28))  # 调整图像大小为28x28像素
    # 添加批次和通道维度
    img = np.expand_dims(img, axis=0)  # 添加批次维度，形状变为(1, 28, 28)
    img = np.expand_dims(img, axis=-1)  # 添加通道维度，形状变为(1, 28, 28, 1)
    # 归一化到[0,1]范围
    img = img.astype(np.float32) / 255.0  # 将像素值从0-255缩放到0-1之间
    print("Preprocessed image shape:", img.shape)  # 打印预处理后的图像形状
    return img

#后处理函数
def postprocess_output(output_data, label_list):
    print("Raw output data:", output_data)  # 打印原始输出数据
    # 确保output_data是二维数组
    if output_data.ndim == 1:
        # 如果是一维数组，将其转换为二维数组
        output_data = output_data.reshape(1, -1)  # 重塑为(1, 10)形状
        print("Reshaped output data:", output_data)  # 打印重塑后的输出数据
    # 获取预测结果
    w = np.argmax(output_data)  # 找到输出数组中最大值的索引
    # 根据output_data的形状选择正确的索引方式获取置信度
    if output_data.ndim == 2:
        confidence = output_data[0, w]  # 对于二维数组，使用[0, w]索引
    else:
        confidence = output_data[w]  # 对于一维数组，使用[w]索引
    return label_list[w], confidence

def main():
    # 1. 定义模型和文件路径
    model_path = os.path.abspath('MyModel.tflite')  # 获取模型的绝对路径
    image_path = '77.jpg'  # 待识别的图像文件路径
    
    # 2. 创建AIDLite模型对象
    model = aidlite.Model.create_instance(model_path)  # 创建模型实例
    if model is None:
        print("Create model failed!")  # 模型创建失败处理
        exit(1)  # 退出程序
    
    # 3. 创建配置实例对象
    config = aidlite.Config.create_instance()  # 创建配置实例
    if config is None:
        print("Create config failed!")  # 配置创建失败处理
        exit(1)  # 退出程序
    
    # 设置配置参数
    config.framework_type = aidlite.FrameworkType.TYPE_TFLITE  # 设置框架类型为TFLite
    config.accelerate_type = aidlite.AccelerateType.TYPE_GPU  # 设置加速类型为GPU加速
    config.implement_type = aidlite.ImplementType.TYPE_LOCAL  # 设置实现类型为本地执行
    config.number_of_threads = 4  # 设置线程数为4
    config.is_quantify_model = 0  # 设置是否为量化模型（0表示非量化模型）
    config.fast_timeout = -1  # 设置超时时间（-1表示无超时限制）
    print("Create Config success!")  # 配置创建成功提示
    
    # 4. 创建解释器对象
    interpreter = aidlite.InterpreterBuilder.build_interpretper_from_model_and_config(model, config)
    if interpreter is None:
        print("build_interpretper_from_model_and_config failed!")  # 解释器创建失败处理
        exit(1)  # 退出程序
    print("Create Interpreter success!")  # 解释器创建成功提示
    
    # 初始化解释器
    result = interpreter.init()  # 初始化解释器
    if result != 0:
        print("interpreter->init() failed!")  # 解释器初始化失败处理
        exit(1)  # 退出程序
    print("Interpreter init success!")  # 解释器初始化成功提示
    
    # 加载模型到解释器
    result = interpreter.load_model()  # 加载模型
    if result != 0:
        print("interpreter->load_model() failed!")  # 模型加载失败处理
        exit(1)  # 退出程序
    print("Interpreter load model success!")  # 模型加载成功提示
    
    # 定义输入输出张量的形状
    input_details = [[1, 28, 28, 1]]  # 输入张量形状：批次大小=1, 高度=28, 宽度=28, 通道数=1
    output_details = [[1, 10]]  # 输出张量形状：批次大小=1, 输出类别数=10
    

    # 设置模型属性
    model.set_model_properties(
        input_shapes=input_details, 
        input_data_type=aidlite.DataType.TYPE_FLOAT32, 
        output_shapes=output_details, 
        output_data_type=aidlite.DataType.TYPE_FLOAT32
    )

    # 5. 图像预处理
    try:
        processed_img = preprocess_image(image_path)  # 调用预处理函数处理图像
    except ValueError as e:
        print(e)  # 打印错误信息
        exit(1)  # 退出程序
    
    # 设置输入张量
    result = interpreter.set_input_tensor(0, processed_img)  # 将预处理后的图像数据设置为输入张量
    if result != 0:
        print("interpreter->set_input_tensor() failed!")  # 设置输入张量失败处理
        exit(1)  # 退出程序
    
    # 6. 执行模型推理
    result = interpreter.invoke()  # 执行模型推理
    if result != 0:
        print("interpreter->invoke() failed!")  # 模型推理失败处理
        exit(1)  # 退出程序
    
    # 获取输出张量
    output_data = interpreter.get_output_tensor(0)  # 获取输出张量数据
    if output_data is None:
        print("interpreter->get_output_tensor() failed!")  # 获取输出张量失败处理
        exit(1)  # 退出程序
    
    # 7. 结果后处理
    label_list = read_label_list()  # 读取标签列表
    predicted_label, confidence = postprocess_output(output_data, label_list)  # 调用后处理函数处理输出数据
    # 打印预测结果
    print("Predicted label:", predicted_label)  # 打印预测的标签
    print("Confidence:", confidence)  # 打印置信度

if __name__ == "__main__":
    main()  # 执行主函数