유니티3D 프로그래밍
Unity ML-Agent Penguin (21.05.25) 본문
물고기를 잡으면 아기 펭귄에게 가져다 주는 훈련을 한다.
PenguinAgent.cs
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
public class PenguinAgent : Agent
{
[Tooltip("How fast the agent moves forward")]
public float moveSpeed = 5f;
[Tooltip("How fast the agent turns")]
public float turnSpeed = 180f;
[Tooltip("Prefab of the heart that appears when the baby is fed")]
public GameObject heartPrefab;
[Tooltip("Prefab of the regurgitated fish that appears when the baby is fed")]
public GameObject regurgitatedFishPrefab;
private PenguinArea penguinArea;
new private Rigidbody rigidbody;
private GameObject baby;
private bool isFull; // If true, penguin has a full stomach
/// <summary>
/// Initial setup, called when the agent is enabled
/// </summary>
public override void Initialize()
{
base.Initialize();
penguinArea = GetComponentInParent<PenguinArea>();
baby = penguinArea.penguinBaby;
rigidbody = GetComponent<Rigidbody>();
}
/// <summary>
/// Perform actions based on a vector of numbers
/// </summary>
/// <param name="actionBuffers">The struct of actions to take</param>
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// Convert the first action to forward movement
float forwardAmount = actionBuffers.DiscreteActions[0];
// Convert the second action to turning left or right
float turnAmount = 0f;
if (actionBuffers.DiscreteActions[1] == 1f)
{
turnAmount = -1f;
}
else if (actionBuffers.DiscreteActions[1] == 2f)
{
turnAmount = 1f;
}
// Apply movement
rigidbody.MovePosition(transform.position + transform.forward * forwardAmount * moveSpeed * Time.fixedDeltaTime);
transform.Rotate(transform.up * turnAmount * turnSpeed * Time.fixedDeltaTime);
// Apply a tiny negative reward every step to encourage action
if (MaxStep > 0) AddReward(-1f / MaxStep);
}
/// <summary>
/// Read inputs from the keyboard and convert them to a list of actions.
/// This is called only when the player wants to control the agent and has set
/// Behavior Type to "Heuristic Only" in the Behavior Parameters inspector.
/// </summary>
public override void Heuristic(in ActionBuffers actionsOut)
{
int forwardAction = 0;
int turnAction = 0;
if (Input.GetKey(KeyCode.W))
{
// move forward
forwardAction = 1;
}
if (Input.GetKey(KeyCode.A))
{
// turn left
turnAction = 1;
}
else if (Input.GetKey(KeyCode.D))
{
// turn right
turnAction = 2;
}
// Put the actions into the array
actionsOut.DiscreteActions.Array[0] = forwardAction;
actionsOut.DiscreteActions.Array[1] = turnAction;
}
/// <summary>
/// When a new episode begins, reset the agent and area
/// </summary>
public override void OnEpisodeBegin()
{
isFull = false;
penguinArea.ResetArea();
}
/// <summary>
/// Collect all non-Raycast observations
/// </summary>
/// <param name="sensor">The vector sensor to add observations to</param>
public override void CollectObservations(VectorSensor sensor)
{
// Whether the penguin has eaten a fish (1 float = 1 value)
sensor.AddObservation(isFull);
// Distance to the baby (1 float = 1 value)
sensor.AddObservation(Vector3.Distance(baby.transform.position, transform.position));
// Direction to baby (1 Vector3 = 3 values)
sensor.AddObservation((baby.transform.position - transform.position).normalized);
// Direction penguin is facing (1 Vector3 = 3 values)
sensor.AddObservation(transform.forward);
// 1 + 1 + 3 + 3 = 8 total values
}
/// <summary>
/// When the agent collides with something, take action
/// </summary>
/// <param name="collision">The collision info</param>
private void OnCollisionEnter(Collision collision)
{
if (collision.transform.CompareTag("fish"))
{
// Try to eat the fish
EatFish(collision.gameObject);
}
else if (collision.transform.CompareTag("baby"))
{
// Try to feed the baby
RegurgitateFish();
}
}
/// <summary>
/// Check if agent is full, if not, eat the fish and get a reward
/// </summary>
/// <param name="fishObject">The fish to eat</param>
private void EatFish(GameObject fishObject)
{
if (isFull) return; // Can't eat another fish while full
isFull = true;
penguinArea.RemoveSpecificFish(fishObject);
AddReward(1f);
}
/// <summary>
/// Check if agent is full, if yes, feed the baby
/// </summary>
private void RegurgitateFish()
{
if (!isFull) return; // Nothing to regurgitate
isFull = false;
// Spawn regurgitated fish
GameObject regurgitatedFish = Instantiate<GameObject>(regurgitatedFishPrefab);
regurgitatedFish.transform.parent = transform.parent;
regurgitatedFish.transform.position = baby.transform.position;
Destroy(regurgitatedFish, 4f);
// Spawn heart
GameObject heart = Instantiate<GameObject>(heartPrefab);
heart.transform.parent = transform.parent;
heart.transform.position = baby.transform.position + Vector3.up;
Destroy(heart, 4f);
AddReward(1f);
if (penguinArea.FishRemaining <= 0)
{
EndEpisode();
}
}
}
PenguinArea.cs
using System.Collections.Generic;
using UnityEngine;
using TMPro;
public class PenguinArea : MonoBehaviour
{
[Tooltip("The agent inside the area")]
public PenguinAgent penguinAgent;
[Tooltip("The baby penguin inside the area")]
public GameObject penguinBaby;
[Tooltip("The TextMeshPro text that shows the cumulative reward of the agent")]
public TextMeshPro cumulativeRewardText;
[Tooltip("Prefab of a live fish")]
public Fish fishPrefab;
private List<GameObject> fishList;
/// <summary>
/// Reset the area, including fish and penguin placement
/// </summary>
public void ResetArea()
{
RemoveAllFish();
PlacePenguin();
PlaceBaby();
SpawnFish(4, .5f);
}
/// <summary>
/// Remove a specific fish from the area when it is eaten
/// </summary>
/// <param name="fishObject">The fish to remove</param>
public void RemoveSpecificFish(GameObject fishObject)
{
fishList.Remove(fishObject);
Destroy(fishObject);
}
/// <summary>
/// The number of fish remaining
/// </summary>
public int FishRemaining
{
get { return fishList.Count; }
}
/// <summary>
/// Choose a random position on the X-Z plane within a partial donut shape
/// </summary>
/// <param name="center">The center of the donut</param>
/// <param name="minAngle">Minimum angle of the wedge</param>
/// <param name="maxAngle">Maximum angle of the wedge</param>
/// <param name="minRadius">Minimum distance from the center</param>
/// <param name="maxRadius">Maximum distance from the center</param>
/// <returns>A position falling within the specified region</returns>
public static Vector3 ChooseRandomPosition(Vector3 center, float minAngle, float maxAngle, float minRadius, float maxRadius)
{
float radius = minRadius;
float angle = minAngle;
if (maxRadius > minRadius)
{
// Pick a random radius
radius = UnityEngine.Random.Range(minRadius, maxRadius);
}
if (maxAngle > minAngle)
{
// Pick a random angle
angle = UnityEngine.Random.Range(minAngle, maxAngle);
}
// Center position + forward vector rotated around the Y axis by "angle" degrees, multiplies by "radius"
return center + Quaternion.Euler(0f, angle, 0f) * Vector3.forward * radius;
}
/// <summary>
/// Remove all fish from the area
/// </summary>
private void RemoveAllFish()
{
if (fishList != null)
{
for (int i = 0; i < fishList.Count; i++)
{
if (fishList[i] != null)
{
Destroy(fishList[i]);
}
}
}
fishList = new List<GameObject>();
}
/// <summary>
/// Place the penguin in the area
/// </summary>
private void PlacePenguin()
{
Rigidbody rigidbody = penguinAgent.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
penguinAgent.transform.position = ChooseRandomPosition(transform.position, 0f, 360f, 0f, 9f) + Vector3.up * .5f;
penguinAgent.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
}
/// <summary>
/// Place the baby in the area
/// </summary>
private void PlaceBaby()
{
Rigidbody rigidbody = penguinBaby.GetComponent<Rigidbody>();
rigidbody.velocity = Vector3.zero;
rigidbody.angularVelocity = Vector3.zero;
penguinBaby.transform.position = ChooseRandomPosition(transform.position, -45f, 45f, 4f, 9f) + Vector3.up * .5f;
penguinBaby.transform.rotation = Quaternion.Euler(0f, 180f, 0f);
}
/// <summary>
/// Spawn some number of fish in the area and set their swim speed
/// </summary>
/// <param name="count">The number to spawn</param>
/// <param name="fishSpeed">The swim speed</param>
private void SpawnFish(int count, float fishSpeed)
{
for (int i = 0; i < count; i++)
{
// Spawn and place the fish
GameObject fishObject = Instantiate<GameObject>(fishPrefab.gameObject);
fishObject.transform.position = ChooseRandomPosition(transform.position, 100f, 260f, 2f, 13f) + Vector3.up * .5f;
fishObject.transform.rotation = Quaternion.Euler(0f, UnityEngine.Random.Range(0f, 360f), 0f);
// Set the fish's parent to this area's transform
fishObject.transform.SetParent(transform);
// Keep track of the fish
fishList.Add(fishObject);
// Set the fish speed
fishObject.GetComponent<Fish>().fishSpeed = fishSpeed;
}
}
/// <summary>
/// Called when the game starts
/// </summary>
private void Start()
{
ResetArea();
}
/// <summary>
/// Called every frame
/// </summary>
private void Update()
{
// Update the cumulative reward text
cumulativeRewardText.text = penguinAgent.GetCumulativeReward().ToString("0.00");
}
}
Fish.cs
using UnityEngine;
public class Fish : MonoBehaviour
{
[Tooltip("The swim speed")]
public float fishSpeed;
private float randomizedSpeed = 0f;
private float nextActionTime = -1f;
private Vector3 targetPosition;
/// <summary>
/// Called every timestep
/// </summary>
private void FixedUpdate()
{
if (fishSpeed > 0f)
{
Swim();
}
}
/// <summary>
/// Swim between random positions
/// </summary>
private void Swim()
{
// If it's time for the next action, pick a new speed and destination
// Else, swim toward the destination
if (Time.fixedTime >= nextActionTime)
{
// Randomize the speed
randomizedSpeed = fishSpeed * UnityEngine.Random.Range(.5f, 1.5f);
// Pick a random target
targetPosition = PenguinArea.ChooseRandomPosition(transform.parent.position, 100f, 260f, 2f, 13f);
// Rotate toward the target
transform.rotation = Quaternion.LookRotation(targetPosition - transform.position, Vector3.up);
// Calculate the time to get there
float timeToGetThere = Vector3.Distance(transform.position, targetPosition) / randomizedSpeed;
nextActionTime = Time.fixedTime + timeToGetThere;
}
else
{
// Make sure that the fish does not swim past the target
Vector3 moveVector = randomizedSpeed * transform.forward * Time.fixedDeltaTime;
if (moveVector.magnitude <= Vector3.Distance(transform.position, targetPosition))
{
transform.position += moveVector;
}
else
{
transform.position = targetPosition;
nextActionTime = Time.fixedTime;
}
}
}
}
1000000회 실험 결과
'Unity > 수업내용' 카테고리의 다른 글
Unity ML-Agent 연습 (21.05.17~18) (0) | 2021.05.18 |
---|---|
Unity Shader 연습 (21.05.03, 21.05.06~07) (0) | 2021.05.06 |
Unity 간단한 게임 구현 (21.04.29~30) (0) | 2021.04.30 |
Unity NGUI 클릭 움직임 구현 (21.04.28) (0) | 2021.04.28 |
Unity UGUI 상점 창 구현 (21.04.23) (0) | 2021.04.23 |