ESC
输入关键词搜索文章
目录

Video Processing Model

Video Processing Model

which python

Import necessary libraries

import os
from datetime import datetime
import pathlib
# import random
# import shutil
# import zipfile
# from collections import defaultdict
# import cv2
# import einops
# import imageio
# import numpy as np
# import pandas as pd
# import pydot
# import remotezip as rz
import tensorflow as tf
# import tqdm
from sklearn import *
# from sklearn.metrics import accuracy_score
# from sklearn.model_selection import cross_val_score, train_test_split
from tensorflow import keras
from tensorflow.keras import layers

hyper parameters

lr = 0.0001
frame_number = 40
batch_size = 2
shuffle_buffer_size = 16
HEIGHT = 224
WIDTH = 224

video_dir = "Videos"
correct_dir = "Videos/Correct"
incorrect_dir = "Videos/Incorrect"

DataLoader

# Paths to the directories containing videos

from pdb import main
import remotezip as rz
import zipfile
from collections import defaultdict
from variables import correct_dir, incorrect_dir
import imageio
import random
import os
import shutil
import tensorflow as tf
import pathlib
import tqdm
import cv2
import numpy as np
# if you change the structure of the zip file, maybe you need to modify following functions: get_class, list_file_for_class
class DataLoader:
    def __init__(
        self,
        correct_dir=correct_dir,
        incorrect_dir=incorrect_dir,
        zip_file_path="Videos.zip",
        zip_url=None,
        to_dir="Videos",
        filetype=".mov",
    ):
        self.correct_dir = correct_dir
        self.incorrect_dir = incorrect_dir
        self.zip_file_path = zip_file_path
        self.zip_url = zip_url
        self.to_dir = to_dir
        self.filetype = filetype

    def list_files_from_zip_url(self):
        """List the files in each class of the dataset given a URL with the zip file.

        Args:
          zip_url: A URL from which the files can be extracted from.

        Returns:
          List of files in each of the classes.
        """
        files = []
        with rz.RemoteZip(self.zip_url) as zip:
            for zip_info in zip.infolist():
                if zip_info.filename.lower().endswith(".mov"):
                    files.append(zip_info.filename)
        return files

    def list_files_from_zip_file(self):
        """List the files in each class of the dataset given a local zip file.

        Args:
          zip_file_path: Path to the local zip file.

        Returns:
          List of files in the zip file.
        """
        files = []
        with zipfile.ZipFile(self.zip_file_path, "r") as zip:
            for zip_info in zip.infolist():
                if zip_info.filename.lower().endswith(self.filetype):
                    files.append(zip_info.filename)
        return files

    def get_number_of_files(self):
        files = self.list_files_from_zip_file()
        return len(files)

    def get_number_of_class(self):
        """
        return how many class the dataset has.
        """
        return len(self.list_files_for_class())

    def get_number_of_files_for_class(self, show=False):
        """List the files in each class of the dataset given a local zip file.

        Args:
          zip_file_path: Path to the local zip file.

        Returns:
          List of files in the zip file.
        """
        files = self.list_files_for_class()
        numbers = defaultdict(int)
        for cls in files.keys():
            if show:
                print(f"Class: {cls}, Number of files: {len(files[cls])}")
            numbers[cls] = len(files[cls])
        return numbers

    def list_files_for_class(self):
        # list all the files in all the subdirectories of the correct and incorrect directories
        files_for_class = defaultdict(list)
        files = self.list_files_from_zip_file()
        for fname in files:
            class_name = self.get_class(fname)
            files_for_class[class_name].append(fname)
        return files_for_class

    def select_subset_of_classes(self):
        """Create a dictionary with the class name and a subset of the files in that class.

        Args:
          files_for_class: Dictionary of class names (key) and files (values).
          classes: List of classes.
          files_per_class: Number of files per class of interest.

        Returns:
          Dictionary with class as key and list of specified number of video files in that class.
        """
        files_subset = dict()
        files_for_class = self.list_files_for_class()
        classes = list(files_for_class.keys())
        files_per_class = []

        files_per_class = 100
        for class_name in classes:
            class_files = files_for_class[class_name]
            files_subset[class_name] = class_files[:files_per_class]
        return files_subset

    def get_class(self, fname):
        """Retrieve the name of the class given a filename.

        Args:
          fname: Name of the file in the UCF101 dataset.

        Returns:
          Class that the file belongs to.
        """
        return fname.split("/")[-2]

    def get_files_per_class_from_zip_file(self):
        """Retrieve the files that belong to each class.

        Args:
          files: List of files in the dataset.

        Returns:
          Dictionary of class names (key) and files (values).
        """
        files = self.list_files_from_zip_file()
        files_for_class = defaultdict(list)
        for fname in files:
            class_name = self.get_class(fname)
            files_for_class[class_name].append(fname)
        return files_for_class

    def select_subset_of_classes_from_zip_file(self):
        """Create a dictionary with the class name and a subset of the files in that class.

        Args:
          files_for_class: Dictionary of class names (key) and files (values).
          classes: List of classes.
          files_per_class: Number of files per class of interest.

        Returns:
          Dictionary with class as key and list of specified number of video files in that class.
        """
        files_subset = dict()
        files_for_class = self.get_files_per_class_from_zip_file()
        classes = list(files_for_class.keys())
        files_per_class = []

        files_per_class = 100
        for class_name in classes:
            class_files = files_for_class[class_name]
            files_subset[class_name] = class_files[:files_per_class]
        return files_subset

    def download_from_zip_file(self, to_dir=None, file_names=None):
        """Download the contents of the zip file from a local zip file.

        Args:
          zip_file: Path to a local zip file containing data.
          to_dir: A directory to download data to.
          file_names: Names of files to download.
        """
        if os.path.exists(to_dir) and os.path.isdir(to_dir):
            shutil.rmtree(to_dir)
        if file_names is None:
            file_names = self.list_files_from_zip_file()
        if to_dir is None:
            to_dir = self.to_dir
        with zipfile.ZipFile(self.zip_file_path, "r") as zip:
            for fn in tqdm.tqdm(file_names):
                # Determine the class name from the file name
                class_name = self.get_class(fn)

                # Extract the file to the target directory under the class subdirectory
                zip.extract(fn, str(pathlib.Path(to_dir) / class_name))
                unzipped_file = pathlib.Path(to_dir) / class_name / fn

                # Rename the file to keep only the last part of its path
                fn = pathlib.Path(fn).name  # Extract the file
                output_file = pathlib.Path(to_dir) / class_name / fn
                unzipped_file.rename(output_file)
                # delete unnecessary file folder
        return file_names

    def download_from_zip_url(self):
        """Download the contents of the zip file from the zip URL.

        Args:
          zip_url: A URL with a zip file containing data.
          to_dir: A directory to download data to.
          file_names: Names of files to download.
        """
        if os.path.exists(self.to_dir) and os.path.isdir(self.to_dir):
            shutil.rmtree(self.to_dir)
        file_names = self.list_files_from_zip_url()
        with rz.RemoteZip(self.zip_url) as zip:
            for fn in tqdm.tqdm(file_names):
                class_name = self.get_class(fn)
                zip.extract(fn, str(pathlib.Path(self.to_dir) / class_name))
                unzipped_file = pathlib.Path(self.to_dir) / class_name / fn

                fn = pathlib.Path(fn).parts[-1]
                output_file = pathlib.Path(self.to_dir) / class_name / fn
                unzipped_file.rename(output_file)

    def split_class_lists(self, count, files_for_class):
        """Returns the list of files belonging to a subset of data as well as the remainder of
        files that need to be downloaded.

        Args:
        files_for_class: Dictionary where keys are class names and values are lists of files.
        count: Total number of files to select across all classes.

        Returns:
        Files belonging to the subset of data and dictionary of the remainder of files.
        """
        # Calculate the total number of files across all classes
        total_files = sum(len(files) for files in files_for_class.values())

        # Compute the proportion of files to take from each class
        class_proportions = {
            cls: len(files) / total_files for cls, files in files_for_class.items()
        }

        # Compute the number of files to take from each class
        class_counts = {
            cls: int(proportion * count) for cls, proportion in class_proportions.items()
        }

        # Adjust the counts to ensure the total matches the required count
        remaining_count = count - sum(class_counts.values())
        sorted_classes = sorted(
            class_proportions, key=class_proportions.get, reverse=True
        )

        # Distribute any remaining files due to rounding
        for cls in sorted_classes:
            if remaining_count <= 0:
                break
            if len(files_for_class[cls]) > class_counts[cls]:
                class_counts[cls] += 1
                remaining_count -= 1

        # Select files based on calculated counts
        split_files = []
        remainder = {}
        for cls, cls_files in files_for_class.items():
            split_count = class_counts.get(cls, 0)
            split_files.extend(cls_files[:split_count])
            remainder[cls] = cls_files[split_count:]

        return split_files, remainder


    def download_and_split_file(self, splits, num_classes=None):
        """Download a subset of the UCF101 dataset and split them into various parts, such as
        training, validation, and test.

        Args:
        splits: Dictionary specifying the training, validation, test, etc. (key) division of data
                (value is the proportion of data for each split, e.g., {'train': 0.7, 'val': 0.2, 'test': 0.1}).
        num_classes: Number of labels.

        Return:
        Mapping of the directories containing the subsections of data.
        """
        if os.path.exists(self.to_dir) and os.path.isdir(self.to_dir):
            shutil.rmtree(self.to_dir)
        files_for_class = self.get_files_per_class_from_zip_file()

        # if needed, change here to number of classes that you want
        if num_classes is None:
            num_classes = len(files_for_class)

        classes = list(files_for_class.keys())[:num_classes]

        for cls in classes:
            random.shuffle(files_for_class[cls])

        # Only use the number of classes you want in the dictionary
        files_for_class = {x: files_for_class[x] for x in classes}

        # Compute total number of files and derive absolute counts from proportions
        total_files = sum(len(files) for files in files_for_class.values())
        splits = {k: int(v * total_files) for k, v in splits.items()}

        dirs = {}
        for split_name, split_count in splits.items():
            print(f"Split: {split_name}, Count: {split_count}")
            print(split_name, ":")
            split_dir = pathlib.Path(self.to_dir) / split_name
            split_files, files_for_class = self.split_class_lists(
                split_count, files_for_class
            )
            self.download_from_zip_file(split_dir, split_files)
            dirs[split_name] = split_dir

        return dirs


    def format_frames(self, frame, output_size):
        """
        Pad and resize an image from a video.

        Args:
          frame: Image that needs to resized and padded.
          output_size: Pixel size of the output frame image.

        Return:
          Formatted frame with padding of specified output size.
        """
        frame = tf.image.convert_image_dtype(frame, tf.float32)
        frame = tf.image.resize_with_pad(frame, *output_size)
        return frame

    def frames_from_video_file(
        self, video_path, n_frames, output_size=(224, 224), frame_step=15
    ):
        """
        Creates frames from each video file present for each category.

        Args:
          video_path: File path to the video.
          n_frames: Number of frames to be created per video file.
          output_size: Pixel size of the output frame image.

        Return:
          An NumPy array of frames in the shape of (n_frames, height, width, channels).
        """
        # Read each video frame by frame
        result = []
        src = cv2.VideoCapture(str(video_path))

        video_length = src.get(cv2.CAP_PROP_FRAME_COUNT)

        need_length = 1 + (n_frames - 1) * frame_step

        if need_length > video_length:
            start = 0
        else:
            max_start = video_length - need_length
            start = random.randint(0, max_start + 1)

        src.set(cv2.CAP_PROP_POS_FRAMES, start)
        # ret is a boolean indicating whether read was successful, frame is the image itself
        ret, frame = src.read()
        result.append(self.format_frames(frame, output_size))

        for _ in range(n_frames - 1):
            for _ in range(frame_step):
                ret, frame = src.read()
            if ret:
                frame = self.format_frames(frame, output_size)
                result.append(frame)
            else:
                result.append(np.zeros_like(result[0]))
        src.release()
        result = np.array(result)[..., [2, 1, 0]]

        return result

    def to_gif(self, images):
        converted_images = np.clip(images * 255, 0, 255).astype(np.uint8)
        imageio.mimsave("./animation.gif", converted_images, fps=10)


