I have a student-teacher network for a self-supervised learning network. Basically, my teacher network predicts classes and calculates the loss and updates its parameters using the gradient flow. Typically, the student network updates its parameters using the moving average of teacher network. But in my case, i have some unique layers for the student network that are needed to train using gradient flow. For instance, as the figure shows, there is a unique layer (layer X) in the student network, and rest are common for both networks. The layer 1 layer 2 and layer 3 in the student network need to be updated with the moving average of teacher network and the layer x needs to be updated with gradient flow. How do I design such a network using Python torch DDP.