NS3Gym中的环境脚本(Gym端)代码讲解
基本思想
在这个代码中,首先设计一个桥类使得Gym
中的环境代码能够和NS3
中的真实环境进行交互,比如得到下一时刻的状态和回报,以及执行动作等等,这个类命名为Ns3ZmqBridge
。
在继承自父类gym.Env
的环境类Ns3Env
中,引用了类Ns3ZmqBridge
,使得其能够更好地与环境代码进行通讯。
代码解读
在ns3
与gym
的结合版本ns3gym
中,作者给出了一个示例用以展示如何实现需要提供的接口,以使得两者可以很好地同时进行工作。关于ns3gym
的设计思路已经在官方提供的技术报告中讲解,下面讲解官方提供的示例环境ns3env
的代码:
import os
import sys
import zmq
import time
import numpy as np
import gym
from gym import spaces
from gym.utils import seeding
from enum import IntEnum
from ns3gym.start_sim import start_sim_script, build_ns3_project
# 使用Google的protocol buffer技术
import ns3gym.messages_pb2 as pb
from google.protobuf.any_pb2 import Any
__author__ = "Piotr Gawlowicz"
__copyright__ = "Copyright (c) 2018, Technische Universität Berlin"
__version__ = "0.1.0"
__email__ = "gawlowicz@tkn.tu-berlin.de"
# ns3和gym通信的桥梁,主要用来和ns3之间(网络环境服务器)之间进行通信
class Ns3ZmqBridge(object):
"""docstring for Ns3ZmqBridge"""
# 构造函数
# port:指定如果ns3进程要和这边通信的的话,这边的端口号是多少,默认为0,表示使用随机端口
# startSim:是否要在初始化该进程时即进行模拟(Agent打开时通知ns3脚本开启网络环境)
# simSeed:随机函数种子
# simArgs:模拟相关的参数
# debug:是否需要debug相关信息
def __init__(self, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
# 首先执行父类的构造函数
super(Ns3ZmqBridge, self).__init__()
port = int(port)
# 定义该类的一些自有属性
self.port = port
self.startSim = startSim
self.simSeed = simSeed
self.simArgs = simArgs
self.envStopped = False
self.simPid = None
self.wafPid = None
self.ns3Process = None
# 技术上来说,context是一个进程,其中包含了所有的socket,这使得其更加容易与线程通信
context = zmq.Context()
self.socket = context.socket(zmq.REP)
try:
# 使用随机端口,直接开始模拟
if port == 0 and self.startSim:
# 绑定一个随机的端口
port = self.socket.bind_to_random_port('tcp://*', min_port=5001, max_port=10000, max_tries=100)
print("Got new port for ns3gm interface: ", port)
# 使用随机端口,不开始模拟
elif port == 0 and not self.startSim:
# 不开始模拟为什么要绑定
print("Cannot use port %s to bind" % str(port) )
print("Please specify correct port" )
sys.exit()
else:
# 绑定指定端口
self.socket.bind ("tcp://*:%s" % str(port))
except Exception as e:
print("Cannot bind to tcp://*:%s as port is already in use" % str(port) )
print("Please specify different port or use 0 to get free port" )
sys.exit()
# 开始模拟,使用随机种子
if (startSim == True and simSeed == 0):
maxSeed = np.iinfo(np.uint32).max
simSeed = np.random.randint(0, maxSeed)
self.simSeed = simSeed
# 执行进程,模拟的脚本
# 如何知道脚本地址?就在当前目录下
if self.startSim:
# run simulation script
self.ns3Process = start_sim_script(port, simSeed, simArgs, debug)
else:
print("Waiting for simulation script to connect on port: tcp://localhost:{}".format(port))
print('Please start proper ns-3 simulation script using ./waf --run "..."')
# 一些变量的初始化
self._action_space = None
self._observation_space = None
self.forceEnvStop = False
self.obsData = None
self.reward = 0
self.gameOver = False
self.gameOverReason = None
self.extraInfo = None
self.newStateRx = False
# 关闭
def close(self):
try:
if not self.envStopped:
self.envStopped = True
self.force_env_stop()
self.rx_env_state()
self.send_close_command()
self.ns3Process.kill()
if self.simPid:
os.kill(self.simPid, signal.SIGTERM)
self.simPid = None
if self.wafPid:
os.kill(self.wafPid, signal.SIGTERM)
self.wafPid = None
except Exception as e:
pass
# 创造空间,状态空间和动作空间都可以创造
def _create_space(self, spaceDesc):
space = None
# 在gym中,空间只能有四种类型
# Discrete
if (spaceDesc.type == pb.Discrete):
discreteSpacePb = pb.DiscreteSpace()
spaceDesc.space.Unpack(discreteSpacePb)
space = spaces.Discrete(discreteSpacePb.n)
# Box
elif (spaceDesc.type == pb.Box):
boxSpacePb = pb.BoxSpace()
spaceDesc.space.Unpack(boxSpacePb)
low = boxSpacePb.low
high = boxSpacePb.high
shape = tuple(boxSpacePb.shape)
mtype = boxSpacePb.dtype
if mtype == pb.INT:
mtype = np.int
elif mtype == pb.UINT:
mtype = np.uint
elif mtype == pb.DOUBLE:
mtype = np.float
else:
mtype = np.float
space = spaces.Box(low=low, high=high, shape=shape, dtype=mtype)
# Tuple
elif (spaceDesc.type == pb.Tuple):
mySpaceList = []
tupleSpacePb = pb.TupleSpace()
spaceDesc.space.Unpack(tupleSpacePb)
for pbSubSpaceDesc in tupleSpacePb.element:
subSpace = self._create_space(pbSubSpaceDesc)
mySpaceList.append(subSpace)
mySpaceTuple = tuple(mySpaceList)
space = spaces.Tuple(mySpaceTuple)
# Dict
elif (spaceDesc.type == pb.Dict):
mySpaceDict = {}
dictSpacePb = pb.DictSpace()
spaceDesc.space.Unpack(dictSpacePb)
for pbSubSpaceDesc in dictSpacePb.element:
subSpace = self._create_space(pbSubSpaceDesc)
mySpaceDict[pbSubSpaceDesc.name] = subSpace
space = spaces.Dict(mySpaceDict)
return space
# 初始化环境,在刚开始执行或者reset环境时会进行调用
# 因为之前已经执行过构造函数,而在构造函数中调用了ns3脚本,所以调用这个函数时可以确保脚本是在运行中的
def initialize_env(self, stepInterval):
# 将ns3代码的消息进行解析
request = self.socket.recv()
simInitMsg = pb.SimInitMsg()
simInitMsg.ParseFromString(request)
# 根据两个进程号结束两个进程
self.simPid = int(simInitMsg.simProcessId)
self.wafPid = int(simInitMsg.wafShellProcessId)
# 创造状态空间和动作空间
self._action_space = self._create_space(simInitMsg.actSpace)
self._observation_space = self._create_space(simInitMsg.obsSpace)
# 向Agent发送消息,环境初始化结束
reply = pb.SimInitAck()
reply.done = True
reply.stopSimReq = False
replyMsg = reply.SerializeToString()
self.socket.send(replyMsg)
return True
# 得到动作空间
def get_action_space(self):
return self._action_space
# 得到观测空间
def get_observation_space(self):
return self._observation_space
# 强行终止服务端
def force_env_stop(self):
self.forceEnvStop = True
# 接收环境的状态
# 主要用在当执行动作后,环境的状态和回报得到了更新
def rx_env_state(self):
if self.newStateRx:
return
request = self.socket.recv()
envStateMsg = pb.EnvStateMsg()
envStateMsg.ParseFromString(request)
# 更新状态信息 和 回报信息
self.obsData = self._create_data(envStateMsg.obsData)
self.reward = envStateMsg.reward
self.gameOver = envStateMsg.isGameOver
self.gameOverReason = envStateMsg.reason
if self.gameOver:
if self.gameOverReason == pb.EnvStateMsg.SimulationEnd:
self.envStopped = True
self.send_close_command()
else:
self.forceEnvStop = True
self.send_close_command()
self.extraInfo = envStateMsg.info
if not self.extraInfo:
self.extraInfo = {}
self.newStateRx = True
# 向ns3脚本发送终止的消息
def send_close_command(self):
reply = pb.EnvActMsg()
reply.stopSimReq = True
replyMsg = reply.SerializeToString()
self.socket.send(replyMsg)
self.newStateRx = False
return True
# 向ns3发送需要执行的动作
def send_actions(self, actions):
reply = pb.EnvActMsg()
actionMsg = self._pack_data(actions, self._action_space)
reply.actData.CopyFrom(actionMsg)
reply.stopSimReq = False
if self.forceEnvStop:
reply.stopSimReq = True
replyMsg = reply.SerializeToString()
self.socket.send(replyMsg)
self.newStateRx = False
return True
def step(self, actions):
# 执行动作
self.send_actions(actions)
# 下一步状态、回报
self.rx_env_state()
# 是否游戏结束
def is_game_over(self):
return self.gameOver
# 对传输来的数据进行解析
def _create_data(self, dataContainerPb):
# Discrete
if (dataContainerPb.type == pb.Discrete):
discreteContainerPb = pb.DiscreteDataContainer()
dataContainerPb.data.Unpack(discreteContainerPb)
data = discreteContainerPb.data
return data
# Box
if (dataContainerPb.type == pb.Box):
boxContainerPb = pb.BoxDataContainer()
dataContainerPb.data.Unpack(boxContainerPb)
# print(boxContainerPb.shape, boxContainerPb.dtype, boxContainerPb.uintData)
if boxContainerPb.dtype == pb.INT:
data = boxContainerPb.intData
elif boxContainerPb.dtype == pb.UINT:
data = boxContainerPb.uintData
elif boxContainerPb.dtype == pb.DOUBLE:
data = boxContainerPb.doubleData
else:
data = boxContainerPb.floatData
# TODO: reshape using shape info
return data
# Tuple
elif (dataContainerPb.type == pb.Tuple):
tupleDataPb = pb.TupleDataContainer()
dataContainerPb.data.Unpack(tupleDataPb)
myDataList = []
for pbSubData in tupleDataPb.element:
subData = self._create_data(pbSubData)
myDataList.append(subData)
data = tuple(myDataList)
return data
# Dict
elif (dataContainerPb.type == pb.Dict):
dictDataPb = pb.DictDataContainer()
dataContainerPb.data.Unpack(dictDataPb)
myDataDict = {}
for pbSubData in dictDataPb.element:
subData = self._create_data(pbSubData)
myDataDict[pbSubData.name] = subData
data = myDataDict
return data
def get_obs(self):
return self.obsData
def get_reward(self):
return self.reward
def get_extra_info(self):
return self.extraInfo
# 打包数据,用以发送
def _pack_data(self, actions, spaceDesc):
dataContainer = pb.DataContainer()
spaceType = spaceDesc.__class__
# Discrete
if spaceType == spaces.Discrete:
dataContainer.type = pb.Discrete
discreteContainerPb = pb.DiscreteDataContainer()
discreteContainerPb.data = actions
dataContainer.data.Pack(discreteContainerPb)
# Box
elif spaceType == spaces.Box:
dataContainer.type = pb.Box
boxContainerPb = pb.BoxDataContainer()
shape = [len(actions)]
boxContainerPb.shape.extend(shape)
if (spaceDesc.dtype in ['int', 'int8', 'int16', 'int32', 'int64']):
boxContainerPb.dtype = pb.INT
boxContainerPb.intData.extend(actions)
elif (spaceDesc.dtype in ['uint', 'uint8', 'uint16', 'uint32', 'uint64']):
boxContainerPb.dtype = pb.UINT
boxContainerPb.uintData.extend(actions)
elif (spaceDesc.dtype in ['float', 'float32', 'float64']):
boxContainerPb.dtype = pb.FLOAT
boxContainerPb.floatData.extend(actions)
elif (spaceDesc.dtype in ['double']):
boxContainerPb.dtype = pb.DOUBLE
boxContainerPb.doubleData.extend(actions)
else:
boxContainerPb.dtype = pb.FLOAT
boxContainerPb.floatData.extend(actions)
dataContainer.data.Pack(boxContainerPb)
# Tuple
elif spaceType == spaces.Tuple:
dataContainer.type = pb.Tuple
tupleDataPb = pb.TupleDataContainer()
spaceList = list(self._action_space.spaces)
subDataList = []
for subAction, subActSpaceType in zip(actions, spaceList):
subData = self._pack_data(subAction, subActSpaceType)
subDataList.append(subData)
tupleDataPb.element.extend(subDataList)
dataContainer.data.Pack(tupleDataPb)
# Dict
elif spaceType == spaces.Dict:
dataContainer.type = pb.Dict
dictDataPb = pb.DictDataContainer()
subDataList = []
for sName, subAction in actions.items():
subActSpaceType = self._action_space.spaces[sName]
subData = self._pack_data(subAction, subActSpaceType)
subData.name = sName
subDataList.append(subData)
dictDataPb.element.extend(subDataList)
dataContainer.data.Pack(dictDataPb)
return dataContainer
class Ns3Env(gym.Env):
def __init__(self, stepTime=0, port=0, startSim=True, simSeed=0, simArgs={}, debug=False):
self.stepTime = stepTime
self.port = port
self.startSim = startSim
self.simSeed = simSeed
self.simArgs = simArgs
self.debug = debug
# Filled in reset function
self.ns3ZmqBridge = None
self.action_space = None
self.observation_space = None
self.viewer = None
self.state = None
self.steps_beyond_done = None
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
self.ns3ZmqBridge.initialize_env(self.stepTime)
self.action_space = self.ns3ZmqBridge.get_action_space()
self.observation_space = self.ns3ZmqBridge.get_observation_space()
# get first observations
self.ns3ZmqBridge.rx_env_state()
# envDirty表示自从reset了之后,有没有step过
self.envDirty = False
self.seed()
def seed(self, seed=None):
self.np_random, seed = seeding.np_random(seed)
return [seed]
def step(self, action):
response = self.ns3ZmqBridge.step(action)
self.envDirty = True
obs = self.ns3ZmqBridge.get_obs()
reward = self.ns3ZmqBridge.get_reward()
done = self.ns3ZmqBridge.is_game_over()
extraInfo = self.ns3ZmqBridge.get_extra_info()
return (obs, reward, done, extraInfo)
def reset(self):
# 刚刚reset过
if not self.envDirty:
obs = self.ns3ZmqBridge.get_obs()
return obs
if self.ns3ZmqBridge:
self.ns3ZmqBridge.close()
self.ns3ZmqBridge = None
self.envDirty = False
self.ns3ZmqBridge = Ns3ZmqBridge(self.port, self.startSim, self.simSeed, self.simArgs, self.debug)
self.ns3ZmqBridge.initialize_env(self.stepTime)
self.action_space = self.ns3ZmqBridge.get_action_space()
self.observation_space = self.ns3ZmqBridge.get_observation_space()
# 得到首次观测值
self.ns3ZmqBridge.rx_env_state()
obs = self.ns3ZmqBridge.get_obs()
return obs
def render(self, mode='human'):
return
def get_random_action(self):
act = self.action_space.sample()
return act
def close(self):
if self.ns3ZmqBridge:
self.ns3ZmqBridge.close()
self.ns3ZmqBridge = None
if self.viewer:
self.viewer.close()
🗞️ Recent Posts