"""This module bundles the VirshDomain and its child with Graphical extension."""
import datetime as dt
import ipaddress
import os
import pathlib
import signal
import subprocess
import textwrap
import threading
import time
import xml.etree.ElementTree as ET
from typing import Literal, Optional
from zoneinfo import ZoneInfo
import cv2
import libvirt
import numpy as np
import usb
from libvirt import VIR_DOMAIN_AFFECT_CURRENT, VIR_DOMAIN_XML_SECURE
from fortrace.core.qemu_monitor import QEMUMonitorSession
from fortrace.core.virsh_console import VirshConsole
from fortrace.utility.desktop_environments.desktop_environment import (
DesktopEnvironmentType,
)
from fortrace.utility.desktop_environments.desktop_environment_factory import (
get_desktop_env,
)
from fortrace.utility.distribution_constants import OSType, ShellType
from fortrace.utility.exceptions import (
ConfigurationError,
DomainException,
ForTraceException,
)
from fortrace.utility.image_processing.image_similarity import nrmse
from fortrace.utility.logger_helper import setup_logger
logger = setup_logger(__name__)
[docs]
class VirshDomain:
"""Single VM/domain under control of ForTrace++."""
_pty: Optional[
VirshConsole
] # TODO: use list here and scan config for available console channels
_ipv4: Optional[ipaddress.IPv4Address]
_sniffer: Optional[subprocess.Popen[bytes]]
def __init__(
self,
domain_name: str,
connection: libvirt.virConnect,
domain_network: list[str] | None,
os_type: OSType,
session_output: os.PathLike,
):
"""Create an object to interact with a specific libvirt domain.
Args:
domain_name: The name of the domain to be created
connection: an active libvirt connection to the hypervisor
domain_network: names of networks the VM should connect to
os_type: operating system of the domain
session_output: path to session output folder
"""
self._pty = None
self._conn = connection
try:
self._domain = self._conn.lookupByName(domain_name)
except libvirt.libvirtError as e:
logger.error("Cannot find domain %s", domain_name)
logger.info(
"Available domains: %s",
(dom.name() for dom in self._conn.listAllDomains()),
)
raise e
self._domain_output_path = pathlib.Path(session_output).joinpath(
self._domain.name()
)
self._os = os_type
self._qs = QEMUMonitorSession(
self._domain, self.take_screenshot, self._domain_output_path
)
self._template = ET.ElementTree(
ET.fromstring(self._domain.XMLDesc(VIR_DOMAIN_XML_SECURE))
)
self._network = (
self._conn.networkLookupByName(domain_network) if domain_network else None
)
for network in self._template.getroot().findall("devices/interface"):
if network.attrib.get("type") != "network":
continue
if domain_network is None:
ET.SubElement(network, "link", {"state": "down"})
elif network.find("source").attrib.get("network") == domain_network:
ET.SubElement(network, "link", {"state": "up"})
elif network.find("source").attrib.get("network") != domain_network:
ET.SubElement(network, "link", {"state": "down"})
self._mac = self._template.find("devices/interface/mac").get("address")
self._ipv4 = None # ip will be available when the domain is active
self._sniffer = None
self._screen_recorder = None
self._screen_recorder_signal = threading.Event()
os.mkdir(self._domain_output_path)
[docs]
def boot(
self,
start_sniffer: bool = True,
snapshot: str | None = None,
):
"""Boot the VM with the option to start a sniffer or use a specific snapshot.
Args:
start_sniffer: start sniffer once the guest has an IP address?
snapshot: Specify a snapshot from which the guest should be booted. If the
snapshot does not exist, it will be created from the current state of the
guest and is available for future iterations
"""
try:
self._domain = self._conn.defineXMLFlags(
ET.tostring(
self._template.getroot(), encoding="utf8", method="xml"
).decode("utf-8"),
VIR_DOMAIN_AFFECT_CURRENT,
)
except libvirt.libvirtError as e:
logger.error("Cannot create domain from template")
raise e
if snapshot is not None:
try:
snap = self._domain.snapshotLookupByName(snapshot)
self._domain.revertToSnapshot(snap)
except libvirt.libvirtError:
logger.info(
"Cannot find snapshot %s, thus will create it from current VM state",
snapshot,
)
self.create_snapshot(snapshot)
self._domain.revertToSnapshot(snapshot)
if self._domain.create() < 0:
raise DomainException(f"Cannot boot domain {self._domain.name()}")
logger.info("Successfully created domain %s", self._domain.name())
if self._network is not None:
for _ in range(10):
lease = self._network.DHCPLeases(self._mac)
if lease:
self._ipv4 = lease[0]["ipaddr"]
break
time.sleep(5)
else:
raise DomainException(
"Cannot get a DHCP lease for domain %s", self._domain.name()
)
logger.info(
"Obtained IP address %s for %s", self._ipv4, self._domain.name()
)
if start_sniffer:
self.start_sniffer()
time.sleep(5) # initial short sleep to give domain time to boot up
# assume that we are on the boot screen if it does not change anymore
for _ in range(10):
screenshot_1 = self.take_screenshot()
time.sleep(5)
screenshot_2 = self.take_screenshot()
score = nrmse(screenshot_1, screenshot_2)
logger.debug("NRMSE score of login screen: %s", score)
if score < 0.05:
# assuming the login screen is static; break if score is high enough
break
else:
raise DomainException(
"Cannot determine whether domain is booted. Maybe login screen is not "
"static?"
)
logger.info(
"Successfully booted domain %s, which is now in login screen",
self._domain.name(),
)
[docs]
def shutdown(self, blocking: bool = False):
"""Send a shutdown signal to the guest.
The guest OS might ignore the request. It might be more convenient to power the
guest down via its PTY or the GUI (if there is one).
Notes:
To forcefully shut down the domain use destroy.
Args:
blocking: should the call block until the guest is shutdown (use with
caution, as there is no timeout)
"""
if self._domain is not None and self._domain.isActive():
self._domain.shutdown()
while blocking:
time.sleep(1)
if not self._domain.isActive():
break
logger.info("Successfully shutdown %s", self._domain.name())
[docs]
def destroy(self):
"""Destroy the domain.
This is a more powerful variant of shutdown, and can be used if the guest OS is
known to ignore a normal shutdown request. The method tries a graceful shutdown
first (it will fail after a certain timeout). If it fails, the 'virtual power
cord' is pulled out of the guest, which may result in data loss.
It might be more convenient to power the guest down via its PTY or the GUI (if
there is one).
"""
if not self._domain.isActive() or self._domain is None:
return
try:
self._domain.destroyFlags(libvirt.VIR_DOMAIN_DESTROY_GRACEFUL)
except libvirt.libvirtError as e:
logger.warning(e)
logger.warning("Will forcefully power down %s", self._domain.name())
self._domain.destroyFlags()
logger.info("Successfully destroyed %s", self._domain.name())
[docs]
def delete(self):
"""Delete a domain and every associated data from the host.
Use this method with caution and only on inactive domains.
"""
flags = (
libvirt.VIR_DOMAIN_UNDEFINE_MANAGED_SAVE
^ libvirt.VIR_DOMAIN_UNDEFINE_SNAPSHOTS_METADATA
^ libvirt.VIR_DOMAIN_UNDEFINE_CHECKPOINTS_METADATA
)
self._domain.undefineFlags(flags)
[docs]
def open_pty(
self, user: str, password: str, shell: ShellType, timeout: int = 2
) -> VirshConsole:
"""Open a pty for the specified user in the domain.
Checks whether there already is an active PTY session, and if so, returns it
instead.
Args:
user: username
password: user's password
shell: the shell type (important for prompt change)
timeout: timeout for pty command in seconds
Returns:
connection to an active console
"""
if self._os == OSType.WINDOWS:
raise NotImplementedError("Microsoft Windows cannot open a PTY.")
if self._domain.isActive():
# TODO: scan template for console names to allocate new pty
if self._pty is None:
self._pty = VirshConsole(
self._domain.name(),
user,
password,
shell,
self._conn.getURI(),
timeout,
)
logger.info("Established PTY to %s", self._domain.name())
return self._pty
else:
raise ForTraceException("Domain must be active to open PTY")
@property
def pty(self) -> VirshConsole | None:
return self._pty
@property
def domain(self) -> libvirt.virDomain:
return self._domain
@property
def domain_name(self) -> str:
return self._domain.name()
@property
def ipv4(self) -> ipaddress.IPv4Address:
return self._ipv4
@property
def template(self) -> ET.ElementTree:
return self._template
[docs]
def take_screenshot(self, screen: int = 0) -> bytes:
"""Takes a screenshot of the domain's specified screen
A non-graphical guest also shows a console screen, thus this method is also
available in this class.
Args:
screen: ID of the virtualized screen to take a screenshot from (If the
domain only has one screen, use '0')
Returns:
byte sequence of specified screen. Image format is hypervisor specific.
Refer to the appropriate documentation.
Warnings:
Do NOT change this method without reviewing QEMUMonitor_Session
"""
screenshot = b""
if self._conn and self._domain is not None:
stream = self._conn.newStream()
_ = self._domain.screenshot(stream, screen)
stream_bytes = stream.recv(262120)
screenshot += stream_bytes
while stream_bytes != b"":
stream_bytes = stream.recv(262120)
screenshot += stream_bytes
stream.finish()
logger.debug("Took screenshot of %s", self._domain.name())
return screenshot
[docs]
def start_sniffer(self):
"""Start the network sniffer on the session network interface.
Note:
dumpcap filters for domain related packages, meaning the package must come
from or must be addressed to the domain. One sniffer can be started per
domain.
"""
if self._sniffer is None:
file_path = self._domain_output_path.joinpath(
f"network_traffic_{time.strftime("%Y%m%d_%H%M%S")}.pcapng"
)
self._sniffer = subprocess.Popen(
[
"/usr/bin/dumpcap",
"-i",
self._network.bridgeName(),
"-w",
file_path,
f"-f host {str(self._ipv4)}",
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
logger.info(
"Sniffer started on IP %s writing pcap to %s", self._ipv4, file_path
)
else:
logger.warning("Sniffer has already been started.")
[docs]
def stop_sniffer(self):
"""Stop the network sniffer."""
if self._sniffer is not None:
self._sniffer.send_signal(signal.SIGINT)
self._sniffer = None
logger.info("Sniffer stopped")
[docs]
def insert_cd(self, iso_path: os.PathLike):
"""Insert/change a CD-ROM into the (running) guest domain.
Please make sure that a CD-ROM drive is already added to the guest's
configuration (should look like the format string below, without the source
attribute).
Args:
iso_path: provide the path to the ISO file here
"""
if pathlib.Path(iso_path).suffix != ".iso":
logger.error("Can only insert .iso files as a CD. Provided %s", iso_path)
raise ConfigurationError("Can only insert .iso files as a CD")
xml_disk = ET.Element("disk", {"type": "file", "device": "cdrom"})
ET.SubElement(xml_disk, "driver", {"name": "qemu", "type": "raw"})
ET.SubElement(xml_disk, "source", {"file": str(iso_path), "index": "3"})
ET.SubElement(xml_disk, "backingStore")
ET.SubElement(xml_disk, "target", {"dev": "sda", "bus": "sata"})
ET.SubElement(xml_disk, "readonly")
ET.SubElement(xml_disk, "alias", {"name": "sata0-0-0"})
ET.SubElement(
xml_disk,
"address",
{
"type": "drive",
"controller": "0",
"bus": "0",
"target": "0",
"unit": "0",
},
)
self._domain.updateDeviceFlags(
ET.tostring(xml_disk, encoding="unicode", method="xml"),
libvirt.VIR_DOMAIN_AFFECT_LIVE,
)
logger.info("Inserted %s as CD", iso_path)
[docs]
def transfer_file(
self, source: os.PathLike, destination: os.PathLike, target_dev: str = "vda"
):
"""Transfer a file/directory to guest using guestfs.
The domain should be powered down, so no corruption of the guest's filesystem
might occur.
Due to a license conflict, guestfs is not part of PyPi and has to be installed
manually. Thus, it is imported at function level, so ForTrace can run without
guestfs and only imports it on demand. Besides the Python bindings, libguestfish
has to be installed on the system.
For more information visit https://www.libguestfs.org/
Args:
source: file or folder to be transferred from the host
destination: path on the guest file system
target_dev: name of the target device of XML config.
If domain has only one disk, 'vda' should work.
If domain has more disks, they are usually called 'vdb', 'vdc', ... ->
look into the config
"""
if self._domain.isActive():
logger.warning("File transfer is only supported on inactive domain")
raise DomainException("File cannot be transferred to the active domain")
try:
import guestfs # pylint: disable=import-outside-toplevel
except ImportError as exc:
logger.error(exc)
raise ImportError(
"Please make sure you have installed the Python bindings for "
"'libguestfs', since they cannot be provided in the requirements"
) from exc
g = guestfs.GuestFS(python_return_dict=True)
disks = self._template.getroot().findall("devices/disk")
for disk in disks:
if disk.find("target").get("dev") == target_dev:
target_dev_path = disk.find("source").get("file")
break
else:
raise ForTraceException(f"Cannot find {target_dev} in domain config")
g.add_drive_opts(target_dev_path)
g.launch()
with open(source, "rb") as f:
g.write(destination, f.read())
g.umount_all()
g.close()
[docs]
def set_time(self, datetime_guest: dt.datetime | ZoneInfo):
"""Change the time of the guest.
Set the time of a guest by modifying the configuration file and specifying a
relative offset to the hardware clock or a timezone.
Args:
datetime_guest: The time or timezone to set the guest to
Note:
One must disable all NTP services on the guest, as it will update the offset
(e.g., on Linux 'timedatectl set-ntp 0')
"""
if self._domain.isActive():
logger.warning(
"Guest %s is active and therefore time cannot be changed",
self._domain.name(),
)
return
clock = self._template.getroot().find("clock")
if isinstance(datetime_guest, dt.datetime):
offset = round((datetime_guest - dt.datetime.now(dt.UTC)).total_seconds())
clock.set("offset", "variable")
clock.set("adjustment", str(offset))
clock.set("basis", "utc")
elif isinstance(datetime_guest, ZoneInfo):
clock.set("offset", "timezone")
clock.set("timezone", str(datetime_guest))
logger.info("Time of guest set to %s", datetime_guest)
[docs]
def get_time(self) -> dt.datetime:
"""Returns UTC time of guest, based on the settings of guest's configuration.
Note:
libvirt's getTime is only available with a running guest agent, which we
don't want
Returns:
the datetime of the domain
"""
clock = self._template.getroot().find("clock")
match clock.get("offset"):
case "utc", None:
return dt.datetime.now(dt.UTC)
case "localtime":
return dt.datetime.now(dt.datetime.now().astimezone().tzinfo)
case "timezone":
timezone = ZoneInfo(clock.get("timezone"))
return dt.datetime.now(timezone)
case "variable":
adjustment = clock.get("adjustment")
return dt.datetime.now(dt.UTC) - dt.timedelta(seconds=int(adjustment))
case _:
raise ValueError(f"{clock.get('offset')} is not supported")
[docs]
def reset_time(self):
"""Restores the guest clock to the host clock."""
if self._domain.isActive():
logger.warning(
"Guest %s is active and therefore time cannot be changed",
self._domain.name(),
)
return
clock = self._template.getroot().find("clock")
clock.attrib.clear()
clock.set("offset", "localtime")
logger.info("Reset time on domain to time of host")
[docs]
def dump_memory(
self,
compression: Literal["elf", "kdump-zlib", "kdump-lzo", "kdump-snappy"] = "elf",
name: str | None = None,
):
"""Dumps the domain memory via libvirt into the domain's simulation directory.
The memory dump will have the current timestamp as its name, to ensure its
unique and sortable. The file suffix is determined based on the selected
compression algorithm.
Args:
compression: which compression type to be used
elf: the default, uncompressed format
kdump-zlib: kdump-compressed format with zlib compression
kdump-lzo: kdump-compressed format with LZO compression
kdump-snappy: kdump-compressed format with Snappy compression
name: give te RAM dump a unique name (defaults to timestamp)
"""
file = pathlib.Path(
self._domain_output_path,
(time.strftime("%Y%m%d_%H%M%S") if name is None else name)
+ (".elf" if compression == "elf" else ".dump"),
)
with open(file, "w"):
if compression == "elf":
ret = self._domain.coreDumpWithFormat(
str(file),
libvirt.VIR_DOMAIN_CORE_DUMP_FORMAT_RAW,
libvirt.VIR_DUMP_MEMORY_ONLY,
)
elif compression == "kdump-zlib":
ret = self._domain.coreDumpWithFormat(
str(file),
libvirt.VIR_DOMAIN_CORE_DUMP_FORMAT_KDUMP_ZLIB,
libvirt.VIR_DUMP_MEMORY_ONLY,
)
elif compression == "kdump-lzo":
ret = self._domain.coreDumpWithFormat(
str(file),
libvirt.VIR_DOMAIN_CORE_DUMP_FORMAT_KDUMP_LZO,
libvirt.VIR_DUMP_MEMORY_ONLY,
)
elif compression == "kdump-snappy":
ret = self._domain.coreDumpWithFormat(
str(file),
libvirt.VIR_DOMAIN_CORE_DUMP_FORMAT_KDUMP_SNAPPY,
libvirt.VIR_DUMP_MEMORY_ONLY,
)
else:
raise ValueError(
f"{compression} not supported as compression parameter"
)
if ret == 0:
logger.info(
"Completed memory dump of %s. Saved to %s", self._domain.name(), file
)
else:
logger.error("Something went wrong during memory dump")
[docs]
def dump_image(self, images: list[str] | None = None):
"""Saves images of a domain as raw files into the domain_output_path.
Each image is named after its device name provided in the domain configuration.
Domain has to be active to perform a live backup. The method will block until
the backup is completed.
Args:
images: list of images to be dumped (None if all images should be dumped).
The names must be the same as those used in the libvirt configuration.
"""
backup_xml = ET.Element("domainbackup")
disks = ET.SubElement(backup_xml, "disks")
for disk_xml in self._template.getroot().findall("devices/disk"):
if disk_xml.get("device") != "disk":
continue
if images is not None and disk_xml.find("target").get("dev") not in images:
# skip disk from backup if not in images
ET.SubElement(
disks,
"disk",
{"name": disk_xml.find("target").get("dev"), "backup": "no"},
)
else:
disk = ET.SubElement(
disks,
"disk",
{"name": disk_xml.find("target").get("dev"), "type": "file"},
)
path = pathlib.Path(
self._domain_output_path, disk_xml.find("target").get("dev")
).with_suffix(".raw")
ET.SubElement(disk, "target", {"file": str(path)})
ET.SubElement(disk, "driver", {"type": "raw"})
logger.info(
"Will create a backup of %s in %s",
disk_xml.find("target").get("dev"),
self._domain_output_path,
)
self._domain.backupBegin(
ET.tostring(backup_xml, encoding="unicode", method="xml"), None
)
while True:
time.sleep(5)
job_stats = self._domain.jobStats(libvirt.VIR_DOMAIN_JOB_STATS_COMPLETED)
if job_stats["type"] == libvirt.VIR_DOMAIN_JOB_NONE:
logger.debug("Job has not completed yet")
continue
elif job_stats["type"] == libvirt.VIR_DOMAIN_JOB_COMPLETED:
if job_stats["operation"] == libvirt.VIR_DOMAIN_JOB_OPERATION_BACKUP:
if job_stats["disk_remaining"] != 0:
logger.info("There are still disks remaining to be dumped.")
logger.info(
"Total: %s, Processed: %s, Remaining: %s",
job_stats["disk_total"],
job_stats["disk_processed"],
job_stats["disk_remaining"],
)
continue
else:
logger.info(
"Image dump of %s completed. Took %sms",
self._domain.name(),
job_stats["time_elapsed"],
)
return
else:
logger.debug(
"Another job has completed. Waiting for successful backup"
)
continue
elif job_stats["type"] == libvirt.VIR_DOMAIN_JOB_FAILED:
logger.error("Image dump of %s not successful", self._domain.name())
logger.error(job_stats)
raise ForTraceException(
f"Image dump for {self._domain.name()} was not successful"
)
elif job_stats["type"] == libvirt.VIR_DOMAIN_JOB_CANCELLED:
logger.warning("Image dump of %s was cancelled", self._domain.name())
logger.warning(job_stats)
return
[docs]
def create_snapshot(self, name: str, flags: int = 0) -> libvirt.virDomainSnapshot:
"""Create a snapshot of the domain.
The guest might be running or powered off.
Args:
name: specify a unique name for the snapshot
flags: virDomainSnapshotCreateFlags
(https://libvirt.org/html/libvirt-libvirt-domain-snapshot.html#virDomainSnapshotCreateFlags)
Returns:
The newly created snapshot
"""
snapshot_xml = ET.Element("domainsnapshot")
name_xml = ET.SubElement(snapshot_xml, "name")
name_xml.text = name
snap = self._domain.snapshotCreateXML(
ET.tostring(snapshot_xml, encoding="unicode", method="xml"), flags
)
logger.info("Created snapshot with name '%s'", name)
return snap
[docs]
def delete_snapshot(self, name: str, flags: int = 0):
"""Delete a snapshot of the domain.
Args:
name: name of the snapshot to be deleted
flags: virDomainSnapshotDeleteFlags
(https://libvirt.org/html/libvirt-libvirt-domain-snapshot.html#virDomainSnapshotDeleteFlags)
"""
try:
snapshot = self._domain.snapshotLookupByName(name)
snapshot.delete(flags)
except libvirt.libvirtError as e:
logger.error(repr(e))
exit(1)
logger.info("Deleted snapshot with name '%s'", name)
[docs]
def list_snapshots(self, flags: int = 0) -> list[str]:
"""Get a list of all snapshot names
Args:
flags: virDomainSnapshotListFlags
(https://libvirt.org/html/libvirt-libvirt-domain-snapshot.html#virDomainSnapshotListFlags)
Returns:
list of snapshot names
"""
try:
names = self._domain.snapshotListNames(flags)
except libvirt.libvirtError as e:
logger.error(repr(e))
exit(1)
return names
[docs]
def revert_to_snapshot(self, name: str, flags: int = 0):
"""Revert the guest to the specified snapshot.
If the snapshot name is unknown, ForTrace++ stops.
Args:
name: give the unique name of the snapshot
flags: virDomainSnapshotRevertFlags
(https://libvirt.org/html/libvirt-libvirt-domain-snapshot.html#virDomainSnapshotRevertFlags)
"""
try:
snapshot = self._domain.snapshotLookupByName(name)
self._domain.revertToSnapshot(snapshot, flags)
except libvirt.libvirtError as e:
logger.error(repr(e))
exit(1)
logger.info("Reverted to snapshot '%s'", name)
def __del__(self):
if self._sniffer is not None:
self.stop_sniffer()
[docs]
def start_screen_recording(self, interval: float = 5.0):
"""Periodically take screenshots of the domain to document the execution.
The screenshots are placed in the domain's output directory under
'screen_recording', e.g., /var/tmp/ForTrace/win10/screen_recording/.
Each image bears the name of the time it was taken. The images can be combined
into a hyperlapse, i.e., with ffmpeg:
> ffmpeg -framerate 24 -pattern_type glob -i "*.png" -c:v libx264 output.mp4
Args:
interval: delay in seconds between screenshots
"""
def capture(domain: VirshDomain, interval: int = 5):
while domain._screen_recorder_signal.is_set():
img_np = np.frombuffer(domain.take_screenshot(), dtype=np.uint8)
img_cv = cv2.imdecode(img_np, cv2.IMREAD_COLOR)
img_name = dt.datetime.now().strftime("%Y%m%d_%H%M%S") + ".png"
cv2.imwrite(
os.path.join(
domain._domain_output_path, "screen_recording", img_name
),
img_cv,
)
time.sleep(interval)
try:
os.mkdir(self._domain_output_path.joinpath("screen_recording"))
except FileExistsError:
pass
self._screen_recorder = threading.Thread(
target=capture, kwargs={"domain": self, "interval": interval}
)
self._screen_recorder_signal.set()
self._screen_recorder.start()
logger.info("Started screen recording of domain %s", self.domain_name)
[docs]
def stop_screen_recording(self):
"""Stop the screen recording thread."""
self._screen_recorder_signal.clear()
self._screen_recorder.join()
logger.info("Stopped screen recording of domain %s", self.domain_name)
[docs]
def attach_usb_device(self, vendor_id: str, product_id: str):
"""Attach an available USB device to the running domain.
Args:
vendor_id: vendor ID of the USB device, e.g., '0x1234'
product_id: product ID of the USB device, e.g., '0x5678'
Note:
The device is attached only to the live domain. It is NOT attached across
hibernation or boot cycles. You need to attach it again.
"""
if not self._usb_device_present(vendor_id, product_id):
raise DomainException(
f"The USB device with the vendor ID {vendor_id} and the product ID "
f"{product_id} do not exist."
)
usb_device_xml = textwrap.dedent(
f"""
<hostdev mode='subsystem' type='usb' managed='yes'>
<source>
<vendor id='{vendor_id}'/>
<product id='{product_id}'/>
</source>
</hostdev>
"""
)
try:
self.domain.attachDeviceFlags(
usb_device_xml, libvirt.VIR_DOMAIN_AFFECT_LIVE
)
except libvirt.libvirtError as e:
logger.error(repr(e))
raise e
[docs]
def detach_usb_device(self, vendor_id: str, product_id: str):
"""Detach an available USB device of the running domain.
Args:
vendor_id: vendor ID of the USB device, e.g., '0x1234'
product_id: product ID of the USB device, e.g., '0x5678'
"""
# TODO: check that device is attached to the domain
usb_device_xml = textwrap.dedent(
f"""
<hostdev mode='subsystem' type='usb' managed='yes'>
<source>
<vendor id='{vendor_id}'/>
<product id='{product_id}'/>
</source>
</hostdev>
"""
)
try:
self.domain.detachDeviceFlags(
usb_device_xml, libvirt.VIR_DOMAIN_AFFECT_LIVE
)
except libvirt.libvirtError as e:
logger.error(repr(e))
raise e
@staticmethod
def _usb_device_present(vendor_id: str, product_id: str):
device = usb.core.find(
idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
if device is None:
return False
else:
return True
[docs]
class GraphicalVirshDomain(VirshDomain):
"""Extension of the VirshDomain class with desktop environment attribute."""
def __init__(
self,
domain_name: str,
connection: libvirt.virConnect,
domain_network: list[str] | None,
os_type: OSType,
session_output: os.PathLike,
desktop_env: DesktopEnvironmentType,
):
super().__init__(
domain_name, connection, domain_network, os_type, session_output
)
self._env = get_desktop_env(self._os, desktop_env, self._qs)
@property
def env(self):
return self._env