if __name__ == "__main__":
    dataloader = DataLoader(correct_dir, incorrect_dir)

# dataloader.list_files_from_zip_file()
# dataloader.list_files_for_class()
# dataloader.select_subset_of_classes()
# dataloader.get_number_of_files()
dataloader.get_number_of_files_for_class()
#dataloader.get_number_of_class()
# canny and compress all videos

FrameGenerator

from data_loader import DataLoader
from variables import correct_dir, incorrect_dir
import random
class FrameGenerator:
    def __init__(self, path, n_frames, training=False):
        """Returns a set of frames with their associated label.

        Args:
          path: Video file paths.
          n_frames: Number of frames.
          training: Boolean to determine if training dataset is being created.
        """
        self.dataloader = DataLoader(correct_dir, incorrect_dir)
        self.path = path
        self.n_frames = n_frames
        self.training = training
        self.class_names = sorted(
            set(p.name for p in self.path.iterdir() if p.is_dir())
        )
        self.class_ids_for_name = dict(
            (name, idx) for idx, name in enumerate(self.class_names)
        )
        self.filetype = ".MOV"

    def get_files_and_class_names(self):
        video_paths = list(self.path.glob("*/*" + self.filetype))
        classes = [p.parent.name for p in video_paths]
        return video_paths, classes

    def get_class_name_from_id(self, id):
        for class_name, class_id in self.class_ids_for_name.items():
            if class_id == id:
                return class_name
    def get_class_id_names(self):
        return self.class_ids_for_name
    def __call__(self):
        video_paths, classes = self.get_files_and_class_names()

        pairs = list(zip(video_paths, classes))

        if self.training:
            random.shuffle(pairs)

        for path, name in pairs:
            video_frames = self.dataloader.frames_from_video_file(path, self.n_frames)
            label = self.class_ids_for_name[name]  # Encode labels
            yield video_frames, label

