public class ParkingCarAgent : Agent { [SerializeField] private Transform TargetParkingSpot; [SerializeField] // = Reward every 'interval' units getting closer private float DistanceRewardInterval = 3f; // Thresholds defining when the task is complete [SerializeField] private float DistanceThreshold = 2; [SerializeField] private float RotationThreshold = 20; [SerializeField] private float SpeedTheshold = 5f; // Bounds the agent may not leave [SerializeField] private Bounds AllowedBounds; private DistanceSensor[] distanceSensors; ... public override void CollectObservations() { base.CollectObservations(); // Agent position, y rotation and velocity Vector3 normalizedAgentPosition = GetNormalizedPosition(this.transform.position); AddVectorObs(carPhysics.CurrentSpeed); AddVectorObs(normalizedAgentPosition.x); AddVectorObs(normalizedAgentPosition.z); Vector3 normalizedAgentRotation = GetNormalizedRotation(this.transform.rotation); AddVectorObs(normalizedAgentRotation.y); // Target position / y rotation Vector3 normalizedTargetPosition = GetNormalizedPosition(TargetParkingSpot.position); AddVectorObs(normalizedTargetPosition.x - normalizedAgentPosition.x); AddVectorObs(normalizedTargetPosition.z - normalizedAgentPosition.z); Vector3 normalizedTargetRotation = GetNormalizedRotation(TargetParkingSpot.rotation); AddVectorObs(normalizedTargetRotation.y - normalizedAgentRotation.y); // Add all sensor readings foreach (DistanceSensor sensor in distanceSensors) { sensor.UpdateSensorReadings(); AddVectorObs(sensor.NormalizedDistance); } } public override void AgentAction(float[] vectorAction, string textAction) { base.AgentAction(vectorAction, textAction); if (IsDone()) return; // Action Inputs, length 4: // [0]: Throttle, remapped to range [0, 1] // [1]: Turning // [2]: Braking, remapped to range [0, 1] carPhysics.CurrentThrottle = Mathf.Max(0, vectorAction[0]); carPhysics.CurrentBraking = Mathf.Max(0, -vectorAction[0]); carPhysics.CurrentTurning = vectorAction[1]; // Reward for getting closer; Note: could use sqrDistance here for performance float distanceToTarget = Vector3.Distance(this.transform.position, TargetParkingSpot.transform.position); if (distanceToTarget < previousDistance) { if ((int)(distanceToTarget / DistanceRewardInterval) < (int)(previousDistance / DistanceRewardInterval)) AddReward(0.02f); previousDistance = distanceToTarget; } else { // Note: '* 2' is a hard coded value here, which I introduced after tuning the penalty to occur less frequently than // the reward, in order to not 'scare' the AI of performing corrective maneuvers where it has to first increase the // distance to the target parking spot. if ((int)(distanceToTarget / (DistanceRewardInterval * 2)) > (int)(previousDistance / (DistanceRewardInterval * 2))) { if (Verbose) Debug.Log("Distance based penalty"); AddReward(-0.04f); previousDistance = distanceToTarget; } } // Check task completion (= position and rotation lower than threshold) float rotationDiff = Quaternion.Angle(this.transform.rotation, TargetParkingSpot.rotation); if (distanceToTarget <= DistanceThreshold) { // Angle wrap-around if (rotationDiff > 90) rotationDiff = 180 - rotationDiff; if (Mathf.Abs(carPhysics.CurrentSpeed) <= SpeedTheshold) { // Determine how well (= how parallel) the AI parked float reward = 1; if (rotationDiff > RotationThreshold) reward = 1 - GetNormalizedValue(rotationDiff, RotationThreshold, 90); AddReward(reward); Done(); return; } } if (!AllowedBounds.Contains(new Vector3Int((int)transform.position.x, (int)transform.position.y, (int)transform.position.z))) { AddReward(-1.0f); Done(); return; } } private Vector3 GetNormalizedPosition(in Vector3 position) { float normalizedX = GetNormalizedValue(position.x, AllowedBounds.min.x, AllowedBounds.max.x); float normalizedY = GetNormalizedValue(position.y, AllowedBounds.min.y, AllowedBounds.max.y); float normalizedZ = GetNormalizedValue(position.z, AllowedBounds.min.z, AllowedBounds.max.z); return new Vector3(normalizedX, normalizedY, normalizedZ); } private Vector3 GetNormalizedRotation(in Quaternion rotation) { float normalizedX = GetNormalizedValue(rotation.eulerAngles.x, 0, 360); float normalizedY = GetNormalizedValue(rotation.eulerAngles.y, 0, 360); float normalizedZ = GetNormalizedValue(rotation.eulerAngles.z, 0, 360); return new Vector3(normalizedX, normalizedY, normalizedZ); } private float GetNormalizedValue(float currentValue, float minValue, float maxValue) { return (currentValue - minValue) / (maxValue - minValue); } void OnCollisionEnter(Collision collision) { if (collision.collider.gameObject.GetComponent() || collision.collider.gameObject.GetComponentInParent()) AddReward(-0.12f); } ... }