import cv2
import numpy as np
import os
from api_infer import *

# 摄像头检测和图像捕获部分
def capture_image():
    ID = 0
    while True:
        cap = cv2.VideoCapture(ID)
        ret, frame = cap.read()
        if ret == False:
            ID += 1
            if ID > 10:  # 限制最大尝试的摄像头数量
                print("未找到可用的摄像头")
                return None
        else:
            print(f"找到摄像头 ID: {ID}")
            # 捕获一帧图像
            success, img = cap.read()
            cap.release()
            if success:
                # 保存图像到当前目录
                image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "captured_image.jpg")
                cv2.imwrite(image_path, img)
                print(f"图像已保存到: {image_path}")
                return image_path
            break
    return None

# SNPE模型推理部分
def run_snpe(image_path):
    # 设置文件路径
    current_dir = os.path.dirname(os.path.abspath(__file__))
    dlc_path = os.path.join(current_dir, "MyModel.dlc")
    output_dir = current_dir
    label_file = os.path.join(current_dir, "label.txt")

    # 读取标签文件，返回列表
    def read_label_list():
        with open(label_file, 'r', encoding="utf8") as f:
            data = f.read().splitlines()
        return data

    # 创建 SNPE 运行时环境
    snpe_ort = SnpeContext(dlc_path, [], Runtime.GPU, PerfProfile.BALANCED, LogLevel.INFO)
    assert snpe_ort.Initialize() == 0  # 初始化 SNPE 环境

    # 图像预处理
    img1 = cv2.imread(image_path)  # 读取图片
    img = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)  # 转换为灰度图
    img_inverted = 255 - img  # 反色处理
    _, bit_img = cv2.threshold(img_inverted, 127, 255, cv2.THRESH_BINARY)  # 二值化处理
    bit_img_resized = cv2.resize(bit_img, (28, 28), interpolation=cv2.INTER_AREA)  # 调整大小为 28x28
    _, bit_img_final = cv2.threshold(bit_img_resized, 20, 255, cv2.THRESH_BINARY)  # 再次二值化处理
    
    # 保存预处理后的图片
    preprocessed_path = os.path.join(output_dir, "preprocessed_image.jpg")
    cv2.imwrite(preprocessed_path, bit_img_final)
    print(f"预处理图像已保存到: {preprocessed_path}")

    # 准备模型输入数据
    input_feed = {"serving_default_conv2d_input:0": bit_img_final}

    # 执行模型推理
    outputs = snpe_ort.Execute(["StatefulPartitionedCall:0"], input_feed)
    print("模型输出结果:", outputs)
    
    # 后处理部分
    if outputs is not None:
        for k, v in outputs.items():
            print(f"输出名称: {k}, 输出数据: {v}")
            if k == "StatefulPartitionedCall:0":  # 根据输出名称匹配
                output_data = v  # 获取输出数据
                w = np.argmax(output_data)  # 找到最大值的位置
                label_list = read_label_list()  # 读取标签列表
                print("预测结果:", label_list[w])  # 打印预测标签

    # 释放资源
    assert snpe_ort.Release() == 0

if __name__ == "__main__":
    # 1. 捕获图像
    captured_image_path = capture_image()
    
    if captured_image_path:
        # 2. 运行模型推理
        run_snpe(captured_image_path)
    else:
        print("无法捕获图像，程序退出")