Tensorflow data input pipeline(hyperparameters)

# Create the training set
from data_loader import DataLoader
from frame_generator import FrameGenerator
from variables import video_dir, lr, frame_number, batch_size, shuffle_buffer_size, HEIGHT, WIDTH
from modules import Conv2Plus1D, ResizeVideo, add_residual_block
dataloader = DataLoader()
classes = dataloader.get_number_of_class()
output_signature = (
    tf.TensorSpec(shape=(None, None, None, 3), dtype=tf.float32),
    tf.TensorSpec(shape=(), dtype=tf.int16),
)
train_ds = tf.data.Dataset.from_generator(
    FrameGenerator(pathlib.Path(video_dir + "/train"), frame_number, training=True),
    output_signature=output_signature,
)
# 数据输入流水线
# Create the validation set
val_ds = tf.data.Dataset.from_generator(
    FrameGenerator(pathlib.Path(video_dir + "/validation"), frame_number), output_signature=output_signature,
)
test_ds = tf.data.Dataset.from_generator(
    FrameGenerator(pathlib.Path(video_dir + "/test"), frame_number), output_signature=output_signature
)

configure the dataset

# 自动调整并行数据加载(AUTOTUNE)
AUTOTUNE = tf.data.AUTOTUNE

# 优化数据管道
train_ds = (
    train_ds.cache()  # 将数据缓存到内存(如果内存足够大)
    .shuffle(buffer_size=shuffle_buffer_size)  # 打乱数据
    .batch(batch_size)  # 批量化
    .prefetch(buffer_size=AUTOTUNE)  # 在 GPU 训练时提前加载数据
)

