基本思想

在这个代码中,首先设计一个桥类使得Gym中的环境代码能够和NS3中的真实环境进行交互,比如得到下一时刻的状态和回报,以及执行动作等等,这个类命名为Ns3ZmqBridge

在继承自父类gym.Env的环境类Ns3Env中,引用了类Ns3ZmqBridge,使得其能够更好地与环境代码进行通讯。

代码解读

ns3gym的结合版本ns3gym中,作者给出了一个示例用以展示如何实现需要提供的接口,以使得两者可以很好地同时进行工作。关于ns3gym的设计思路已经在官方提供的技术报告中讲解,下面讲解官方提供的示例环境ns3env的代码:

import os
import sys
import zmq
import time

import numpy as np

import gym
from gym import spaces
from gym.utils import seeding
from enum import IntEnum

from ns3gym.start_sim import start_sim_script, build_ns3_project

# 使用Google的protocol buffer技术
import ns3gym.messages_pb2 as pb
from google.protobuf.any_pb2 import Any


__author__ = "Piotr Gawlowicz"
__copyright__ = "Copyright (c) 2018, Technische Universität Berlin"
__version__ = "0.1.0"
__email__ = "gawlowicz@tkn.tu-berlin.de"

# ns3和gym通信的桥梁,主要用来和ns3之间(网络环境服务器)之间进行通信
class Ns3ZmqBridge(object):
    """docstring for Ns3ZmqBridge"""

    # 构造函数
    # port:指定如果ns3进程要和这边通信的的话,这边的端口号是多少,默认为0,表示使用随机端口
    # startSim:是否要在初始化该进程时即进行模拟(Agent打开时通知ns3脚本开启网络环境)
    # simSeed:随机函数种子
    # simArgs:模拟相关的参数
    # debug:是否需要debug相关信息
    def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):

        # 首先执行父类的构造函数
        super(Ns3ZmqBridge, self).__init__()
        port = int(port)

        # 定义该类的一些自有属性
        self.port = port
        self.startSim = startSim
        self.simSeed = simSeed
        self.simArgs = simArgs
        self.envStopped = False
        self.simPid = None
        self.wafPid = None
        self.ns3Process = None

        # 技术上来说,context是一个进程,其中包含了所有的socket,这使得其更加容易与线程通信
        context = zmq.Context()
        self.socket = context.socket(zmq.REP)
        try:

            # 使用随机端口,直接开始模拟
            if port == 0 and self.startSim:

                # 绑定一个随机的端口
                port = self.socket.bind_to_random_port('tcp://*', min_port=5001, max_port=10000, max_tries=100)
                print("Got new port for ns3gm interface: ", port)

            # 使用随机端口,不开始模拟
            elif port == 0 and not self.startSim:

                # 不开始模拟为什么要绑定
                print("Cannot use port %s to bind" % str(port) )
                print("Please specify correct port" )
                sys.exit()

            else:

                # 绑定指定端口
                self.socket.bind ("tcp://*:%s" % str(port))

        except Exception as e:
            print("Cannot bind to tcp://*:%s as port is already in use" % str(port) )
            print("Please specify different port or use 0 to get free port" )
            sys.exit()

        # 开始模拟,使用随机种子
        if (startSim == True and simSeed == 0):
            maxSeed = np.iinfo(np.uint32).max
            simSeed = np.random.randint(0, maxSeed)
            self.simSeed = simSeed

        # 执行进程,模拟的脚本
        # 如何知道脚本地址?就在当前目录下
        if self.startSim:
            # run simulation script
            self.ns3Process = start_sim_script(port, simSeed, simArgs, debug)
        else:
            print("Waiting for simulation script to connect on port: tcp://localhost:{}".format(port))
            print('Please start proper ns-3 simulation script using ./waf --run "..."')

        # 一些变量的初始化
        self._action_space = None
        self._observation_space = None

        self.forceEnvStop = False
        self.obsData = None
        self.reward = 0
        self.gameOver = False
        self.gameOverReason = None
        self.extraInfo = None
        self.newStateRx = False

    # 关闭
    def close(self):
        try:
            if not self.envStopped:
                self.envStopped = True
                self.force_env_stop()
                self.rx_env_state()
                self.send_close_command()
                self.ns3Process.kill()
                if self.simPid:
                    os.kill(self.simPid, signal.SIGTERM)
                    self.simPid = None
                if self.wafPid:
                    os.kill(self.wafPid, signal.SIGTERM)
                    self.wafPid = None
        except Exception as e:
            pass

    # 创造空间,状态空间和动作空间都可以创造
    def _create_space(self, spaceDesc):
        space = None

        # 在gym中,空间只能有四种类型
        # Discrete
        if (spaceDesc.type == pb.Discrete):
            discreteSpacePb = pb.DiscreteSpace()
            spaceDesc.space.Unpack(discreteSpacePb)
            space = spaces.Discrete(discreteSpacePb.n)

        # Box
        elif (spaceDesc.type == pb.Box):
            boxSpacePb = pb.BoxSpace()
            spaceDesc.space.Unpack(boxSpacePb)
            low = boxSpacePb.low
            high = boxSpacePb.high
            shape = tuple(boxSpacePb.shape)
            mtype = boxSpacePb.dtype

            if mtype == pb.INT:
                mtype = np.int
            elif mtype == pb.UINT:
                mtype = np.uint
            elif mtype == pb.DOUBLE:
                mtype = np.float
            else:
                mtype = np.float

            space = spaces.Box(low=low, high=high, shape=shape, dtype=mtype)

        # Tuple
        elif (spaceDesc.type == pb.Tuple):
            mySpaceList = []
            tupleSpacePb = pb.TupleSpace()
            spaceDesc.space.Unpack(tupleSpacePb)

            for pbSubSpaceDesc in tupleSpacePb.element:
                subSpace = self._create_space(pbSubSpaceDesc)
                mySpaceList.append(subSpace)

            mySpaceTuple = tuple(mySpaceList)
            space = spaces.Tuple(mySpaceTuple)

        # Dict
        elif (spaceDesc.type == pb.Dict):
            mySpaceDict = {}
            dictSpacePb = pb.DictSpace()
            spaceDesc.space.Unpack(dictSpacePb)

            for pbSubSpaceDesc in dictSpacePb.element:
                subSpace = self._create_space(pbSubSpaceDesc)
                mySpaceDict[pbSubSpaceDesc.name] = subSpace

            space = spaces.Dict(mySpaceDict)

        return space

    # 初始化环境,在刚开始执行或者reset环境时会进行调用
    # 因为之前已经执行过构造函数,而在构造函数中调用了ns3脚本,所以调用这个函数时可以确保脚本是在运行中的
    def initialize_env(self, stepInterval):

        # 将ns3代码的消息进行解析
        request = self.socket.recv()
        simInitMsg = pb.SimInitMsg()
        simInitMsg.ParseFromString(request)

        # 根据两个进程号结束两个进程
        self.simPid = int(simInitMsg.simProcessId)
        self.wafPid = int(simInitMsg.wafShellProcessId)

        # 创造状态空间和动作空间
        self._action_space = self._create_space(simInitMsg.actSpace)
        self._observation_space = self._create_space(simInitMsg.obsSpace)

        # 向Agent发送消息,环境初始化结束
        reply = pb.SimInitAck()
        reply.done = True
        reply.stopSimReq = False
        replyMsg = reply.SerializeToString()
        self.socket.send(replyMsg)
        return True

    # 得到动作空间
    def get_action_space(self):
        return self._action_space

    # 得到观测空间
    def get_observation_space(self):
        return self._observation_space

    # 强行终止服务端
    def force_env_stop(self):
        self.forceEnvStop = True

    # 接收环境的状态
    # 主要用在当执行动作后,环境的状态和回报得到了更新
    def rx_env_state(self):
        if self.newStateRx:
            return

        request = self.socket.recv()
        envStateMsg = pb.EnvStateMsg()
        envStateMsg.ParseFromString(request)

        # 更新状态信息 和 回报信息
        self.obsData = self._create_data(envStateMsg.obsData)
        self.reward = envStateMsg.reward
        self.gameOver = envStateMsg.isGameOver
        self.gameOverReason = envStateMsg.reason

        if self.gameOver:
            if self.gameOverReason == pb.EnvStateMsg.SimulationEnd:
                self.envStopped = True
                self.send_close_command()
            else:
                self.forceEnvStop = True
                self.send_close_command()

        self.extraInfo = envStateMsg.info
        if not self.extraInfo:
            self.extraInfo = {}

        self.newStateRx = True

    # 向ns3脚本发送终止的消息
    def send_close_command(self):
        reply = pb.EnvActMsg()
        reply.stopSimReq = True
        replyMsg = reply.SerializeToString()
        self.socket.send(replyMsg)
        self.newStateRx = False
        return True

    # 向ns3发送需要执行的动作
    def send_actions(self, actions):
        reply = pb.EnvActMsg()

        actionMsg = self._pack_data(actions, self._action_space)
        reply.actData.CopyFrom(actionMsg)

        reply.stopSimReq = False
        if self.forceEnvStop:
            reply.stopSimReq = True

        replyMsg = reply.SerializeToString()
        self.socket.send(replyMsg)
        self.newStateRx = False
        return True

    def step(self, actions):
        # 执行动作
        self.send_actions(actions)

        # 下一步状态、回报
        self.rx_env_state()

    # 是否游戏结束
    def is_game_over(self):
        return self.gameOver

    # 对传输来的数据进行解析
    def _create_data(self, dataContainerPb):
        
        # Discrete
        if (dataContainerPb.type == pb.Discrete):
            discreteContainerPb = pb.DiscreteDataContainer()
            dataContainerPb.data.Unpack(discreteContainerPb)
            data = discreteContainerPb.data
            return data

        # Box
        if (dataContainerPb.type == pb.Box):
            boxContainerPb = pb.BoxDataContainer()
            dataContainerPb.data.Unpack(boxContainerPb)
            # print(boxContainerPb.shape, boxContainerPb.dtype, boxContainerPb.uintData)

            if boxContainerPb.dtype == pb.INT:
                data = boxContainerPb.intData
            elif boxContainerPb.dtype == pb.UINT:
                data = boxContainerPb.uintData
            elif boxContainerPb.dtype == pb.DOUBLE:
                data = boxContainerPb.doubleData
            else:
                data = boxContainerPb.floatData

            # TODO: reshape using shape info
            return data

        # Tuple
        elif (dataContainerPb.type == pb.Tuple):
            tupleDataPb = pb.TupleDataContainer()
            dataContainerPb.data.Unpack(tupleDataPb)

            myDataList = []
            for pbSubData in tupleDataPb.element:
                subData = self._create_data(pbSubData)
                myDataList.append(subData)

            data = tuple(myDataList)
            return data

        # Dict
        elif (dataContainerPb.type == pb.Dict):
            dictDataPb = pb.DictDataContainer()
            dataContainerPb.data.Unpack(dictDataPb)

            myDataDict = {}
            for pbSubData in dictDataPb.element:
                subData = self._create_data(pbSubData)
                myDataDict[pbSubData.name] = subData
            data = myDataDict
            return data

    def get_obs(self):
        return self.obsData

    def get_reward(self):
        return self.reward

    def get_extra_info(self):
        return self.extraInfo

    # 打包数据,用以发送
    def _pack_data(self, actions, spaceDesc):
        dataContainer = pb.DataContainer()

        spaceType = spaceDesc.__class__

        # Discrete
        if spaceType == spaces.Discrete:
            dataContainer.type = pb.Discrete
            discreteContainerPb = pb.DiscreteDataContainer()
            discreteContainerPb.data = actions
            dataContainer.data.Pack(discreteContainerPb)

        # Box
        elif spaceType == spaces.Box:
            dataContainer.type = pb.Box
            boxContainerPb = pb.BoxDataContainer()
            shape = [len(actions)]
            boxContainerPb.shape.extend(shape)

            if (spaceDesc.dtype in ['int', 'int8', 'int16', 'int32', 'int64']):
                boxContainerPb.dtype = pb.INT
                boxContainerPb.intData.extend(actions)

            elif (spaceDesc.dtype in ['uint', 'uint8', 'uint16', 'uint32', 'uint64']):
                boxContainerPb.dtype = pb.UINT
                boxContainerPb.uintData.extend(actions)

            elif (spaceDesc.dtype in ['float', 'float32', 'float64']):
                boxContainerPb.dtype = pb.FLOAT
                boxContainerPb.floatData.extend(actions)

            elif (spaceDesc.dtype in ['double']):
                boxContainerPb.dtype = pb.DOUBLE
                boxContainerPb.doubleData.extend(actions)

            else:
                boxContainerPb.dtype = pb.FLOAT
                boxContainerPb.floatData.extend(actions)

            dataContainer.data.Pack(boxContainerPb)

        # Tuple
        elif spaceType == spaces.Tuple:
            dataContainer.type = pb.Tuple
            tupleDataPb = pb.TupleDataContainer()

            spaceList = list(self._action_space.spaces)
            subDataList = []
            for subAction, subActSpaceType in zip(actions, spaceList):
                subData = self._pack_data(subAction, subActSpaceType)
                subDataList.append(subData)

            tupleDataPb.element.extend(subDataList)
            dataContainer.data.Pack(tupleDataPb)

        # Dict
        elif spaceType == spaces.Dict:
            dataContainer.type = pb.Dict
            dictDataPb = pb.DictDataContainer()

            subDataList = []
            for sName, subAction in actions.items():
                subActSpaceType = self._action_space.spaces[sName]
                subData = self._pack_data(subAction, subActSpaceType)
                subData.name = sName
                subDataList.append(subData)

            dictDataPb.element.extend(subDataList)
            dataContainer.data.Pack(dictDataPb)

        return dataContainer