val_ds = (
    val_ds.cache().shuffle(buffer_size=shuffle_buffer_size).batch(batch_size).prefetch(buffer_size=AUTOTUNE)
)

test_ds = (
    test_ds.cache().shuffle(buffer_size=shuffle_buffer_size).batch(batch_size).prefetch(buffer_size=AUTOTUNE)
)

# train_frames, train_labels = next(iter(train_ds))
# print(f"Shape of training set of frames: {train_frames.shape}")
# print(f"Shape of training labels: {train_labels.shape}")

# val_frames, val_labels = next(iter(val_ds))
# print(f"Shape of validation set of frames: {val_frames.shape}")
# print(f"Shape of validation labels: {val_labels.shape}")

# test_frames, test_labels = next(iter(test_ds))
# print(f"Shape of test set of frames: {test_frames.shape}")
# print(f"Shape of test labels: {test_labels.shape}")

Conv2Plus1D

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing import image_dataset_from_directory
import einops
class Conv2Plus1D(keras.layers.Layer):
    def __init__(self, filters, kernel_size, padding, name=None, **kwargs):
        """
        A sequence of convolutional layers that first apply the convolution operation
        over the spatial dimensions, and then the temporal dimension.
        """
        super().__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.padding = padding
        self.seq = keras.Sequential(
            [
                # Spatial decomposition
                layers.Conv3D(
                    filters=filters,
                    kernel_size=(1, kernel_size[1], kernel_size[2]),
                    padding=padding,
                ),
                # Temporal decomposition
                layers.Conv3D(
                    filters=filters, kernel_size=(kernel_size[0], 1, 1), padding=padding
                ),
            ]
        )

    def call(self, x):
        return self.seq(x)

    def get_config(self):
        # Get the configuration of this layer
        config = super().get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
            "padding": self.padding,
        })
        return config

ResidualMain

class ResidualMain(keras.layers.Layer):
    """
    Residual block of the model with convolution, layer normalization, and the
    activation function, ReLU.
    """

    def __init__(self, filters, kernel_size, name=None, **kwargs):
        super().__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.seq = keras.Sequential(
            [
                Conv2Plus1D(filters=filters, kernel_size=kernel_size, padding="same"),
                layers.LayerNormalization(),
                layers.ReLU(),
                Conv2Plus1D(filters=filters, kernel_size=kernel_size, padding="same"),
                layers.LayerNormalization(),
            ]
        )

    def call(self, x):
        return self.seq(x)

    def get_config(self):
        # Get the configuration of this layer
        config = super().get_config()
        config.update({
            "filters": self.filters,
            "kernel_size": self.kernel_size,
        })
        return config

Project

class Project(keras.layers.Layer):
    """
    Project certain dimensions of the tensor as the data is passed through different
    sized filters and downsampled.
    """

    def __init__(self, units, name=None, **kwargs):
        super().__init__()
        self.units = units
        self.seq = keras.Sequential([layers.Dense(units), layers.LayerNormalization()])

    def call(self, x):
        return self.seq(x)
    def get_config(self):
        # Get the configuration of this layer
        config = super().get_config()
        config.update({
            "units": self.units,
        })
        return config

def add_residual_block(input, filters, kernel_size):
    """
    Add residual blocks to the model. If the last dimensions of the input data
    and filter size does not match, project it such that last dimension matches.
    """
    out = ResidualMain(filters, kernel_size)(input)

    res = input
    # Using the Keras functional APIs, project the last dimension of the tensor to
    # match the new filter size
    if out.shape[-1] != input.shape[-1]:
        res = Project(out.shape[-1])(res)

    return layers.add([res, out])

ResizeVideo