class Ns3Env(gym.Env):
    def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
        self.stepTime = stepTime
        self.port = port
        self.startSim = startSim
        self.simSeed = simSeed
        self.simArgs = simArgs
        self.debug = debug

        # Filled in reset function
        self.ns3ZmqBridge = None
        self.action_space = None
        self.observation_space = None

        self.viewer = None
        self.state = None
        self.steps_beyond_done = None

        self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
        self.ns3ZmqBridge.initialize_env(self.stepTime)
        self.action_space = self.ns3ZmqBridge.get_action_space()
        self.observation_space = self.ns3ZmqBridge.get_observation_space()
        # get first observations
        self.ns3ZmqBridge.rx_env_state()

        # envDirty表示自从reset了之后,有没有step过
        self.envDirty = False
        self.seed()

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    def step(self, action):
        response = self.ns3ZmqBridge.step(action)
        self.envDirty = True
        obs = self.ns3ZmqBridge.get_obs()
        reward = self.ns3ZmqBridge.get_reward()
        done = self.ns3ZmqBridge.is_game_over()
        extraInfo = self.ns3ZmqBridge.get_extra_info()
        return (obs, reward, done, extraInfo)

    def reset(self):
        # 刚刚reset过
        if not self.envDirty:
            obs = self.ns3ZmqBridge.get_obs()
            return obs

        if self.ns3ZmqBridge:
            self.ns3ZmqBridge.close()
            self.ns3ZmqBridge = None

        self.envDirty = False
        self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
        self.ns3ZmqBridge.initialize_env(self.stepTime)
        self.action_space = self.ns3ZmqBridge.get_action_space()
        self.observation_space = self.ns3ZmqBridge.get_observation_space()
       
        # 得到首次观测值
        self.ns3ZmqBridge.rx_env_state()
        obs = self.ns3ZmqBridge.get_obs()
        return obs

    def render(self, mode='human'):
        return

    def get_random_action(self):
        act = self.action_space.sample()
        return act

    def close(self):
        if self.ns3ZmqBridge:
            self.ns3ZmqBridge.close()
            self.ns3ZmqBridge = None

        if self.viewer:
            self.viewer.close()