class ResizeVideo(keras.layers.Layer):
    def __init__(self, height, width, name=None, **kwargs):
        super().__init__()
        self.height = height
        self.width = width
        self.resizing_layer = layers.Resizing(self.height, self.width)

    def call(self, video):
        """
          Use the einops library to resize the tensor.

        Args:
          video: Tensor representation of the video, in the form of a set of frames.

        Return:
          A downsampled size of the video according to the new height and width it should be resized to.
        """
        # b stands for batch size, t stands for time, h stands for height,
        # w stands for width, and c stands for the number of channels.
        old_shape = einops.parse_shape(video, "b t h w c")
        images = einops.rearrange(video, "b t h w c -> (b t) h w c")
        images = self.resizing_layer(images)
        videos = einops.rearrange(images, "(b t) h w c -> b t h w c", t=old_shape["t"])
        return videos
    def get_config(self):
        # Get the configuration of this layer
        config = super().get_config()
        config.update({
            "height": self.height,
            "width": self.width
        })
        return config

Model

input_shape = (None, frame_number, HEIGHT, WIDTH, 3)
input = layers.Input(shape=(input_shape[1:]))
x = input

x = Conv2Plus1D(filters=16, kernel_size=(3, 7, 7), padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = ResizeVideo(HEIGHT // 2, WIDTH // 2)(x)

# Block 1
x = add_residual_block(x, 16, (3, 3, 3))
x = ResizeVideo(HEIGHT // 4, WIDTH // 4)(x)

# Block 2
x = add_residual_block(x, 32, (3, 3, 3))
x = ResizeVideo(HEIGHT // 8, WIDTH // 8)(x)

# Block 3
x = add_residual_block(x, 64, (3, 3, 3))
x = ResizeVideo(HEIGHT // 16, WIDTH // 16)(x)

# Block 3
x = add_residual_block(x, 64, (3, 3, 3))
x = ResizeVideo(HEIGHT // 16, WIDTH // 16)(x)

# Block 4
x = add_residual_block(x, 128, (3, 3, 3))

x = layers.GlobalAveragePooling3D()(x)
x = layers.Flatten()(x)
x = layers.Dense(classes, activation='softmax')(x)
model = keras.Model(input, x)
# frames, label = next(iter(train_ds))
# Visualize the model
# keras.utils.plot_model(model, expand_nested=True, dpi=60, show_shapes=True)
# model.summary()

运行

model.build((batch_size, frame_number, HEIGHT, WIDTH, 3))
# 确保模型训练在 GPU 上运行
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer=keras.optimizers.Adam(learning_rate=lr),
    metrics=["accuracy"],
)

with tf.device("/GPU:0"):  # 指定 GPU 设备
    history = model.fit(x=train_ds, epochs=50, validation_data=val_ds)

Save model


# 获取当前时间并格式化为字符串
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")

# 构建文件夹路径
h5_dir = "model_h5"
keras_dir = "model_keras"

# 确保文件夹存在
os.makedirs(h5_dir, exist_ok=True)
os.makedirs(keras_dir, exist_ok=True)

# 构建文件路径
h5_filename = os.path.join(h5_dir, f"VideoClassifier_{current_time}.h5")
keras_filename = os.path.join(keras_dir, f"VideoClassifier_{current_time}.keras")

# 保存模型
tf.keras.models.save_model(
    model, h5_filename, overwrite=True
)
model.save(keras_filename)

print(f"Model saved as:\n- {h5_filename}\n- {keras_filename}")
from matplotlib import pyplot as plt
def plot_history(history):
    """
    Plotting training and validation learning curves.

    Args:
      history: model history with all the metric measures
    """
    fig, (ax1, ax2) = plt.subplots(2)

    fig.set_size_inches(18.5, 10.5)

    # Plot loss
    ax1.set_title("Loss")
    ax1.plot(history.history["loss"], label="train")
    ax1.plot(history.history["val_loss"], label="test")
    ax1.set_ylabel("Loss")

    # Determine upper bound of y-axis
    max_loss = max(history.history["loss"] + history.history["val_loss"])

    ax1.set_ylim([0, np.ceil(max_loss)])
    ax1.set_xlabel("Epoch")
    ax1.legend(["Train", "Validation"])

    # Plot accuracy
    ax2.set_title("Accuracy")
    ax2.plot(history.history["accuracy"], label="train")
    ax2.plot(history.history["val_accuracy"], label="test")
    ax2.set_ylabel("Accuracy")
    ax2.set_ylim([0, 1])
    ax2.set_xlabel("Epoch")
    ax2.legend(["Train", "Validation"])

    plt.show()


plot_history(history)

Model Loading

from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import tensorflow as tf
from modules import Conv2Plus1D, Project, ResizeVideo, ResidualMain
from frame_generator import FrameGenerator

# Constants
video_dir = "Videos"  # Update with your video directory path
frame_number = 40
batch_size = 2
AUTOTUNE = tf.data.AUTOTUNE

# Load model
model = keras.models.load_model("model_keras/VideoClassifier_20250205_181141.keras",
                               custom_objects={"Conv2Plus1D": Conv2Plus1D,
                                              "Project": Project,
                                              "ResizeVideo": ResizeVideo,
                                              "ResidualMain": ResidualMain})
# model.summary()

# Create validation dataset
val_ds = tf.data.Dataset.from_generator(
    FrameGenerator(pathlib.Path(video_dir + "/validation"), frame_number, training=False),
    output_signature=(
        tf.TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(), dtype=tf.int32)
    )
)

val_ds = val_ds.batch(batch_size).prefetch(AUTOTUNE)

# Make predictions and collect results
actual_labels = []
predicted_labels = []
video_frames = []

for batch in val_ds:
    frames, labels = batch
    predictions = model.predict(frames)
    predicted_classes = np.argmax(predictions, axis=1)

    actual_labels.extend(labels.numpy())
    predicted_labels.extend(predicted_classes)
    video_frames.extend(frames.numpy())

# Display results
plt.figure(figsize=(10, 10))
for i in range(min(9, len(video_frames))):  # Show up to 9 images
    plt.subplot(3, 3, i + 1)
    plt.imshow(video_frames[i][10] / 255.0)  # Show first frame of each video
    plt.title(f"Pred: {predicted_labels[i]}\nTrue: {actual_labels[i]}")
    plt.axis('off')

plt.tight_layout()
plt.savefig('predictions.png')
plt.show()
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'conv2_plus1d_9', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'resize_video_4', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'residual_main_4', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'resize_video_5', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'project_3', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'residual_main_5', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'resize_video_6', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'project_4', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'residual_main_6', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'resize_video_7', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'project_5', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/layers/layer.py:391: UserWarning: `build()` was called on layer 'residual_main_7', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:713: UserWarning: Skipping variable loading for optimizer 'rmsprop', because it has 70 variables whereas the saved optimizer has 138 variables.
  saveable.load_own_variables(weights_store.get(inner_path))
Model: "functional_16"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)              ┃ Output Shape           ┃        Param # ┃ Connected to           ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer (InputLayer)  │ (None, 40, 224, 224,   │              0 │ -                      │
│                           │ 3)                     │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ conv2_plus1d_9            │ (None, 40, 224, 224,   │          3,152 │ input_layer[0][0]      │
│ (Conv2Plus1D)             │ 16)                    │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ batch_normalization       │ (None, 40, 224, 224,   │             64 │ conv2_plus1d_9[0][0]   │
│ (BatchNormalization)      │ 16)                    │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ re_lu (ReLU)              │ (None, 40, 224, 224,   │              0 │ batch_normalization[0… │
│                           │ 16)                    │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ resize_video_4            │ (None, 40, 112, 112,   │              0 │ re_lu[0][0]            │
│ (ResizeVideo)             │ 16)                    │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ residual_main_4           │ (None, 40, 112, 112,   │          6,272 │ resize_video_4[0][0]   │
│ (ResidualMain)            │ 16)                    │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ add (Add)                 │ (None, 40, 112, 112,   │              0 │ resize_video_4[0][0],  │
│                           │ 16)                    │                │ residual_main_4[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ resize_video_5            │ (None, 40, 56, 56, 16) │              0 │ add[0][0]              │
│ (ResizeVideo)             │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ project_3 (Project)       │ (None, 40, 56, 56, 32) │            608 │ resize_video_5[0][0]   │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ residual_main_5           │ (None, 40, 56, 56, 32) │         20,224 │ resize_video_5[0][0]   │
│ (ResidualMain)            │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ add_1 (Add)               │ (None, 40, 56, 56, 32) │              0 │ project_3[0][0],       │
│                           │                        │                │ residual_main_5[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ resize_video_6            │ (None, 40, 28, 28, 32) │              0 │ add_1[0][0]            │
│ (ResizeVideo)             │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ project_4 (Project)       │ (None, 40, 28, 28, 64) │          2,240 │ resize_video_6[0][0]   │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ residual_main_6           │ (None, 40, 28, 28, 64) │         80,384 │ resize_video_6[0][0]   │
│ (ResidualMain)            │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ add_2 (Add)               │ (None, 40, 28, 28, 64) │              0 │ project_4[0][0],       │
│                           │                        │                │ residual_main_6[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ resize_video_7            │ (None, 40, 14, 14, 64) │              0 │ add_2[0][0]            │
│ (ResizeVideo)             │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ project_5 (Project)       │ (None, 40, 14, 14,     │          8,576 │ resize_video_7[0][0]   │
│                           │ 128)                   │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ residual_main_7           │ (None, 40, 14, 14,     │        320,512 │ resize_video_7[0][0]   │
│ (ResidualMain)            │ 128)                   │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ add_3 (Add)               │ (None, 40, 14, 14,     │              0 │ project_5[0][0],       │
│                           │ 128)                   │                │ residual_main_7[0][0]  │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ global_average_pooling3d  │ (None, 128)            │              0 │ add_3[0][0]            │
│ (GlobalAveragePooling3D)  │                        │                │                        │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ flatten (Flatten)         │ (None, 128)            │              0 │ global_average_poolin… │
├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤
│ dense_3 (Dense)           │ (None, 4)              │            516 │ flatten[0][0]          │
└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘
 Total params: 885,066 (3.38 MB)
 Trainable params: 442,516 (1.69 MB)
 Non-trainable params: 32 (128.00 B)
 Optimizer params: 442,518 (1.69 MB)
def get_actual_predicted_labels(dataset):
    """
    Create a list of actual ground truth values and the predictions from the model.

    Args:
      dataset: An iterable data structure, such as a TensorFlow Dataset, with features and labels.

    Return:
      Ground truth and predicted values for a particular dataset.
    """
    actual = [labels for _, labels in dataset.unbatch()]
    predicted = model.predict(dataset)

    actual = tf.stack(actual, axis=0)
    predicted = tf.concat(predicted, axis=0)
    predicted = tf.argmax(predicted, axis=1)

    return actual, predicted


def calculate_classification_metrics(y_actual, y_pred, labels):
  """
    Calculate the precision and recall of a classification model using the ground truth and
    predicted values.

    Args:
      y_actual: Ground truth labels.
      y_pred: Predicted labels.
      labels: List of classification labels.

    Return:
      Precision and recall measures.
  """
  cm = tf.math.confusion_matrix(y_actual, y_pred)
  tp = np.diag(cm) # Diagonal represents true positives
  precision = dict()
  recall = dict()
  for i in range(len(labels)):
    col = cm[:, i]
    fp = np.sum(col) - tp[i] # Sum of column minus true positive is false negative

    row = cm[i, :]
    fn = np.sum(row) - tp[i] # Sum of row minus true positive, is false negative

    precision[labels[i]] = tp[i] / (tp[i] + fp) # Precision

    recall[labels[i]] = tp[i] / (tp[i] + fn) # Recall

  return precision, recall
actual, predicted = get_actual_predicted_labels(test_ds)
precision, recall = calculate_classification_metrics(actual, predicted, ["Correct", "Incorrect"]) # Test dataset
12/12 ━━━━━━━━━━━━━━━━━━━━ 3s 94ms/step
<ipython-input-43-b91e42561ef8>:45: RuntimeWarning: invalid value encountered in scalar divide
  precision[labels[i]] = tp[i] / (tp[i] + fp) # Precision
precision
recall
{'Correct': 0.0, 'Incorrect': 0.625}